diff --git a/app/services/db_aggregator.py b/app/services/db_aggregator.py new file mode 100644 index 0000000..830ec3f --- /dev/null +++ b/app/services/db_aggregator.py @@ -0,0 +1,610 @@ +import json +from datetime import datetime, timedelta +from typing import List, Dict, Any, Optional, Tuple, Union +from sqlalchemy.orm import Session +from sqlalchemy import and_, or_, func, desc, text, case, cast, String +from sqlalchemy.sql import select + +from app.models.transaction import Transaction +from app.models.screening import AggregateCache + +class DatabaseAggregator: + """ + High-performance database-level aggregation service. + Uses native SQL aggregation functions for optimal performance. + """ + + def __init__(self, db: Session): + self.db = db + + def aggregate_transactions( + self, + aggregate_function: str, + field: str, + group_by: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + time_window: Optional[str] = None, + date_from: Optional[datetime] = None, + date_to: Optional[datetime] = None, + use_cache: bool = True + ) -> Dict[str, Any]: + """ + Perform database-level aggregation on transactions. + + Args: + aggregate_function: sum, count, avg, max, min, count_distinct + field: Field to aggregate on (* for count) + group_by: List of fields to group by + filters: Filter conditions + time_window: Time window (e.g., "24h", "7d") + date_from: Start date filter + date_to: End date filter + use_cache: Whether to use caching + + Returns: + Dictionary with aggregation results + """ + # Generate cache key + cache_key = self._generate_cache_key( + aggregate_function, field, group_by, filters, time_window, date_from, date_to + ) + + # Check cache if enabled + if use_cache: + cached_result = self._get_cached_result(cache_key) + if cached_result: + return cached_result + + # Build query + query = self._build_aggregation_query( + aggregate_function, field, group_by, filters, time_window, date_from, date_to + ) + + # Execute query + start_time = datetime.utcnow() + results = query.all() + execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000 + + # Format results + formatted_results = self._format_aggregation_results( + results, aggregate_function, field, group_by + ) + + # Prepare response + response = { + "results": formatted_results, + "metadata": { + "aggregate_function": aggregate_function, + "field": field, + "group_by": group_by, + "filters": filters, + "time_window": time_window, + "execution_time_ms": execution_time, + "total_groups": len(formatted_results) if group_by else 1, + "cache_hit": False, + "computed_at": start_time.isoformat() + } + } + + # Cache result if enabled + if use_cache: + self._cache_result(cache_key, response, time_window) + + return response + + def aggregate_by_user( + self, + aggregate_function: str, + field: str, + user_id: Optional[str] = None, + time_window: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Aggregate transactions by user. + """ + group_by = ["user_id"] + if user_id: + filters = filters or {} + filters["user_id"] = user_id + group_by = None # Don't group if filtering by specific user + + return self.aggregate_transactions( + aggregate_function=aggregate_function, + field=field, + group_by=group_by, + filters=filters, + time_window=time_window + ) + + def aggregate_by_account( + self, + aggregate_function: str, + field: str, + account_id: Optional[str] = None, + time_window: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Aggregate transactions by account. + """ + group_by = ["account_id"] + if account_id: + filters = filters or {} + filters["account_id"] = account_id + group_by = None + + return self.aggregate_transactions( + aggregate_function=aggregate_function, + field=field, + group_by=group_by, + filters=filters, + time_window=time_window + ) + + def aggregate_by_time_period( + self, + aggregate_function: str, + field: str, + time_period: str = "hour", # hour, day, week, month + group_by: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + date_from: Optional[datetime] = None, + date_to: Optional[datetime] = None + ) -> Dict[str, Any]: + """ + Aggregate transactions by time periods with optional additional grouping. + """ + # Add time period to group by + time_group_by = [f"time_{time_period}"] + if group_by: + time_group_by.extend(group_by) + + # Build query with time-based grouping + query = self._build_time_aggregation_query( + aggregate_function, field, time_period, group_by, filters, date_from, date_to + ) + + # Execute query + start_time = datetime.utcnow() + results = query.all() + execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000 + + # Format time-based results + formatted_results = self._format_time_aggregation_results( + results, aggregate_function, field, time_period, group_by + ) + + return { + "results": formatted_results, + "metadata": { + "aggregate_function": aggregate_function, + "field": field, + "time_period": time_period, + "group_by": group_by, + "filters": filters, + "execution_time_ms": execution_time, + "total_periods": len(formatted_results), + "computed_at": start_time.isoformat() + } + } + + def get_velocity_metrics( + self, + user_id: str, + time_window: str = "24h", + transaction_type: Optional[str] = None + ) -> Dict[str, Any]: + """ + Get comprehensive velocity metrics for a user. + """ + filters = {"user_id": user_id} + if transaction_type: + filters["transaction_type"] = transaction_type + + # Get multiple aggregations in parallel + metrics = {} + + # Transaction count + count_result = self.aggregate_transactions( + "count", "*", filters=filters, time_window=time_window + ) + metrics["transaction_count"] = count_result["results"] + + # Amount sum + sum_result = self.aggregate_transactions( + "sum", "amount", filters=filters, time_window=time_window + ) + metrics["total_amount"] = sum_result["results"] + + # Average amount + avg_result = self.aggregate_transactions( + "avg", "amount", filters=filters, time_window=time_window + ) + metrics["average_amount"] = avg_result["results"] + + # Maximum amount + max_result = self.aggregate_transactions( + "max", "amount", filters=filters, time_window=time_window + ) + metrics["max_amount"] = max_result["results"] + + # Unique devices + device_result = self.aggregate_transactions( + "count_distinct", "device_id", filters=filters, time_window=time_window + ) + metrics["unique_devices"] = device_result["results"] + + # Unique channels + channel_result = self.aggregate_transactions( + "count_distinct", "channel", filters=filters, time_window=time_window + ) + metrics["unique_channels"] = channel_result["results"] + + return { + "user_id": user_id, + "time_window": time_window, + "metrics": metrics, + "computed_at": datetime.utcnow().isoformat() + } + + def _build_aggregation_query( + self, + aggregate_function: str, + field: str, + group_by: Optional[List[str]], + filters: Optional[Dict[str, Any]], + time_window: Optional[str], + date_from: Optional[datetime], + date_to: Optional[datetime] + ): + """ + Build SQLAlchemy query for aggregation. + """ + # Start with base query + query = self.db.query() + + # Build SELECT clause + select_fields = [] + + # Add group by fields + if group_by: + for group_field in group_by: + if hasattr(Transaction, group_field): + select_fields.append(getattr(Transaction, group_field)) + + # Add aggregation field + if aggregate_function == "count": + if field == "*": + agg_field = func.count() + else: + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.count(field_attr) + else: + raise ValueError(f"Field {field} not found") + elif aggregate_function == "sum": + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.sum(field_attr) + else: + raise ValueError(f"Field {field} not found") + elif aggregate_function == "avg": + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.avg(field_attr) + else: + raise ValueError(f"Field {field} not found") + elif aggregate_function == "max": + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.max(field_attr) + else: + raise ValueError(f"Field {field} not found") + elif aggregate_function == "min": + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.min(field_attr) + else: + raise ValueError(f"Field {field} not found") + elif aggregate_function == "count_distinct": + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.count(func.distinct(field_attr)) + else: + raise ValueError(f"Field {field} not found") + else: + raise ValueError(f"Unsupported aggregate function: {aggregate_function}") + + select_fields.append(agg_field.label('agg_value')) + + # Apply SELECT + query = query.add_columns(*select_fields).select_from(Transaction) + + # Apply filters + if filters: + for filter_field, filter_value in filters.items(): + if hasattr(Transaction, filter_field): + field_attr = getattr(Transaction, filter_field) + if isinstance(filter_value, list): + query = query.filter(field_attr.in_(filter_value)) + else: + query = query.filter(field_attr == filter_value) + + # Apply time filters + if time_window: + time_delta = self._parse_time_window(time_window) + cutoff_time = datetime.utcnow() - time_delta + query = query.filter(Transaction.created_at >= cutoff_time) + + if date_from: + query = query.filter(Transaction.created_at >= date_from) + + if date_to: + query = query.filter(Transaction.created_at <= date_to) + + # Apply GROUP BY + if group_by: + for group_field in group_by: + if hasattr(Transaction, group_field): + query = query.group_by(getattr(Transaction, group_field)) + + # Add ORDER BY for consistent results + if group_by: + query = query.order_by(desc('agg_value')) + + return query + + def _build_time_aggregation_query( + self, + aggregate_function: str, + field: str, + time_period: str, + group_by: Optional[List[str]], + filters: Optional[Dict[str, Any]], + date_from: Optional[datetime], + date_to: Optional[datetime] + ): + """ + Build query for time-based aggregation. + """ + # Time period extraction + if time_period == "hour": + time_expr = func.strftime('%Y-%m-%d %H:00:00', Transaction.created_at) + elif time_period == "day": + time_expr = func.strftime('%Y-%m-%d', Transaction.created_at) + elif time_period == "week": + time_expr = func.strftime('%Y-W%W', Transaction.created_at) + elif time_period == "month": + time_expr = func.strftime('%Y-%m', Transaction.created_at) + else: + raise ValueError(f"Unsupported time period: {time_period}") + + # Start building query + select_fields = [time_expr.label('time_period')] + + # Add group by fields + if group_by: + for group_field in group_by: + if hasattr(Transaction, group_field): + select_fields.append(getattr(Transaction, group_field)) + + # Add aggregation field + if aggregate_function == "count": + if field == "*": + agg_field = func.count() + else: + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.count(field_attr) + else: + raise ValueError(f"Field {field} not found") + elif aggregate_function == "sum": + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.sum(field_attr) + else: + raise ValueError(f"Field {field} not found") + elif aggregate_function == "avg": + field_attr = getattr(Transaction, field, None) + if field_attr: + agg_field = func.avg(field_attr) + else: + raise ValueError(f"Field {field} not found") + else: + raise ValueError(f"Unsupported aggregate function: {aggregate_function}") + + select_fields.append(agg_field.label('agg_value')) + + # Build query + query = self.db.query(*select_fields) + + # Apply filters + if filters: + for filter_field, filter_value in filters.items(): + if hasattr(Transaction, filter_field): + field_attr = getattr(Transaction, filter_field) + if isinstance(filter_value, list): + query = query.filter(field_attr.in_(filter_value)) + else: + query = query.filter(field_attr == filter_value) + + # Apply date filters + if date_from: + query = query.filter(Transaction.created_at >= date_from) + if date_to: + query = query.filter(Transaction.created_at <= date_to) + + # Group by time period and other fields + query = query.group_by(time_expr) + if group_by: + for group_field in group_by: + if hasattr(Transaction, group_field): + query = query.group_by(getattr(Transaction, group_field)) + + # Order by time period + query = query.order_by(time_expr) + + return query + + def _format_aggregation_results( + self, + results: List[Tuple], + aggregate_function: str, + field: str, + group_by: Optional[List[str]] + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """ + Format aggregation results into a consistent structure. + """ + if not group_by: + # Single result + if results: + return {"value": float(results[0].agg_value) if results[0].agg_value else 0} + else: + return {"value": 0} + + # Multiple results with grouping + formatted = [] + for result in results: + result_dict = {} + + # Add group by fields + for i, group_field in enumerate(group_by): + result_dict[group_field] = getattr(result, group_field, None) + + # Add aggregated value + result_dict["value"] = float(result.agg_value) if result.agg_value else 0 + + formatted.append(result_dict) + + return formatted + + def _format_time_aggregation_results( + self, + results: List[Tuple], + aggregate_function: str, + field: str, + time_period: str, + group_by: Optional[List[str]] + ) -> List[Dict[str, Any]]: + """ + Format time-based aggregation results. + """ + formatted = [] + for result in results: + result_dict = { + "time_period": result.time_period, + "value": float(result.agg_value) if result.agg_value else 0 + } + + # Add group by fields if present + if group_by: + for group_field in group_by: + result_dict[group_field] = getattr(result, group_field, None) + + formatted.append(result_dict) + + return formatted + + def _generate_cache_key( + self, + aggregate_function: str, + field: str, + group_by: Optional[List[str]], + filters: Optional[Dict[str, Any]], + time_window: Optional[str], + date_from: Optional[datetime], + date_to: Optional[datetime] + ) -> str: + """ + Generate cache key for aggregation. + """ + key_parts = [ + f"agg:{aggregate_function}", + f"field:{field}", + f"group_by:{':'.join(group_by) if group_by else 'none'}", + f"filters:{json.dumps(filters, sort_keys=True) if filters else 'none'}", + f"time_window:{time_window or 'none'}", + f"date_from:{date_from.isoformat() if date_from else 'none'}", + f"date_to:{date_to.isoformat() if date_to else 'none'}" + ] + return "|".join(key_parts) + + def _get_cached_result(self, cache_key: str) -> Optional[Dict[str, Any]]: + """ + Get cached aggregation result. + """ + cached = self.db.query(AggregateCache).filter( + AggregateCache.cache_key == cache_key, + AggregateCache.expires_at > datetime.utcnow() + ).first() + + if cached: + result = json.loads(cached.cache_value) + result["metadata"]["cache_hit"] = True + return result + + return None + + def _cache_result( + self, + cache_key: str, + result: Dict[str, Any], + time_window: Optional[str] + ): + """ + Cache aggregation result. + """ + # Determine cache duration + if time_window: + time_delta = self._parse_time_window(time_window) + # Cache for 10% of time window, min 1 minute, max 1 hour + cache_duration = max(min(time_delta * 0.1, timedelta(hours=1)), timedelta(minutes=1)) + else: + cache_duration = timedelta(minutes=5) + + expires_at = datetime.utcnow() + cache_duration + + # Upsert cache entry + existing = self.db.query(AggregateCache).filter( + AggregateCache.cache_key == cache_key + ).first() + + if existing: + existing.cache_value = json.dumps(result) + existing.expires_at = expires_at + else: + cache_entry = AggregateCache( + cache_key=cache_key, + cache_value=json.dumps(result), + expires_at=expires_at + ) + self.db.add(cache_entry) + + self.db.commit() + + def _parse_time_window(self, time_window: str) -> timedelta: + """ + Parse time window string into timedelta. + """ + if time_window.endswith('h'): + hours = int(time_window[:-1]) + return timedelta(hours=hours) + elif time_window.endswith('d'): + days = int(time_window[:-1]) + return timedelta(days=days) + elif time_window.endswith('m'): + minutes = int(time_window[:-1]) + return timedelta(minutes=minutes) + else: + raise ValueError(f"Unsupported time window format: {time_window}") + + def cleanup_expired_cache(self): + """ + Remove expired cache entries. + """ + self.db.query(AggregateCache).filter( + AggregateCache.expires_at < datetime.utcnow() + ).delete() + self.db.commit() \ No newline at end of file