import json from datetime import datetime, timedelta from typing import List, Dict, Any, Optional, Tuple, Union from sqlalchemy.orm import Session from sqlalchemy import and_, or_, func, desc, text, case, cast, String from sqlalchemy.sql import select from app.models.transaction import Transaction from app.models.screening import AggregateCache class DatabaseAggregator: """ High-performance database-level aggregation service. Uses native SQL aggregation functions for optimal performance. """ def __init__(self, db: Session): self.db = db def aggregate_transactions( self, aggregate_function: str, field: str, group_by: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, time_window: Optional[str] = None, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None, use_cache: bool = True ) -> Dict[str, Any]: """ Perform database-level aggregation on transactions. Args: aggregate_function: sum, count, avg, max, min, count_distinct field: Field to aggregate on (* for count) group_by: List of fields to group by filters: Filter conditions time_window: Time window (e.g., "24h", "7d") date_from: Start date filter date_to: End date filter use_cache: Whether to use caching Returns: Dictionary with aggregation results """ # Generate cache key cache_key = self._generate_cache_key( aggregate_function, field, group_by, filters, time_window, date_from, date_to ) # Check cache if enabled if use_cache: cached_result = self._get_cached_result(cache_key) if cached_result: return cached_result # Build query query = self._build_aggregation_query( aggregate_function, field, group_by, filters, time_window, date_from, date_to ) # Execute query start_time = datetime.utcnow() results = query.all() execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000 # Format results formatted_results = self._format_aggregation_results( results, aggregate_function, field, group_by ) # Prepare response response = { "results": formatted_results, "metadata": { "aggregate_function": aggregate_function, "field": field, "group_by": group_by, "filters": filters, "time_window": time_window, "execution_time_ms": execution_time, "total_groups": len(formatted_results) if group_by else 1, "cache_hit": False, "computed_at": start_time.isoformat() } } # Cache result if enabled if use_cache: self._cache_result(cache_key, response, time_window) return response def aggregate_by_user( self, aggregate_function: str, field: str, user_id: Optional[str] = None, time_window: Optional[str] = None, filters: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Aggregate transactions by user. """ group_by = ["user_id"] if user_id: filters = filters or {} filters["user_id"] = user_id group_by = None # Don't group if filtering by specific user return self.aggregate_transactions( aggregate_function=aggregate_function, field=field, group_by=group_by, filters=filters, time_window=time_window ) def aggregate_by_account( self, aggregate_function: str, field: str, account_id: Optional[str] = None, time_window: Optional[str] = None, filters: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Aggregate transactions by account. """ group_by = ["account_id"] if account_id: filters = filters or {} filters["account_id"] = account_id group_by = None return self.aggregate_transactions( aggregate_function=aggregate_function, field=field, group_by=group_by, filters=filters, time_window=time_window ) def aggregate_by_time_period( self, aggregate_function: str, field: str, time_period: str = "hour", # hour, day, week, month group_by: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None ) -> Dict[str, Any]: """ Aggregate transactions by time periods with optional additional grouping. """ # Add time period to group by time_group_by = [f"time_{time_period}"] if group_by: time_group_by.extend(group_by) # Build query with time-based grouping query = self._build_time_aggregation_query( aggregate_function, field, time_period, group_by, filters, date_from, date_to ) # Execute query start_time = datetime.utcnow() results = query.all() execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000 # Format time-based results formatted_results = self._format_time_aggregation_results( results, aggregate_function, field, time_period, group_by ) return { "results": formatted_results, "metadata": { "aggregate_function": aggregate_function, "field": field, "time_period": time_period, "group_by": group_by, "filters": filters, "execution_time_ms": execution_time, "total_periods": len(formatted_results), "computed_at": start_time.isoformat() } } def get_velocity_metrics( self, user_id: str, time_window: str = "24h", transaction_type: Optional[str] = None ) -> Dict[str, Any]: """ Get comprehensive velocity metrics for a user. """ filters = {"user_id": user_id} if transaction_type: filters["transaction_type"] = transaction_type # Get multiple aggregations in parallel metrics = {} # Transaction count count_result = self.aggregate_transactions( "count", "*", filters=filters, time_window=time_window ) metrics["transaction_count"] = count_result["results"] # Amount sum sum_result = self.aggregate_transactions( "sum", "amount", filters=filters, time_window=time_window ) metrics["total_amount"] = sum_result["results"] # Average amount avg_result = self.aggregate_transactions( "avg", "amount", filters=filters, time_window=time_window ) metrics["average_amount"] = avg_result["results"] # Maximum amount max_result = self.aggregate_transactions( "max", "amount", filters=filters, time_window=time_window ) metrics["max_amount"] = max_result["results"] # Unique devices device_result = self.aggregate_transactions( "count_distinct", "device_id", filters=filters, time_window=time_window ) metrics["unique_devices"] = device_result["results"] # Unique channels channel_result = self.aggregate_transactions( "count_distinct", "channel", filters=filters, time_window=time_window ) metrics["unique_channels"] = channel_result["results"] return { "user_id": user_id, "time_window": time_window, "metrics": metrics, "computed_at": datetime.utcnow().isoformat() } def _build_aggregation_query( self, aggregate_function: str, field: str, group_by: Optional[List[str]], filters: Optional[Dict[str, Any]], time_window: Optional[str], date_from: Optional[datetime], date_to: Optional[datetime] ): """ Build SQLAlchemy query for aggregation. """ # Start with base query query = self.db.query() # Build SELECT clause select_fields = [] # Add group by fields if group_by: for group_field in group_by: if hasattr(Transaction, group_field): select_fields.append(getattr(Transaction, group_field)) # Add aggregation field if aggregate_function == "count": if field == "*": agg_field = func.count() else: field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.count(field_attr) else: raise ValueError(f"Field {field} not found") elif aggregate_function == "sum": field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.sum(field_attr) else: raise ValueError(f"Field {field} not found") elif aggregate_function == "avg": field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.avg(field_attr) else: raise ValueError(f"Field {field} not found") elif aggregate_function == "max": field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.max(field_attr) else: raise ValueError(f"Field {field} not found") elif aggregate_function == "min": field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.min(field_attr) else: raise ValueError(f"Field {field} not found") elif aggregate_function == "count_distinct": field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.count(func.distinct(field_attr)) else: raise ValueError(f"Field {field} not found") else: raise ValueError(f"Unsupported aggregate function: {aggregate_function}") select_fields.append(agg_field.label('agg_value')) # Apply SELECT query = query.add_columns(*select_fields).select_from(Transaction) # Apply filters if filters: for filter_field, filter_value in filters.items(): if hasattr(Transaction, filter_field): field_attr = getattr(Transaction, filter_field) if isinstance(filter_value, list): query = query.filter(field_attr.in_(filter_value)) else: query = query.filter(field_attr == filter_value) # Apply time filters if time_window: time_delta = self._parse_time_window(time_window) cutoff_time = datetime.utcnow() - time_delta query = query.filter(Transaction.created_at >= cutoff_time) if date_from: query = query.filter(Transaction.created_at >= date_from) if date_to: query = query.filter(Transaction.created_at <= date_to) # Apply GROUP BY if group_by: for group_field in group_by: if hasattr(Transaction, group_field): query = query.group_by(getattr(Transaction, group_field)) # Add ORDER BY for consistent results if group_by: query = query.order_by(desc('agg_value')) return query def _build_time_aggregation_query( self, aggregate_function: str, field: str, time_period: str, group_by: Optional[List[str]], filters: Optional[Dict[str, Any]], date_from: Optional[datetime], date_to: Optional[datetime] ): """ Build query for time-based aggregation. """ # Time period extraction if time_period == "hour": time_expr = func.strftime('%Y-%m-%d %H:00:00', Transaction.created_at) elif time_period == "day": time_expr = func.strftime('%Y-%m-%d', Transaction.created_at) elif time_period == "week": time_expr = func.strftime('%Y-W%W', Transaction.created_at) elif time_period == "month": time_expr = func.strftime('%Y-%m', Transaction.created_at) else: raise ValueError(f"Unsupported time period: {time_period}") # Start building query select_fields = [time_expr.label('time_period')] # Add group by fields if group_by: for group_field in group_by: if hasattr(Transaction, group_field): select_fields.append(getattr(Transaction, group_field)) # Add aggregation field if aggregate_function == "count": if field == "*": agg_field = func.count() else: field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.count(field_attr) else: raise ValueError(f"Field {field} not found") elif aggregate_function == "sum": field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.sum(field_attr) else: raise ValueError(f"Field {field} not found") elif aggregate_function == "avg": field_attr = getattr(Transaction, field, None) if field_attr: agg_field = func.avg(field_attr) else: raise ValueError(f"Field {field} not found") else: raise ValueError(f"Unsupported aggregate function: {aggregate_function}") select_fields.append(agg_field.label('agg_value')) # Build query query = self.db.query(*select_fields) # Apply filters if filters: for filter_field, filter_value in filters.items(): if hasattr(Transaction, filter_field): field_attr = getattr(Transaction, filter_field) if isinstance(filter_value, list): query = query.filter(field_attr.in_(filter_value)) else: query = query.filter(field_attr == filter_value) # Apply date filters if date_from: query = query.filter(Transaction.created_at >= date_from) if date_to: query = query.filter(Transaction.created_at <= date_to) # Group by time period and other fields query = query.group_by(time_expr) if group_by: for group_field in group_by: if hasattr(Transaction, group_field): query = query.group_by(getattr(Transaction, group_field)) # Order by time period query = query.order_by(time_expr) return query def _format_aggregation_results( self, results: List[Tuple], aggregate_function: str, field: str, group_by: Optional[List[str]] ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """ Format aggregation results into a consistent structure. """ if not group_by: # Single result if results: return {"value": float(results[0].agg_value) if results[0].agg_value else 0} else: return {"value": 0} # Multiple results with grouping formatted = [] for result in results: result_dict = {} # Add group by fields for i, group_field in enumerate(group_by): result_dict[group_field] = getattr(result, group_field, None) # Add aggregated value result_dict["value"] = float(result.agg_value) if result.agg_value else 0 formatted.append(result_dict) return formatted def _format_time_aggregation_results( self, results: List[Tuple], aggregate_function: str, field: str, time_period: str, group_by: Optional[List[str]] ) -> List[Dict[str, Any]]: """ Format time-based aggregation results. """ formatted = [] for result in results: result_dict = { "time_period": result.time_period, "value": float(result.agg_value) if result.agg_value else 0 } # Add group by fields if present if group_by: for group_field in group_by: result_dict[group_field] = getattr(result, group_field, None) formatted.append(result_dict) return formatted def _generate_cache_key( self, aggregate_function: str, field: str, group_by: Optional[List[str]], filters: Optional[Dict[str, Any]], time_window: Optional[str], date_from: Optional[datetime], date_to: Optional[datetime] ) -> str: """ Generate cache key for aggregation. """ key_parts = [ f"agg:{aggregate_function}", f"field:{field}", f"group_by:{':'.join(group_by) if group_by else 'none'}", f"filters:{json.dumps(filters, sort_keys=True) if filters else 'none'}", f"time_window:{time_window or 'none'}", f"date_from:{date_from.isoformat() if date_from else 'none'}", f"date_to:{date_to.isoformat() if date_to else 'none'}" ] return "|".join(key_parts) def _get_cached_result(self, cache_key: str) -> Optional[Dict[str, Any]]: """ Get cached aggregation result. """ cached = self.db.query(AggregateCache).filter( AggregateCache.cache_key == cache_key, AggregateCache.expires_at > datetime.utcnow() ).first() if cached: result = json.loads(cached.cache_value) result["metadata"]["cache_hit"] = True return result return None def _cache_result( self, cache_key: str, result: Dict[str, Any], time_window: Optional[str] ): """ Cache aggregation result. """ # Determine cache duration if time_window: time_delta = self._parse_time_window(time_window) # Cache for 10% of time window, min 1 minute, max 1 hour cache_duration = max(min(time_delta * 0.1, timedelta(hours=1)), timedelta(minutes=1)) else: cache_duration = timedelta(minutes=5) expires_at = datetime.utcnow() + cache_duration # Upsert cache entry existing = self.db.query(AggregateCache).filter( AggregateCache.cache_key == cache_key ).first() if existing: existing.cache_value = json.dumps(result) existing.expires_at = expires_at else: cache_entry = AggregateCache( cache_key=cache_key, cache_value=json.dumps(result), expires_at=expires_at ) self.db.add(cache_entry) self.db.commit() def _parse_time_window(self, time_window: str) -> timedelta: """ Parse time window string into timedelta. """ if time_window.endswith('h'): hours = int(time_window[:-1]) return timedelta(hours=hours) elif time_window.endswith('d'): days = int(time_window[:-1]) return timedelta(days=days) elif time_window.endswith('m'): minutes = int(time_window[:-1]) return timedelta(minutes=minutes) else: raise ValueError(f"Unsupported time window format: {time_window}") def cleanup_expired_cache(self): """ Remove expired cache entries. """ self.db.query(AggregateCache).filter( AggregateCache.expires_at < datetime.utcnow() ).delete() self.db.commit()