import json import time import uuid from typing import List, Optional from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from sqlalchemy.orm import Session from sqlalchemy import desc, func from app.db.session import get_db from app.models.screening import ScreeningResult, ScreeningBatch from app.models.transaction import Transaction from app.models.rule import Rule from app.services.rule_engine import RuleEngine from app.core.schemas import ( ScreeningResponse, BatchScreeningRequest, ScreeningBatch as ScreeningBatchSchema, PaginatedResponse, AggregateRequest, AggregateResponse ) router = APIRouter() @router.post("/transactions/{transaction_id}", response_model=ScreeningResponse) async def screen_transaction( transaction_id: str, rule_ids: Optional[List[int]] = None, db: Session = Depends(get_db) ): """ Screen a single transaction against fraud detection rules in real-time. This endpoint evaluates a transaction against all active rules or specific rules and returns the screening results immediately. Suitable for real-time fraud detection. Args: transaction_id: The ID of the transaction to screen rule_ids: Optional list of specific rule IDs to apply (if None, applies all active rules) Returns: ScreeningResponse with detailed results and overall risk assessment """ start_time = time.time() # Initialize rule engine rule_engine = RuleEngine(db) try: # Evaluate transaction screening_results = rule_engine.evaluate_transaction(transaction_id, rule_ids) # Save results to database db_results = [] total_risk_score = 0.0 flagged_count = 0 for result in screening_results: db_result = ScreeningResult( transaction_id=result.transaction_id, rule_id=result.rule_id, rule_name=result.rule_name, rule_version=result.rule_version, status=result.status, risk_score=result.risk_score, details=json.dumps(result.details) if result.details else None, aggregated_data=json.dumps(result.aggregated_data) if result.aggregated_data else None, screening_type=result.screening_type ) db.add(db_result) db_results.append(db_result) total_risk_score += result.risk_score if result.status == "flagged": flagged_count += 1 db.commit() # Refresh to get IDs for db_result in db_results: db.refresh(db_result) # Determine overall status overall_status = "flagged" if flagged_count > 0 else "clean" # Convert to response format response_results = [] for i, db_result in enumerate(db_results): result_dict = { "id": db_result.id, "transaction_id": db_result.transaction_id, "rule_id": db_result.rule_id, "rule_name": db_result.rule_name, "rule_version": db_result.rule_version, "status": db_result.status, "risk_score": db_result.risk_score, "details": json.loads(db_result.details) if db_result.details else None, "aggregated_data": json.loads(db_result.aggregated_data) if db_result.aggregated_data else None, "screening_type": db_result.screening_type, "created_at": db_result.created_at } response_results.append(result_dict) screening_duration = (time.time() - start_time) * 1000 return ScreeningResponse( transaction_id=transaction_id, results=response_results, overall_status=overall_status, total_risk_score=total_risk_score, screening_duration_ms=screening_duration ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"Screening failed: {str(e)}") @router.post("/batch", response_model=ScreeningBatchSchema) async def create_batch_screening( request: BatchScreeningRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db) ): """ Create a batch screening job for multiple transactions. This endpoint creates a batch job that will screen multiple transactions based on the provided filters. The job runs in the background and results can be retrieved using the batch endpoints. Args: request: Batch screening configuration including filters and rules Returns: ScreeningBatchSchema with batch job details """ # Generate unique batch ID batch_id = str(uuid.uuid4()) # Count transactions that match the filters query = db.query(Transaction) if request.transaction_filters: # Apply transaction filters filters = request.transaction_filters if "user_id" in filters: query = query.filter(Transaction.user_id == filters["user_id"]) if "account_id" in filters: query = query.filter(Transaction.account_id == filters["account_id"]) if "transaction_type" in filters: query = query.filter(Transaction.transaction_type == filters["transaction_type"]) if "min_amount" in filters: query = query.filter(Transaction.amount >= filters["min_amount"]) if "max_amount" in filters: query = query.filter(Transaction.amount <= filters["max_amount"]) if "channel" in filters: query = query.filter(Transaction.channel == filters["channel"]) if "status" in filters: query = query.filter(Transaction.status == filters["status"]) if request.date_from: query = query.filter(Transaction.created_at >= request.date_from) if request.date_to: query = query.filter(Transaction.created_at <= request.date_to) total_transactions = query.count() # Get rule IDs to apply rule_ids_to_apply = request.rule_ids if not rule_ids_to_apply: # Use all active rules active_rules = db.query(Rule.id).filter(Rule.is_active).all() rule_ids_to_apply = [rule.id for rule in active_rules] # Create batch record batch = ScreeningBatch( batch_id=batch_id, name=request.name, description=request.description, status="pending", total_transactions=total_transactions, processed_transactions=0, flagged_transactions=0, rules_applied=json.dumps(rule_ids_to_apply) ) db.add(batch) db.commit() db.refresh(batch) # Add background task to process the batch background_tasks.add_task( process_batch_screening, batch.id, request.dict(), rule_ids_to_apply ) # Convert to response format result = ScreeningBatchSchema.from_orm(batch) result.rules_applied = rule_ids_to_apply return result @router.get("/batch", response_model=PaginatedResponse) async def get_screening_batches( page: int = Query(1, ge=1), page_size: int = Query(100, ge=1, le=1000), status: Optional[str] = None, db: Session = Depends(get_db) ): """ Retrieve screening batch jobs with filtering and pagination. """ query = db.query(ScreeningBatch) if status: query = query.filter(ScreeningBatch.status == status) # Get total count total = query.count() # Apply pagination offset = (page - 1) * page_size batches = query.order_by(desc(ScreeningBatch.created_at)).offset(offset).limit(page_size).all() # Convert to response format items = [] for batch in batches: result = ScreeningBatchSchema.from_orm(batch) if batch.rules_applied: result.rules_applied = json.loads(batch.rules_applied) items.append(result) return PaginatedResponse( items=items, total=total, page=page, page_size=page_size, total_pages=(total + page_size - 1) // page_size ) @router.get("/batch/{batch_id}", response_model=ScreeningBatchSchema) async def get_screening_batch( batch_id: str, db: Session = Depends(get_db) ): """ Retrieve a specific screening batch by ID. """ batch = db.query(ScreeningBatch).filter(ScreeningBatch.batch_id == batch_id).first() if not batch: raise HTTPException(status_code=404, detail="Batch not found") result = ScreeningBatchSchema.from_orm(batch) if batch.rules_applied: result.rules_applied = json.loads(batch.rules_applied) return result @router.get("/results", response_model=PaginatedResponse) async def get_screening_results( page: int = Query(1, ge=1), page_size: int = Query(100, ge=1, le=1000), transaction_id: Optional[str] = None, rule_id: Optional[int] = None, status: Optional[str] = None, screening_type: Optional[str] = None, min_risk_score: Optional[float] = None, db: Session = Depends(get_db) ): """ Retrieve screening results with filtering and pagination. This endpoint provides audit and reporting capabilities for screening results. Results can be filtered by various criteria for compliance and analysis. """ query = db.query(ScreeningResult) # Apply filters if transaction_id: query = query.filter(ScreeningResult.transaction_id == transaction_id) if rule_id: query = query.filter(ScreeningResult.rule_id == rule_id) if status: query = query.filter(ScreeningResult.status == status) if screening_type: query = query.filter(ScreeningResult.screening_type == screening_type) if min_risk_score is not None: query = query.filter(ScreeningResult.risk_score >= min_risk_score) # Get total count total = query.count() # Apply pagination offset = (page - 1) * page_size results = query.order_by(desc(ScreeningResult.created_at)).offset(offset).limit(page_size).all() # Convert to response format items = [] for result in results: result_dict = { "id": result.id, "transaction_id": result.transaction_id, "rule_id": result.rule_id, "rule_name": result.rule_name, "rule_version": result.rule_version, "status": result.status, "risk_score": result.risk_score, "details": json.loads(result.details) if result.details else None, "aggregated_data": json.loads(result.aggregated_data) if result.aggregated_data else None, "screening_type": result.screening_type, "created_at": result.created_at } items.append(result_dict) return PaginatedResponse( items=items, total=total, page=page, page_size=page_size, total_pages=(total + page_size - 1) // page_size ) @router.post("/aggregate", response_model=AggregateResponse) async def compute_aggregate( request: AggregateRequest, db: Session = Depends(get_db) ): """ Compute aggregate values for transactions with optional caching. This endpoint allows computing various aggregations on transaction data such as sum, count, average, etc. Results are cached for performance. Example request: { "aggregate_function": "sum", "field": "amount", "group_by": ["user_id"], "filters": {"transaction_type": "debit"}, "time_window": "24h" } """ start_time = time.time() try: # Build query query = db.query(Transaction) # Apply filters if request.filters: for field, value in request.filters.items(): if hasattr(Transaction, field): query = query.filter(getattr(Transaction, field) == value) # Apply date range if request.date_from: query = query.filter(Transaction.created_at >= request.date_from) if request.date_to: query = query.filter(Transaction.created_at <= request.date_to) # Apply time window if request.time_window: rule_engine = RuleEngine(db) time_delta = rule_engine._parse_time_window(request.time_window) cutoff_time = datetime.utcnow() - time_delta query = query.filter(Transaction.created_at >= cutoff_time) # Compute aggregate if request.aggregate_function == "count": if request.field == "*": result = query.count() else: field_attr = getattr(Transaction, request.field, None) if field_attr: result = query.filter(field_attr.isnot(None)).count() else: result = 0 elif request.aggregate_function == "sum": field_attr = getattr(Transaction, request.field, None) if field_attr: result = query.with_entities(func.sum(field_attr)).scalar() or 0 else: result = 0 elif request.aggregate_function == "avg": field_attr = getattr(Transaction, request.field, None) if field_attr: result = query.with_entities(func.avg(field_attr)).scalar() or 0 else: result = 0 elif request.aggregate_function == "max": field_attr = getattr(Transaction, request.field, None) if field_attr: result = query.with_entities(func.max(field_attr)).scalar() or 0 else: result = 0 elif request.aggregate_function == "min": field_attr = getattr(Transaction, request.field, None) if field_attr: result = query.with_entities(func.min(field_attr)).scalar() or 0 else: result = 0 else: raise HTTPException(status_code=400, detail=f"Unsupported aggregate function: {request.aggregate_function}") # Handle group by if request.group_by: # This is a simplified implementation # In a production system, you might want to use SQL GROUP BY result = {"total": result, "note": "Group by aggregation simplified for demo"} computation_time = (time.time() - start_time) * 1000 return AggregateResponse( result={"value": result} if not request.group_by else result, cache_hit=False, computation_time_ms=computation_time ) except Exception as e: raise HTTPException(status_code=500, detail=f"Aggregate computation failed: {str(e)}") def process_batch_screening(batch_id: int, request_data: dict, rule_ids: List[int]): """ Background task to process batch screening. """ from app.db.session import SessionLocal db = SessionLocal() try: # Get batch record batch = db.query(ScreeningBatch).filter(ScreeningBatch.id == batch_id).first() if not batch: return # Update status to processing batch.status = "processing" batch.started_at = datetime.utcnow() db.commit() # Build transaction query query = db.query(Transaction) # Apply filters from request request_obj = BatchScreeningRequest(**request_data) if request_obj.transaction_filters: filters = request_obj.transaction_filters if "user_id" in filters: query = query.filter(Transaction.user_id == filters["user_id"]) if "account_id" in filters: query = query.filter(Transaction.account_id == filters["account_id"]) # Add more filters as needed if request_obj.date_from: query = query.filter(Transaction.created_at >= request_obj.date_from) if request_obj.date_to: query = query.filter(Transaction.created_at <= request_obj.date_to) # Get transactions to process transactions = query.all() # Initialize rule engine rule_engine = RuleEngine(db) processed_count = 0 flagged_count = 0 # Process each transaction for transaction in transactions: try: # Screen transaction screening_results = rule_engine.evaluate_transaction(transaction.transaction_id, rule_ids) # Save results transaction_flagged = False for result in screening_results: db_result = ScreeningResult( transaction_id=result.transaction_id, rule_id=result.rule_id, rule_name=result.rule_name, rule_version=result.rule_version, status=result.status, risk_score=result.risk_score, details=json.dumps(result.details) if result.details else None, aggregated_data=json.dumps(result.aggregated_data) if result.aggregated_data else None, screening_type="batch" ) db.add(db_result) if result.status == "flagged": transaction_flagged = True if transaction_flagged: flagged_count += 1 processed_count += 1 # Update batch progress periodically if processed_count % 100 == 0: batch.processed_transactions = processed_count batch.flagged_transactions = flagged_count db.commit() except Exception as e: # Log error but continue processing print(f"Error processing transaction {transaction.transaction_id}: {str(e)}") continue # Update final batch status batch.status = "completed" batch.processed_transactions = processed_count batch.flagged_transactions = flagged_count batch.completed_at = datetime.utcnow() db.commit() except Exception as e: # Mark batch as failed batch.status = "failed" batch.completed_at = datetime.utcnow() db.commit() print(f"Batch screening failed: {str(e)}") finally: db.close()