from typing import List, Optional, Dict, Any, Union, Tuple from sqlalchemy import func, desc, asc, and_, or_ from sqlalchemy.orm import Session, joinedload from app.crud.base import CRUDBase from app.models.arbitrage import Arbitrage, ArbitrageLeg from app.models.pool import Pool from app.models.dex import Dex from app.models.transaction import Transaction from app.models.block import Block from app.schemas.arbitrage import ArbitrageCreate, ArbitrageUpdate, ArbitrageLegCreate class CRUDArbitrage(CRUDBase[Arbitrage, ArbitrageCreate, ArbitrageUpdate]): """CRUD operations for arbitrages.""" def get_with_legs(self, db: Session, *, arbitrage_id: int) -> Optional[Arbitrage]: """Get an arbitrage with all its legs.""" return db.query(Arbitrage).options( joinedload(Arbitrage.legs).joinedload(ArbitrageLeg.pool).joinedload(Pool.dex) ).filter(Arbitrage.id == arbitrage_id).first() def get_by_transaction(self, db: Session, *, tx_id: int) -> List[Arbitrage]: """Get arbitrages for a specific transaction.""" return db.query(Arbitrage).filter(Arbitrage.transaction_id == tx_id).all() def get_by_initiator( self, db: Session, *, initiator: str, skip: int = 0, limit: int = 100 ) -> List[Arbitrage]: """Get arbitrages initiated by a specific address.""" return db.query(Arbitrage).filter( Arbitrage.initiator_address == initiator ).order_by(desc(Arbitrage.created_at)).offset(skip).limit(limit).all() def get_by_token( self, db: Session, *, token: str, skip: int = 0, limit: int = 100 ) -> List[Arbitrage]: """Get arbitrages for a specific token.""" return db.query(Arbitrage).filter( Arbitrage.start_token_address == token ).order_by(desc(Arbitrage.created_at)).offset(skip).limit(limit).all() def get_successful_arbitrages( self, db: Session, *, skip: int = 0, limit: int = 100 ) -> List[Arbitrage]: """Get successful arbitrages.""" return db.query(Arbitrage).filter( Arbitrage.success == True ).order_by(desc(Arbitrage.profit_percentage)).offset(skip).limit(limit).all() def get_failed_arbitrages( self, db: Session, *, skip: int = 0, limit: int = 100 ) -> List[Arbitrage]: """Get failed arbitrages.""" return db.query(Arbitrage).filter( Arbitrage.success == False ).order_by(desc(Arbitrage.created_at)).offset(skip).limit(limit).all() def get_arbitrages_by_profit_range( self, db: Session, *, min_profit: float, max_profit: Optional[float] = None, skip: int = 0, limit: int = 100 ) -> List[Arbitrage]: """Get arbitrages within a profit percentage range.""" query = db.query(Arbitrage).filter( Arbitrage.success == True, Arbitrage.profit_percentage >= min_profit ) if max_profit is not None: query = query.filter(Arbitrage.profit_percentage <= max_profit) return query.order_by( desc(Arbitrage.profit_percentage) ).offset(skip).limit(limit).all() def get_arbitrages_by_time_range( self, db: Session, *, start_time: Any, end_time: Any, skip: int = 0, limit: int = 100 ) -> List[Arbitrage]: """Get arbitrages within a time range.""" return db.query(Arbitrage).join( Transaction, Arbitrage.transaction_id == Transaction.id ).join( Block, Transaction.block_id == Block.id ).filter( Block.block_time >= start_time, Block.block_time <= end_time ).order_by( desc(Block.block_time) ).offset(skip).limit(limit).all() def get_arbitrages_by_dex( self, db: Session, *, dex_id: int, skip: int = 0, limit: int = 100 ) -> List[Arbitrage]: """Get arbitrages involving a specific DEX.""" # We'll need to join with ArbitrageLeg and Pool to filter by DEX return db.query(Arbitrage).join( ArbitrageLeg, Arbitrage.id == ArbitrageLeg.arbitrage_id ).join( Pool, ArbitrageLeg.pool_id == Pool.id ).filter( Pool.dex_id == dex_id ).distinct().order_by( desc(Arbitrage.created_at) ).offset(skip).limit(limit).all() def get_arbitrage_stats(self, db: Session) -> Dict[str, Any]: """Get overall arbitrage statistics.""" total = db.query(func.count(Arbitrage.id)).scalar() or 0 successful = db.query(func.count(Arbitrage.id)).filter( Arbitrage.success == True ).scalar() or 0 failed = total - successful total_profit = db.query(func.sum(Arbitrage.profit_amount)).filter( Arbitrage.success == True ).scalar() or 0 avg_profit_pct = db.query(func.avg(Arbitrage.profit_percentage)).filter( Arbitrage.success == True ).scalar() or 0 max_profit_pct = db.query(func.max(Arbitrage.profit_percentage)).filter( Arbitrage.success == True ).scalar() or 0 # Most used DEXes dex_query = db.query( Dex.name, func.count(Dex.id).label('count') ).join( Pool, Dex.id == Pool.dex_id ).join( ArbitrageLeg, Pool.id == ArbitrageLeg.pool_id ).join( Arbitrage, ArbitrageLeg.arbitrage_id == Arbitrage.id ).filter( Arbitrage.success == True ).group_by( Dex.name ).order_by( desc('count') ).limit(5).all() most_used_dexes = [ {"name": name, "count": count} for name, count in dex_query ] # Most arbitraged tokens token_query = db.query( Arbitrage.start_token_symbol, func.count(Arbitrage.id).label('count') ).filter( Arbitrage.success == True, Arbitrage.start_token_symbol.isnot(None) ).group_by( Arbitrage.start_token_symbol ).order_by( desc('count') ).limit(5).all() most_arbitraged_tokens = [ {"symbol": symbol or "Unknown", "count": count} for symbol, count in token_query ] return { "total_arbitrages": total, "successful_arbitrages": successful, "failed_arbitrages": failed, "total_profit": total_profit, "avg_profit_percentage": avg_profit_pct, "max_profit_percentage": max_profit_pct, "most_used_dexes": most_used_dexes, "most_arbitraged_tokens": most_arbitraged_tokens } def create_with_legs( self, db: Session, *, obj_in: ArbitrageCreate, legs: List[ArbitrageLegCreate] ) -> Arbitrage: """Create an arbitrage with its legs.""" # Create the arbitrage obj_data = obj_in.model_dump() db_obj = Arbitrage(**obj_data) db.add(db_obj) db.flush() # Get the ID without committing # Create the legs for leg_in in legs: leg_data = leg_in.model_dump() leg_data["arbitrage_id"] = db_obj.id db_leg = ArbitrageLeg(**leg_data) db.add(db_leg) # Update leg count db_obj.legs_count = len(legs) db.commit() db.refresh(db_obj) return db_obj # Create a single instance for use in dependency injection arbitrage = CRUDArbitrage(Arbitrage)