Set up Solana Arbitrage Trading System
- Created Alembic migrations for SQLite database - Set up database initialization on app startup - Fixed linting issues with Ruff - Updated README with comprehensive documentation - Configured startup tasks and health checks
This commit is contained in:
parent
88ccf1d134
commit
73b706f0eb
95
README.md
95
README.md
@ -1,3 +1,94 @@
|
||||
# FastAPI Application
|
||||
# Solana Arbitrage Trading System
|
||||
|
||||
This is a FastAPI application bootstrapped by BackendIM, the AI-powered backend generation platform.
|
||||
A backend system for detecting and executing arbitrage opportunities on Solana DEXes.
|
||||
|
||||
## Overview
|
||||
|
||||
This FastAPI application provides a robust backend for:
|
||||
|
||||
1. Monitoring price differences between different Solana DEXes (currently Jupiter and Raydium)
|
||||
2. Identifying profitable arbitrage opportunities based on configurable parameters
|
||||
3. Optionally executing trades to capture these opportunities
|
||||
4. Tracking performance metrics and historical trade data
|
||||
|
||||
## Features
|
||||
|
||||
- Real-time price monitoring across multiple DEXes
|
||||
- Configurable profit thresholds and slippage tolerance
|
||||
- Wallet integration for trade execution
|
||||
- Historical opportunity and trade tracking
|
||||
- Comprehensive API for monitoring and configuration
|
||||
- SQLite database for persistent storage
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Base Endpoints
|
||||
|
||||
- `GET /` - Root endpoint with API information
|
||||
- `GET /health` - Health check endpoint
|
||||
- `GET /docs` - Swagger UI documentation
|
||||
- `GET /redoc` - ReDoc documentation
|
||||
|
||||
### V1 API Endpoints
|
||||
|
||||
- `GET /api/v1/status` - System status and statistics
|
||||
- `GET /api/v1/opportunities` - List arbitrage opportunities with filtering options
|
||||
- `GET /api/v1/trades` - List historical trades with filtering options
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `SOLANA_RPC_URL` | Solana RPC endpoint URL | `https://api.mainnet-beta.solana.com` |
|
||||
| `SOLANA_NETWORK` | Solana network to use (mainnet-beta, testnet, devnet) | `mainnet-beta` |
|
||||
| `WALLET_KEYPAIR_PATH` | Path to wallet keypair JSON file | `None` |
|
||||
| `PROFIT_THRESHOLD_PERCENT` | Minimum profit percentage to consider | `1.0` |
|
||||
| `MAX_SLIPPAGE_PERCENT` | Maximum allowed slippage percentage | `0.5` |
|
||||
| `EXECUTION_ENABLED` | Whether to execute trades or just monitor | `False` |
|
||||
| `SCAN_INTERVAL_SECONDS` | How often to scan for opportunities (seconds) | `10` |
|
||||
| `MONITORED_TOKENS` | Comma-separated list of token addresses to monitor | `[]` |
|
||||
| `ENABLED_DEXES` | Comma-separated list of DEXes to monitor | `jupiter,raydium` |
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.9+
|
||||
- SQLite
|
||||
|
||||
### Setup
|
||||
|
||||
1. Clone the repository
|
||||
2. Install dependencies:
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
3. Configure environment variables (see above)
|
||||
4. Run the application:
|
||||
```
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Database Migrations
|
||||
|
||||
Database migrations are handled with Alembic:
|
||||
|
||||
```bash
|
||||
# Create a new migration
|
||||
alembic revision -m "description"
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
74
alembic.ini
Normal file
74
alembic.ini
Normal file
@ -0,0 +1,74 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration files
|
||||
# file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# timezone to use when rendering the date
|
||||
# within the migration file as well as the filename.
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; this defaults
|
||||
# to migrations/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path
|
||||
# version_locations = %(here)s/bar %(here)s/bat migrations/versions
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# SQLite URL - using absolute path
|
||||
sqlalchemy.url = sqlite:////app/storage/db/db.sqlite
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
8
app/api/api_v1/api.py
Normal file
8
app/api/api_v1/api.py
Normal file
@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.api_v1.endpoints import status, opportunities, trades
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1")
|
||||
api_router.include_router(status.router, prefix="/status", tags=["status"])
|
||||
api_router.include_router(opportunities.router, prefix="/opportunities", tags=["opportunities"])
|
||||
api_router.include_router(trades.router, prefix="/trades", tags=["trades"])
|
48
app/api/api_v1/endpoints/opportunities.py
Normal file
48
app/api/api_v1/endpoints/opportunities.py
Normal file
@ -0,0 +1,48 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.arbitrage import ArbitrageOpportunity
|
||||
from app.schemas.arbitrage import OpportunitiesList
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=OpportunitiesList)
|
||||
async def get_arbitrage_opportunities(
|
||||
viable_only: bool = Query(True, description="Show only viable opportunities that meet profit threshold"),
|
||||
token_address: Optional[str] = Query(None, description="Filter by specific token address"),
|
||||
min_profit_percent: Optional[float] = Query(None, description="Filter by minimum profit percentage"),
|
||||
limit: int = Query(20, ge=1, le=100, description="Number of opportunities to return"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve arbitrage opportunities with optional filtering.
|
||||
"""
|
||||
query = db.query(ArbitrageOpportunity)
|
||||
|
||||
# Apply filters
|
||||
if viable_only:
|
||||
query = query.filter(ArbitrageOpportunity.is_viable.is_(True))
|
||||
|
||||
if token_address:
|
||||
query = query.filter(ArbitrageOpportunity.token_address == token_address)
|
||||
|
||||
if min_profit_percent is not None:
|
||||
query = query.filter(ArbitrageOpportunity.price_difference_percent >= min_profit_percent)
|
||||
|
||||
# Get total count
|
||||
total_count = query.count()
|
||||
|
||||
# Get paginated results
|
||||
opportunities = query.order_by(desc(ArbitrageOpportunity.created_at)).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"opportunities": opportunities,
|
||||
"count": total_count,
|
||||
"timestamp": datetime.utcnow()
|
||||
}
|
69
app/api/api_v1/endpoints/status.py
Normal file
69
app/api/api_v1/endpoints/status.py
Normal file
@ -0,0 +1,69 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.session import get_db
|
||||
from app.models.arbitrage import ArbitrageOpportunity, Trade
|
||||
from app.schemas.arbitrage import StatusResponse
|
||||
from app.services.solana import get_wallet_balance
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=StatusResponse)
|
||||
async def get_system_status(db: Session = Depends(get_db)) -> Any:
|
||||
"""
|
||||
Get current system status including:
|
||||
- System configuration
|
||||
- Wallet balance (if connected)
|
||||
- Active opportunities count
|
||||
- Recent trade statistics
|
||||
"""
|
||||
# Get count of active opportunities
|
||||
active_opportunities_count = db.query(ArbitrageOpportunity).filter(
|
||||
ArbitrageOpportunity.is_viable.is_(True)
|
||||
).count()
|
||||
|
||||
# Get trade stats for last 24 hours
|
||||
yesterday = datetime.utcnow() - timedelta(days=1)
|
||||
recent_trades = db.query(Trade).filter(
|
||||
Trade.created_at >= yesterday,
|
||||
Trade.tx_status == "success"
|
||||
).all()
|
||||
|
||||
# Calculate profit
|
||||
profit_last_24h_usd = sum(trade.profit_amount_usd for trade in recent_trades)
|
||||
|
||||
# Get wallet balance
|
||||
wallet_balance_sol = None
|
||||
wallet_balance_usdc = None
|
||||
wallet_connected = False
|
||||
|
||||
if settings.WALLET_KEYPAIR_PATH:
|
||||
wallet_connected = True
|
||||
try:
|
||||
balances = get_wallet_balance()
|
||||
wallet_balance_sol = balances.get("SOL", 0.0)
|
||||
wallet_balance_usdc = balances.get("USDC", 0.0)
|
||||
except Exception:
|
||||
# If there's an error getting balance, we'll just return None
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": "running",
|
||||
"version": settings.VERSION,
|
||||
"network": settings.SOLANA_NETWORK,
|
||||
"execution_enabled": settings.EXECUTION_ENABLED,
|
||||
"scan_interval_seconds": settings.SCAN_INTERVAL_SECONDS,
|
||||
"last_scan_time": None, # Will be populated when the scanner is implemented
|
||||
"monitored_tokens_count": len(settings.MONITORED_TOKENS),
|
||||
"enabled_dexes": settings.ENABLED_DEXES,
|
||||
"wallet_connected": wallet_connected,
|
||||
"wallet_balance_sol": wallet_balance_sol,
|
||||
"wallet_balance_usdc": wallet_balance_usdc,
|
||||
"active_opportunities_count": active_opportunities_count,
|
||||
"trades_last_24h": len(recent_trades),
|
||||
"profit_last_24h_usd": profit_last_24h_usd
|
||||
}
|
56
app/api/api_v1/endpoints/trades.py
Normal file
56
app/api/api_v1/endpoints/trades.py
Normal file
@ -0,0 +1,56 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, func
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.arbitrage import Trade
|
||||
from app.schemas.arbitrage import TradesList
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=TradesList)
|
||||
async def get_trades(
|
||||
token_address: Optional[str] = Query(None, description="Filter by specific token address"),
|
||||
status: Optional[str] = Query(None, description="Filter by trade status (success, failed, pending)"),
|
||||
time_period: Optional[int] = Query(24, description="Time period in hours to look back"),
|
||||
limit: int = Query(20, ge=1, le=100, description="Number of trades to return"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve trade history with optional filtering.
|
||||
"""
|
||||
query = db.query(Trade)
|
||||
|
||||
# Apply filters
|
||||
if token_address:
|
||||
query = query.filter(Trade.token_address == token_address)
|
||||
|
||||
if status:
|
||||
query = query.filter(Trade.tx_status == status)
|
||||
|
||||
if time_period is not None:
|
||||
start_time = datetime.utcnow() - timedelta(hours=time_period)
|
||||
query = query.filter(Trade.created_at >= start_time)
|
||||
|
||||
# Get total count
|
||||
total_count = query.count()
|
||||
|
||||
# Get paginated results
|
||||
trades = query.order_by(desc(Trade.created_at)).offset(offset).limit(limit).all()
|
||||
|
||||
# Calculate total profit for the filtered trades
|
||||
total_profit = db.query(func.sum(Trade.profit_amount_usd)).filter(
|
||||
Trade.tx_status == "success",
|
||||
Trade.id.in_([trade.id for trade in trades])
|
||||
).scalar() or 0.0
|
||||
|
||||
return {
|
||||
"trades": trades,
|
||||
"count": total_count,
|
||||
"timestamp": datetime.utcnow(),
|
||||
"total_profit_usd": total_profit
|
||||
}
|
62
app/core/config.py
Normal file
62
app/core/config.py
Normal file
@ -0,0 +1,62 @@
|
||||
import secrets
|
||||
from typing import Any, List, Optional
|
||||
from pathlib import Path
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Basic application info
|
||||
PROJECT_NAME: str = "Solana Arbitrage Trading System"
|
||||
PROJECT_DESCRIPTION: str = "A backend system for detecting and executing arbitrage opportunities on Solana DEXes"
|
||||
VERSION: str = "0.1.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Security
|
||||
SECRET_KEY: str = secrets.token_urlsafe(32)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
||||
|
||||
# Database
|
||||
DB_DIR: Path = Path("/app") / "storage" / "db"
|
||||
SQLALCHEMY_DATABASE_URL: str = f"sqlite:///{DB_DIR}/db.sqlite"
|
||||
|
||||
# Solana configuration
|
||||
SOLANA_RPC_URL: str = "https://api.mainnet-beta.solana.com"
|
||||
SOLANA_NETWORK: str = "mainnet-beta" # Can be mainnet-beta, testnet, or devnet
|
||||
WALLET_KEYPAIR_PATH: Optional[str] = None # Path to the keypair JSON file
|
||||
|
||||
# Trading parameters
|
||||
PROFIT_THRESHOLD_PERCENT: float = 1.0 # Minimum profit percentage to consider an opportunity
|
||||
MAX_SLIPPAGE_PERCENT: float = 0.5 # Maximum allowed slippage percentage
|
||||
EXECUTION_ENABLED: bool = False # Whether to actually execute trades or just monitor
|
||||
SCAN_INTERVAL_SECONDS: int = 10 # How often to scan for arbitrage opportunities
|
||||
|
||||
# Monitored tokens - comma-separated list of token addresses to monitor
|
||||
MONITORED_TOKENS: List[str] = []
|
||||
|
||||
@field_validator("MONITORED_TOKENS", mode="before")
|
||||
def validate_monitored_tokens(cls, v: Any) -> List[str]:
|
||||
if isinstance(v, str) and v:
|
||||
return [token.strip() for token in v.split(",")]
|
||||
return []
|
||||
|
||||
# DEX configuration
|
||||
ENABLED_DEXES: List[str] = ["jupiter", "raydium"]
|
||||
|
||||
@field_validator("ENABLED_DEXES", mode="before")
|
||||
def validate_enabled_dexes(cls, v: Any) -> List[str]:
|
||||
if isinstance(v, str) and v:
|
||||
return [dex.strip().lower() for dex in v.split(",")]
|
||||
return ["jupiter", "raydium"] # Default if not specified
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=True,
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Create DB directory if it doesn't exist
|
||||
settings.DB_DIR.mkdir(parents=True, exist_ok=True)
|
3
app/db/base.py
Normal file
3
app/db/base.py
Normal file
@ -0,0 +1,3 @@
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
36
app/db/init_db.py
Normal file
36
app/db/init_db.py
Normal file
@ -0,0 +1,36 @@
|
||||
import logging
|
||||
import alembic.config
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Initialize the database by running migrations"""
|
||||
logger.info("Initializing database")
|
||||
|
||||
# Create DB directory if it doesn't exist
|
||||
settings.DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get the path to the alembic.ini file
|
||||
alembic_ini_path = Path(__file__).parent.parent.parent / "alembic.ini"
|
||||
|
||||
if not alembic_ini_path.exists():
|
||||
logger.error(f"Alembic config file not found at {alembic_ini_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Run the migrations
|
||||
alembic_args = [
|
||||
'--raiseerr',
|
||||
'-c', str(alembic_ini_path),
|
||||
'upgrade', 'head',
|
||||
]
|
||||
alembic.config.main(argv=alembic_args)
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database: {str(e)}")
|
||||
raise
|
19
app/db/session.py
Normal file
19
app/db/session.py
Normal file
@ -0,0 +1,19 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_engine(
|
||||
settings.SQLALCHEMY_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting DB session"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
98
app/models/arbitrage.py
Normal file
98
app/models/arbitrage.py
Normal file
@ -0,0 +1,98 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Boolean, Column, Float, Integer, String, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class ArbitrageOpportunity(Base):
|
||||
__tablename__ = "arbitrage_opportunities"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
||||
|
||||
# Token information
|
||||
token_address = Column(String, index=True)
|
||||
token_symbol = Column(String, index=True)
|
||||
|
||||
# Price information
|
||||
source_dex = Column(String, index=True)
|
||||
target_dex = Column(String, index=True)
|
||||
source_price = Column(Float)
|
||||
target_price = Column(Float)
|
||||
price_difference = Column(Float) # Absolute price difference
|
||||
price_difference_percent = Column(Float, index=True) # Percentage difference
|
||||
|
||||
# Profit details
|
||||
estimated_profit_usd = Column(Float)
|
||||
estimated_profit_token = Column(Float)
|
||||
|
||||
# Trade parameters
|
||||
max_trade_amount_usd = Column(Float)
|
||||
max_trade_amount_token = Column(Float)
|
||||
slippage_estimate = Column(Float)
|
||||
fees_estimate = Column(Float)
|
||||
|
||||
# Status
|
||||
is_viable = Column(Boolean, default=False, index=True) # Whether it meets profit threshold
|
||||
was_executed = Column(Boolean, default=False, index=True) # Whether a trade was attempted
|
||||
|
||||
# Relationships
|
||||
trades = relationship("Trade", back_populates="opportunity")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ArbitrageOpportunity {self.token_symbol}: {self.source_dex}->{self.target_dex} {self.price_difference_percent:.2f}%>"
|
||||
|
||||
|
||||
class Trade(Base):
|
||||
__tablename__ = "trades"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
||||
|
||||
# Reference to opportunity
|
||||
opportunity_id = Column(Integer, ForeignKey("arbitrage_opportunities.id"), index=True)
|
||||
opportunity = relationship("ArbitrageOpportunity", back_populates="trades")
|
||||
|
||||
# Trade details
|
||||
token_address = Column(String, index=True)
|
||||
token_symbol = Column(String, index=True)
|
||||
source_dex = Column(String)
|
||||
target_dex = Column(String)
|
||||
|
||||
# Amounts
|
||||
input_amount = Column(Float) # Amount in token
|
||||
input_amount_usd = Column(Float) # USD value at execution time
|
||||
output_amount = Column(Float) # Amount in token
|
||||
output_amount_usd = Column(Float) # USD value at execution time
|
||||
|
||||
# Trade outcome
|
||||
profit_amount = Column(Float) # Amount in token
|
||||
profit_amount_usd = Column(Float) # USD value
|
||||
profit_percent = Column(Float) # Percentage gain
|
||||
|
||||
# Transaction data
|
||||
tx_signature = Column(String, unique=True, index=True, nullable=True) # Solana transaction signature
|
||||
tx_status = Column(String, index=True) # success, failed, pending
|
||||
tx_error = Column(Text, nullable=True) # Error message if failed
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Trade {self.id}: {self.token_symbol} {self.profit_percent:.2f}% {'SUCCESS' if self.tx_status == 'success' else self.tx_status.upper()}>"
|
||||
|
||||
|
||||
class SystemEvent(Base):
|
||||
__tablename__ = "system_events"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, index=True)
|
||||
|
||||
# Event categorization
|
||||
event_type = Column(String, index=True) # startup, shutdown, error, warning, info
|
||||
component = Column(String, index=True) # which component generated the event
|
||||
|
||||
# Event details
|
||||
message = Column(Text)
|
||||
details = Column(Text, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SystemEvent {self.event_type.upper()}: {self.component} - {self.message[:50]}>"
|
119
app/schemas/arbitrage.py
Normal file
119
app/schemas/arbitrage.py
Normal file
@ -0,0 +1,119 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ArbitrageOpportunityBase(BaseModel):
|
||||
token_address: str
|
||||
token_symbol: str
|
||||
source_dex: str
|
||||
target_dex: str
|
||||
source_price: float
|
||||
target_price: float
|
||||
price_difference: float
|
||||
price_difference_percent: float
|
||||
estimated_profit_usd: float
|
||||
estimated_profit_token: float
|
||||
max_trade_amount_usd: float
|
||||
max_trade_amount_token: float
|
||||
slippage_estimate: float
|
||||
fees_estimate: float
|
||||
is_viable: bool
|
||||
|
||||
|
||||
class ArbitrageOpportunityCreate(ArbitrageOpportunityBase):
|
||||
pass
|
||||
|
||||
|
||||
class ArbitrageOpportunity(ArbitrageOpportunityBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
was_executed: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TradeBase(BaseModel):
|
||||
token_address: str
|
||||
token_symbol: str
|
||||
source_dex: str
|
||||
target_dex: str
|
||||
input_amount: float
|
||||
input_amount_usd: float
|
||||
output_amount: float
|
||||
output_amount_usd: float
|
||||
profit_amount: float
|
||||
profit_amount_usd: float
|
||||
profit_percent: float
|
||||
|
||||
|
||||
class TradeCreate(TradeBase):
|
||||
opportunity_id: int
|
||||
tx_signature: Optional[str] = None
|
||||
tx_status: str
|
||||
tx_error: Optional[str] = None
|
||||
|
||||
|
||||
class Trade(TradeBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
opportunity_id: int
|
||||
tx_signature: Optional[str] = None
|
||||
tx_status: str
|
||||
tx_error: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SystemEventBase(BaseModel):
|
||||
event_type: str
|
||||
component: str
|
||||
message: str
|
||||
details: Optional[str] = None
|
||||
|
||||
|
||||
class SystemEventCreate(SystemEventBase):
|
||||
pass
|
||||
|
||||
|
||||
class SystemEvent(SystemEventBase):
|
||||
id: int
|
||||
timestamp: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
"""Response model for the status endpoint"""
|
||||
status: str
|
||||
version: str
|
||||
network: str
|
||||
execution_enabled: bool
|
||||
scan_interval_seconds: int
|
||||
last_scan_time: Optional[datetime] = None
|
||||
monitored_tokens_count: int
|
||||
enabled_dexes: List[str]
|
||||
wallet_connected: bool
|
||||
wallet_balance_sol: Optional[float] = None
|
||||
wallet_balance_usdc: Optional[float] = None
|
||||
active_opportunities_count: int = Field(default=0, description="Number of currently viable arbitrage opportunities")
|
||||
trades_last_24h: int = Field(default=0, description="Number of trades executed in the last 24 hours")
|
||||
profit_last_24h_usd: float = Field(default=0.0, description="Total profit in USD for the last 24 hours")
|
||||
|
||||
|
||||
class OpportunitiesList(BaseModel):
|
||||
"""Response model for the opportunities endpoint"""
|
||||
opportunities: List[ArbitrageOpportunity]
|
||||
count: int
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TradesList(BaseModel):
|
||||
"""Response model for the trades endpoint"""
|
||||
trades: List[Trade]
|
||||
count: int
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
total_profit_usd: float = 0.0
|
47
app/services/dex/__init__.py
Normal file
47
app/services/dex/__init__.py
Normal file
@ -0,0 +1,47 @@
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.dex.base import BaseDexService
|
||||
from app.services.dex.jupiter import JupiterDexService
|
||||
from app.services.dex.raydium import RaydiumDexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DEX service registry
|
||||
_dex_services: Dict[str, BaseDexService] = {}
|
||||
|
||||
|
||||
def get_dex_service(dex_name: str) -> Optional[BaseDexService]:
|
||||
"""Get a DEX service by name"""
|
||||
return _dex_services.get(dex_name.lower())
|
||||
|
||||
|
||||
def get_all_dex_services() -> Dict[str, BaseDexService]:
|
||||
"""Get all available DEX services"""
|
||||
return _dex_services
|
||||
|
||||
|
||||
def initialize_dex_services():
|
||||
"""Initialize all enabled DEX services"""
|
||||
global _dex_services
|
||||
|
||||
# Clear existing services
|
||||
_dex_services = {}
|
||||
|
||||
# Initialize services based on configuration
|
||||
enabled_dexes = settings.ENABLED_DEXES
|
||||
|
||||
if "jupiter" in enabled_dexes:
|
||||
_dex_services["jupiter"] = JupiterDexService()
|
||||
logger.info("Jupiter DEX service initialized")
|
||||
|
||||
if "raydium" in enabled_dexes:
|
||||
_dex_services["raydium"] = RaydiumDexService()
|
||||
logger.info("Raydium DEX service initialized")
|
||||
|
||||
logger.info(f"Initialized {len(_dex_services)} DEX services: {', '.join(_dex_services.keys())}")
|
||||
|
||||
|
||||
# Initialize services on module import
|
||||
initialize_dex_services()
|
56
app/services/dex/base.py
Normal file
56
app/services/dex/base.py
Normal file
@ -0,0 +1,56 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any
|
||||
from decimal import Decimal
|
||||
|
||||
from app.services.solana import get_token_metadata
|
||||
|
||||
|
||||
class BaseDexService(ABC):
|
||||
"""Base class for DEX price monitoring services"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
@abstractmethod
|
||||
async def get_token_price(self, token_address: str, quote_token_address: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get token price from the DEX
|
||||
|
||||
Args:
|
||||
token_address: The address of the token to get price for
|
||||
quote_token_address: The address of the quote token (default USDC)
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
price: The token price in quote token
|
||||
liquidity: Available liquidity for the token pair
|
||||
timestamp: When the price was fetched
|
||||
metadata: Additional DEX-specific data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_token_prices(self, token_addresses: List[str], quote_token_address: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get prices for multiple tokens from the DEX
|
||||
|
||||
Args:
|
||||
token_addresses: List of token addresses to get prices for
|
||||
quote_token_address: The address of the quote token (default USDC)
|
||||
|
||||
Returns:
|
||||
Dict of token_address -> price_data
|
||||
"""
|
||||
pass
|
||||
|
||||
def format_token_amount(self, token_address: str, amount: int) -> float:
|
||||
"""Convert raw token amount to human-readable format"""
|
||||
token_metadata = get_token_metadata(token_address)
|
||||
decimals = token_metadata.get("decimals", 9)
|
||||
return float(Decimal(amount) / Decimal(10**decimals))
|
||||
|
||||
def parse_token_amount(self, token_address: str, amount: float) -> int:
|
||||
"""Convert human-readable token amount to raw format"""
|
||||
token_metadata = get_token_metadata(token_address)
|
||||
decimals = token_metadata.get("decimals", 9)
|
||||
return int(Decimal(amount) * Decimal(10**decimals))
|
166
app/services/dex/jupiter.py
Normal file
166
app/services/dex/jupiter.py
Normal file
@ -0,0 +1,166 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
import httpx
|
||||
|
||||
from app.services.dex.base import BaseDexService
|
||||
from app.services.solana import USDC_TOKEN_ADDRESS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Jupiter API V6 endpoints
|
||||
JUPITER_API_BASE = "https://quote-api.jup.ag/v6"
|
||||
PRICE_ENDPOINT = f"{JUPITER_API_BASE}/price"
|
||||
QUOTE_ENDPOINT = f"{JUPITER_API_BASE}/quote"
|
||||
|
||||
|
||||
class JupiterDexService(BaseDexService):
|
||||
"""Service for Jupiter DEX price monitoring"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("jupiter")
|
||||
self.http_client = httpx.AsyncClient(timeout=10.0)
|
||||
|
||||
async def get_token_price(self, token_address: str, quote_token_address: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get token price from Jupiter"""
|
||||
if not quote_token_address:
|
||||
quote_token_address = USDC_TOKEN_ADDRESS
|
||||
|
||||
try:
|
||||
params = {
|
||||
"ids": token_address,
|
||||
"vsToken": quote_token_address
|
||||
}
|
||||
|
||||
response = await self.http_client.get(PRICE_ENDPOINT, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if "data" in data and token_address in data["data"]:
|
||||
price_data = data["data"][token_address]
|
||||
return {
|
||||
"price": float(price_data["price"]),
|
||||
"liquidity": float(price_data.get("liquidity", 0)),
|
||||
"timestamp": int(time.time()),
|
||||
"metadata": {
|
||||
"raw_data": price_data
|
||||
}
|
||||
}
|
||||
else:
|
||||
logger.warning(f"No price data returned from Jupiter for {token_address}")
|
||||
return {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": int(time.time()),
|
||||
"metadata": {
|
||||
"error": "No price data returned"
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Jupiter price for {token_address}: {str(e)}")
|
||||
return {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": int(time.time()),
|
||||
"metadata": {
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
async def get_token_prices(self, token_addresses: List[str], quote_token_address: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get prices for multiple tokens from Jupiter"""
|
||||
if not quote_token_address:
|
||||
quote_token_address = USDC_TOKEN_ADDRESS
|
||||
|
||||
try:
|
||||
params = {
|
||||
"ids": ",".join(token_addresses),
|
||||
"vsToken": quote_token_address
|
||||
}
|
||||
|
||||
response = await self.http_client.get(PRICE_ENDPOINT, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
result = {}
|
||||
|
||||
if "data" in data:
|
||||
price_data = data["data"]
|
||||
timestamp = int(time.time())
|
||||
|
||||
for token_address in token_addresses:
|
||||
if token_address in price_data:
|
||||
token_price_data = price_data[token_address]
|
||||
result[token_address] = {
|
||||
"price": float(token_price_data["price"]),
|
||||
"liquidity": float(token_price_data.get("liquidity", 0)),
|
||||
"timestamp": timestamp,
|
||||
"metadata": {
|
||||
"raw_data": token_price_data
|
||||
}
|
||||
}
|
||||
else:
|
||||
result[token_address] = {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": timestamp,
|
||||
"metadata": {
|
||||
"error": "No price data returned"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Jupiter prices: {str(e)}")
|
||||
timestamp = int(time.time())
|
||||
return {
|
||||
token_address: {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": timestamp,
|
||||
"metadata": {
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
for token_address in token_addresses
|
||||
}
|
||||
|
||||
async def get_swap_quote(self, input_token: str, output_token: str, amount: float, slippage_bps: int = 50) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a swap quote from Jupiter
|
||||
|
||||
Args:
|
||||
input_token: Address of the input token
|
||||
output_token: Address of the output token
|
||||
amount: Amount of input token to swap
|
||||
slippage_bps: Slippage tolerance in basis points (1 bps = 0.01%)
|
||||
|
||||
Returns:
|
||||
Quote data or error
|
||||
"""
|
||||
try:
|
||||
# Convert amount to raw format
|
||||
raw_amount = self.parse_token_amount(input_token, amount)
|
||||
|
||||
params = {
|
||||
"inputMint": input_token,
|
||||
"outputMint": output_token,
|
||||
"amount": str(raw_amount),
|
||||
"slippageBps": slippage_bps,
|
||||
"onlyDirectRoutes": False,
|
||||
"asLegacyTransaction": False,
|
||||
}
|
||||
|
||||
response = await self.http_client.get(QUOTE_ENDPOINT, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Jupiter swap quote: {str(e)}")
|
||||
return {
|
||||
"error": str(e)
|
||||
}
|
187
app/services/dex/raydium.py
Normal file
187
app/services/dex/raydium.py
Normal file
@ -0,0 +1,187 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
import httpx
|
||||
|
||||
from app.services.dex.base import BaseDexService
|
||||
from app.services.solana import USDC_TOKEN_ADDRESS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Raydium API endpoints
|
||||
RAYDIUM_API_BASE = "https://api.raydium.io/v2"
|
||||
PAIRS_ENDPOINT = f"{RAYDIUM_API_BASE}/main/pairs"
|
||||
|
||||
|
||||
class RaydiumDexService(BaseDexService):
|
||||
"""Service for Raydium DEX price monitoring"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("raydium")
|
||||
self.http_client = httpx.AsyncClient(timeout=10.0)
|
||||
self.pairs_cache = {}
|
||||
self.last_refresh = 0
|
||||
self.cache_ttl = 60 # 60 seconds
|
||||
|
||||
async def refresh_pairs_cache(self):
|
||||
"""Refresh the pairs cache if needed"""
|
||||
current_time = time.time()
|
||||
if current_time - self.last_refresh < self.cache_ttl and self.pairs_cache:
|
||||
return
|
||||
|
||||
try:
|
||||
response = await self.http_client.get(PAIRS_ENDPOINT)
|
||||
response.raise_for_status()
|
||||
|
||||
pairs_data = response.json()
|
||||
|
||||
# Reorganize by token address for faster lookups
|
||||
pairs_by_token = {}
|
||||
for pair in pairs_data:
|
||||
base_token = pair.get("baseMint")
|
||||
quote_token = pair.get("quoteMint")
|
||||
|
||||
if base_token:
|
||||
if base_token not in pairs_by_token:
|
||||
pairs_by_token[base_token] = []
|
||||
pairs_by_token[base_token].append(pair)
|
||||
|
||||
if quote_token:
|
||||
if quote_token not in pairs_by_token:
|
||||
pairs_by_token[quote_token] = []
|
||||
pairs_by_token[quote_token].append(pair)
|
||||
|
||||
self.pairs_cache = pairs_by_token
|
||||
self.last_refresh = current_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing Raydium pairs cache: {str(e)}")
|
||||
|
||||
async def get_token_price(self, token_address: str, quote_token_address: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get token price from Raydium"""
|
||||
if not quote_token_address:
|
||||
quote_token_address = USDC_TOKEN_ADDRESS
|
||||
|
||||
try:
|
||||
await self.refresh_pairs_cache()
|
||||
|
||||
# Find all pairs for the token
|
||||
token_pairs = self.pairs_cache.get(token_address, [])
|
||||
|
||||
# Find a pair with the quote token
|
||||
target_pair = None
|
||||
for pair in token_pairs:
|
||||
if pair.get("baseMint") == token_address and pair.get("quoteMint") == quote_token_address:
|
||||
target_pair = pair
|
||||
price = 1.0 / float(pair.get("price", 0)) if float(pair.get("price", 0)) > 0 else 0
|
||||
break
|
||||
elif pair.get("quoteMint") == token_address and pair.get("baseMint") == quote_token_address:
|
||||
target_pair = pair
|
||||
price = float(pair.get("price", 0))
|
||||
break
|
||||
|
||||
if target_pair:
|
||||
# Calculate liquidity
|
||||
amm_id = target_pair.get("ammId")
|
||||
liquidity = float(target_pair.get("liquidity", 0))
|
||||
|
||||
return {
|
||||
"price": price,
|
||||
"liquidity": liquidity,
|
||||
"timestamp": int(time.time()),
|
||||
"metadata": {
|
||||
"pair_data": target_pair,
|
||||
"amm_id": amm_id
|
||||
}
|
||||
}
|
||||
else:
|
||||
logger.warning(f"No Raydium pair found for {token_address} with quote {quote_token_address}")
|
||||
return {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": int(time.time()),
|
||||
"metadata": {
|
||||
"error": "No pair found"
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Raydium price for {token_address}: {str(e)}")
|
||||
return {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": int(time.time()),
|
||||
"metadata": {
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
async def get_token_prices(self, token_addresses: List[str], quote_token_address: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get prices for multiple tokens from Raydium"""
|
||||
if not quote_token_address:
|
||||
quote_token_address = USDC_TOKEN_ADDRESS
|
||||
|
||||
try:
|
||||
await self.refresh_pairs_cache()
|
||||
|
||||
result = {}
|
||||
timestamp = int(time.time())
|
||||
|
||||
for token_address in token_addresses:
|
||||
# Find all pairs for the token
|
||||
token_pairs = self.pairs_cache.get(token_address, [])
|
||||
|
||||
# Find a pair with the quote token
|
||||
target_pair = None
|
||||
price = 0.0
|
||||
|
||||
for pair in token_pairs:
|
||||
if pair.get("baseMint") == token_address and pair.get("quoteMint") == quote_token_address:
|
||||
target_pair = pair
|
||||
price = 1.0 / float(pair.get("price", 0)) if float(pair.get("price", 0)) > 0 else 0
|
||||
break
|
||||
elif pair.get("quoteMint") == token_address and pair.get("baseMint") == quote_token_address:
|
||||
target_pair = pair
|
||||
price = float(pair.get("price", 0))
|
||||
break
|
||||
|
||||
if target_pair:
|
||||
# Calculate liquidity
|
||||
amm_id = target_pair.get("ammId")
|
||||
liquidity = float(target_pair.get("liquidity", 0))
|
||||
|
||||
result[token_address] = {
|
||||
"price": price,
|
||||
"liquidity": liquidity,
|
||||
"timestamp": timestamp,
|
||||
"metadata": {
|
||||
"pair_data": target_pair,
|
||||
"amm_id": amm_id
|
||||
}
|
||||
}
|
||||
else:
|
||||
result[token_address] = {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": timestamp,
|
||||
"metadata": {
|
||||
"error": "No pair found"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Raydium prices: {str(e)}")
|
||||
timestamp = int(time.time())
|
||||
return {
|
||||
token_address: {
|
||||
"price": 0.0,
|
||||
"liquidity": 0.0,
|
||||
"timestamp": timestamp,
|
||||
"metadata": {
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
for token_address in token_addresses
|
||||
}
|
180
app/services/solana.py
Normal file
180
app/services/solana.py
Normal file
@ -0,0 +1,180 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import base64
|
||||
import base58
|
||||
from solana.rpc.api import Client
|
||||
from solana.keypair import Keypair
|
||||
from solana.transaction import Transaction
|
||||
from solana.rpc.types import TxOpts
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize Solana client
|
||||
solana_client = Client(settings.SOLANA_RPC_URL)
|
||||
|
||||
# Token constants
|
||||
USDC_TOKEN_ADDRESS = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v" # USDC on Solana
|
||||
SOL_DECIMALS = 9
|
||||
USDC_DECIMALS = 6
|
||||
|
||||
# Cache for token metadata
|
||||
token_metadata_cache = {}
|
||||
|
||||
|
||||
def get_solana_client() -> Client:
|
||||
"""Get Solana RPC client"""
|
||||
return solana_client
|
||||
|
||||
|
||||
def load_wallet_keypair() -> Optional[Keypair]:
|
||||
"""Load wallet keypair from file if configured"""
|
||||
if not settings.WALLET_KEYPAIR_PATH:
|
||||
logger.warning("No wallet keypair path configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(settings.WALLET_KEYPAIR_PATH, "r") as f:
|
||||
keypair_data = json.load(f)
|
||||
if isinstance(keypair_data, list):
|
||||
# Array format [private_key_bytes]
|
||||
secret_key = bytes(keypair_data)
|
||||
return Keypair.from_secret_key(secret_key)
|
||||
elif isinstance(keypair_data, dict) and "secretKey" in keypair_data:
|
||||
# Phantom wallet export format {"publicKey": "...", "secretKey": "..."}
|
||||
secret_key = base58.b58decode(keypair_data["secretKey"])
|
||||
return Keypair.from_secret_key(secret_key)
|
||||
else:
|
||||
# Solflare and other wallets might use different formats
|
||||
logger.error("Unsupported wallet keypair format")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load wallet keypair: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_wallet_balance() -> Dict[str, float]:
|
||||
"""Get SOL and USDC balance for the configured wallet"""
|
||||
keypair = load_wallet_keypair()
|
||||
if not keypair:
|
||||
return {"SOL": 0.0, "USDC": 0.0}
|
||||
|
||||
wallet_pubkey = keypair.public_key
|
||||
|
||||
# Get SOL balance
|
||||
sol_balance_response = solana_client.get_balance(wallet_pubkey)
|
||||
sol_balance = sol_balance_response["result"]["value"] / 10**SOL_DECIMALS if "result" in sol_balance_response else 0
|
||||
|
||||
# Get USDC balance
|
||||
try:
|
||||
token_accounts = solana_client.get_token_accounts_by_owner(
|
||||
wallet_pubkey,
|
||||
{"mint": USDC_TOKEN_ADDRESS}
|
||||
)
|
||||
|
||||
usdc_balance = 0
|
||||
if "result" in token_accounts and "value" in token_accounts["result"]:
|
||||
for account in token_accounts["result"]["value"]:
|
||||
account_data = account["account"]["data"]
|
||||
if isinstance(account_data, list) and len(account_data) > 1:
|
||||
decoded_data = base64.b64decode(account_data[0])
|
||||
# Parse the token account data - this is a simplified approach
|
||||
# In a real implementation, you'd use proper parsing
|
||||
if len(decoded_data) >= 64: # Minimum length for token account data
|
||||
amount_bytes = decoded_data[64:72]
|
||||
amount = int.from_bytes(amount_bytes, byteorder="little")
|
||||
usdc_balance += amount / 10**USDC_DECIMALS
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting USDC balance: {str(e)}")
|
||||
usdc_balance = 0
|
||||
|
||||
return {
|
||||
"SOL": sol_balance,
|
||||
"USDC": usdc_balance
|
||||
}
|
||||
|
||||
|
||||
def get_token_metadata(token_address: str) -> Dict[str, Any]:
|
||||
"""Get token metadata including symbol and decimals"""
|
||||
if token_address in token_metadata_cache:
|
||||
return token_metadata_cache[token_address]
|
||||
|
||||
try:
|
||||
# Simplification: In a real implementation, you'd query the token's metadata
|
||||
# properly from the Solana token registry or on-chain data
|
||||
|
||||
# For now, we just use a placeholder implementation
|
||||
if token_address == USDC_TOKEN_ADDRESS:
|
||||
metadata = {
|
||||
"address": token_address,
|
||||
"symbol": "USDC",
|
||||
"name": "USD Coin",
|
||||
"decimals": 6,
|
||||
"logo": "https://raw.githubusercontent.com/solana-labs/token-list/main/assets/mainnet/EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v/logo.png"
|
||||
}
|
||||
else:
|
||||
# For other tokens, make an RPC call to get the decimals
|
||||
token_info = solana_client.get_token_supply(token_address)
|
||||
if "result" in token_info and "value" in token_info["result"]:
|
||||
decimals = token_info["result"]["value"]["decimals"]
|
||||
metadata = {
|
||||
"address": token_address,
|
||||
"symbol": f"TOKEN-{token_address[:4]}", # Placeholder symbol
|
||||
"name": f"Unknown Token {token_address[:8]}",
|
||||
"decimals": decimals,
|
||||
"logo": None
|
||||
}
|
||||
else:
|
||||
# Default fallback
|
||||
metadata = {
|
||||
"address": token_address,
|
||||
"symbol": f"TOKEN-{token_address[:4]}",
|
||||
"name": f"Unknown Token {token_address[:8]}",
|
||||
"decimals": 9, # Default to 9 decimals
|
||||
"logo": None
|
||||
}
|
||||
|
||||
# Cache the result
|
||||
token_metadata_cache[token_address] = metadata
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token metadata for {token_address}: {str(e)}")
|
||||
# Return a default metadata object
|
||||
default_metadata = {
|
||||
"address": token_address,
|
||||
"symbol": f"TOKEN-{token_address[:4]}",
|
||||
"name": f"Unknown Token {token_address[:8]}",
|
||||
"decimals": 9,
|
||||
"logo": None
|
||||
}
|
||||
return default_metadata
|
||||
|
||||
|
||||
def send_transaction(transaction: Transaction, signers: List[Keypair], opts: Optional[TxOpts] = None) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Send a transaction to the Solana network
|
||||
|
||||
Returns:
|
||||
Tuple of (success, signature, error_message)
|
||||
"""
|
||||
try:
|
||||
# Sign the transaction
|
||||
transaction.sign(*signers)
|
||||
|
||||
# Send the transaction
|
||||
result = solana_client.send_transaction(transaction, *signers, opts=opts)
|
||||
|
||||
if "result" in result:
|
||||
signature = result["result"]
|
||||
return True, signature, None
|
||||
else:
|
||||
error_msg = result.get("error", {}).get("message", "Unknown error")
|
||||
return False, "", error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Transaction failed: {error_message}")
|
||||
return False, "", error_message
|
69
main.py
Normal file
69
main.py
Normal file
@ -0,0 +1,69 @@
|
||||
import logging
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.api_v1.api import api_router
|
||||
from app.core.config import settings
|
||||
from app.db.init_db import init_db
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
description=settings.PROJECT_DESCRIPTION,
|
||||
version=settings.VERSION,
|
||||
openapi_url="/openapi.json",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API router
|
||||
app.include_router(api_router)
|
||||
|
||||
@app.get("/", tags=["Home"])
|
||||
async def root():
|
||||
"""
|
||||
Root endpoint returning basic information about the API.
|
||||
"""
|
||||
return {
|
||||
"title": settings.PROJECT_NAME,
|
||||
"description": settings.PROJECT_DESCRIPTION,
|
||||
"version": settings.VERSION,
|
||||
"docs_url": "/docs",
|
||||
"health_check": "/health",
|
||||
}
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint to verify the service is running.
|
||||
"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
# Initialize the database on startup
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Run database migrations on startup"""
|
||||
logger.info("Running startup tasks")
|
||||
try:
|
||||
init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during startup: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
87
migrations/env.py
Normal file
87
migrations/env.py
Normal file
@ -0,0 +1,87 @@
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import the SQLAlchemy Base and models
|
||||
from app.db.base import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline():
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online():
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
is_sqlite = connection.dialect.name == 'sqlite'
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=is_sqlite, # Important for SQLite
|
||||
compare_type=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
24
migrations/script.py.mako
Normal file
24
migrations/script.py.mako
Normal file
@ -0,0 +1,24 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade():
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade():
|
||||
${downgrades if downgrades else "pass"}
|
80
migrations/versions/001_initial_schema.py
Normal file
80
migrations/versions/001_initial_schema.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""Initial schema
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2023-06-05
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '001'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create arbitrage_opportunities table
|
||||
op.create_table(
|
||||
'arbitrage_opportunities',
|
||||
sa.Column('id', sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column('created_at', sa.DateTime(), default=sa.func.current_timestamp(), index=True),
|
||||
sa.Column('token_address', sa.String(), index=True),
|
||||
sa.Column('token_symbol', sa.String(), index=True),
|
||||
sa.Column('source_dex', sa.String(), index=True),
|
||||
sa.Column('target_dex', sa.String(), index=True),
|
||||
sa.Column('source_price', sa.Float()),
|
||||
sa.Column('target_price', sa.Float()),
|
||||
sa.Column('price_difference', sa.Float()),
|
||||
sa.Column('price_difference_percent', sa.Float(), index=True),
|
||||
sa.Column('estimated_profit_usd', sa.Float()),
|
||||
sa.Column('estimated_profit_token', sa.Float()),
|
||||
sa.Column('max_trade_amount_usd', sa.Float()),
|
||||
sa.Column('max_trade_amount_token', sa.Float()),
|
||||
sa.Column('slippage_estimate', sa.Float()),
|
||||
sa.Column('fees_estimate', sa.Float()),
|
||||
sa.Column('is_viable', sa.Boolean(), default=False, index=True),
|
||||
sa.Column('was_executed', sa.Boolean(), default=False, index=True)
|
||||
)
|
||||
|
||||
# Create trades table
|
||||
op.create_table(
|
||||
'trades',
|
||||
sa.Column('id', sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column('created_at', sa.DateTime(), default=sa.func.current_timestamp(), index=True),
|
||||
sa.Column('opportunity_id', sa.Integer(), sa.ForeignKey("arbitrage_opportunities.id"), index=True),
|
||||
sa.Column('token_address', sa.String(), index=True),
|
||||
sa.Column('token_symbol', sa.String(), index=True),
|
||||
sa.Column('source_dex', sa.String()),
|
||||
sa.Column('target_dex', sa.String()),
|
||||
sa.Column('input_amount', sa.Float()),
|
||||
sa.Column('input_amount_usd', sa.Float()),
|
||||
sa.Column('output_amount', sa.Float()),
|
||||
sa.Column('output_amount_usd', sa.Float()),
|
||||
sa.Column('profit_amount', sa.Float()),
|
||||
sa.Column('profit_amount_usd', sa.Float()),
|
||||
sa.Column('profit_percent', sa.Float()),
|
||||
sa.Column('tx_signature', sa.String(), unique=True, index=True, nullable=True),
|
||||
sa.Column('tx_status', sa.String(), index=True),
|
||||
sa.Column('tx_error', sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Create system_events table
|
||||
op.create_table(
|
||||
'system_events',
|
||||
sa.Column('id', sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column('timestamp', sa.DateTime(), default=sa.func.current_timestamp(), index=True),
|
||||
sa.Column('event_type', sa.String(), index=True),
|
||||
sa.Column('component', sa.String(), index=True),
|
||||
sa.Column('message', sa.Text()),
|
||||
sa.Column('details', sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table('trades')
|
||||
op.drop_table('system_events')
|
||||
op.drop_table('arbitrage_opportunities')
|
14
requirements.txt
Normal file
14
requirements.txt
Normal file
@ -0,0 +1,14 @@
|
||||
fastapi>=0.103.1
|
||||
uvicorn>=0.23.2
|
||||
sqlalchemy>=2.0.20
|
||||
alembic>=1.12.0
|
||||
pydantic>=2.3.0
|
||||
pydantic-settings>=2.0.3
|
||||
solana>=0.30.0
|
||||
asyncio>=3.4.3
|
||||
aiohttp>=3.8.5
|
||||
loguru>=0.7.0
|
||||
python-dotenv>=1.0.0
|
||||
ruff>=0.0.292
|
||||
httpx>=0.25.0
|
||||
pytest>=7.4.2
|
Loading…
x
Reference in New Issue
Block a user