2025-06-27 18:17:38 +00:00

610 lines
21 KiB
Python

import json
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Tuple, Union
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func, desc, text, case, cast, String
from sqlalchemy.sql import select
from app.models.transaction import Transaction
from app.models.screening import AggregateCache
class DatabaseAggregator:
"""
High-performance database-level aggregation service.
Uses native SQL aggregation functions for optimal performance.
"""
def __init__(self, db: Session):
self.db = db
def aggregate_transactions(
self,
aggregate_function: str,
field: str,
group_by: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
time_window: Optional[str] = None,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
use_cache: bool = True
) -> Dict[str, Any]:
"""
Perform database-level aggregation on transactions.
Args:
aggregate_function: sum, count, avg, max, min, count_distinct
field: Field to aggregate on (* for count)
group_by: List of fields to group by
filters: Filter conditions
time_window: Time window (e.g., "24h", "7d")
date_from: Start date filter
date_to: End date filter
use_cache: Whether to use caching
Returns:
Dictionary with aggregation results
"""
# Generate cache key
cache_key = self._generate_cache_key(
aggregate_function, field, group_by, filters, time_window, date_from, date_to
)
# Check cache if enabled
if use_cache:
cached_result = self._get_cached_result(cache_key)
if cached_result:
return cached_result
# Build query
query = self._build_aggregation_query(
aggregate_function, field, group_by, filters, time_window, date_from, date_to
)
# Execute query
start_time = datetime.utcnow()
results = query.all()
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
# Format results
formatted_results = self._format_aggregation_results(
results, aggregate_function, field, group_by
)
# Prepare response
response = {
"results": formatted_results,
"metadata": {
"aggregate_function": aggregate_function,
"field": field,
"group_by": group_by,
"filters": filters,
"time_window": time_window,
"execution_time_ms": execution_time,
"total_groups": len(formatted_results) if group_by else 1,
"cache_hit": False,
"computed_at": start_time.isoformat()
}
}
# Cache result if enabled
if use_cache:
self._cache_result(cache_key, response, time_window)
return response
def aggregate_by_user(
self,
aggregate_function: str,
field: str,
user_id: Optional[str] = None,
time_window: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Aggregate transactions by user.
"""
group_by = ["user_id"]
if user_id:
filters = filters or {}
filters["user_id"] = user_id
group_by = None # Don't group if filtering by specific user
return self.aggregate_transactions(
aggregate_function=aggregate_function,
field=field,
group_by=group_by,
filters=filters,
time_window=time_window
)
def aggregate_by_account(
self,
aggregate_function: str,
field: str,
account_id: Optional[str] = None,
time_window: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Aggregate transactions by account.
"""
group_by = ["account_id"]
if account_id:
filters = filters or {}
filters["account_id"] = account_id
group_by = None
return self.aggregate_transactions(
aggregate_function=aggregate_function,
field=field,
group_by=group_by,
filters=filters,
time_window=time_window
)
def aggregate_by_time_period(
self,
aggregate_function: str,
field: str,
time_period: str = "hour", # hour, day, week, month
group_by: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None
) -> Dict[str, Any]:
"""
Aggregate transactions by time periods with optional additional grouping.
"""
# Add time period to group by
time_group_by = [f"time_{time_period}"]
if group_by:
time_group_by.extend(group_by)
# Build query with time-based grouping
query = self._build_time_aggregation_query(
aggregate_function, field, time_period, group_by, filters, date_from, date_to
)
# Execute query
start_time = datetime.utcnow()
results = query.all()
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
# Format time-based results
formatted_results = self._format_time_aggregation_results(
results, aggregate_function, field, time_period, group_by
)
return {
"results": formatted_results,
"metadata": {
"aggregate_function": aggregate_function,
"field": field,
"time_period": time_period,
"group_by": group_by,
"filters": filters,
"execution_time_ms": execution_time,
"total_periods": len(formatted_results),
"computed_at": start_time.isoformat()
}
}
def get_velocity_metrics(
self,
user_id: str,
time_window: str = "24h",
transaction_type: Optional[str] = None
) -> Dict[str, Any]:
"""
Get comprehensive velocity metrics for a user.
"""
filters = {"user_id": user_id}
if transaction_type:
filters["transaction_type"] = transaction_type
# Get multiple aggregations in parallel
metrics = {}
# Transaction count
count_result = self.aggregate_transactions(
"count", "*", filters=filters, time_window=time_window
)
metrics["transaction_count"] = count_result["results"]
# Amount sum
sum_result = self.aggregate_transactions(
"sum", "amount", filters=filters, time_window=time_window
)
metrics["total_amount"] = sum_result["results"]
# Average amount
avg_result = self.aggregate_transactions(
"avg", "amount", filters=filters, time_window=time_window
)
metrics["average_amount"] = avg_result["results"]
# Maximum amount
max_result = self.aggregate_transactions(
"max", "amount", filters=filters, time_window=time_window
)
metrics["max_amount"] = max_result["results"]
# Unique devices
device_result = self.aggregate_transactions(
"count_distinct", "device_id", filters=filters, time_window=time_window
)
metrics["unique_devices"] = device_result["results"]
# Unique channels
channel_result = self.aggregate_transactions(
"count_distinct", "channel", filters=filters, time_window=time_window
)
metrics["unique_channels"] = channel_result["results"]
return {
"user_id": user_id,
"time_window": time_window,
"metrics": metrics,
"computed_at": datetime.utcnow().isoformat()
}
def _build_aggregation_query(
self,
aggregate_function: str,
field: str,
group_by: Optional[List[str]],
filters: Optional[Dict[str, Any]],
time_window: Optional[str],
date_from: Optional[datetime],
date_to: Optional[datetime]
):
"""
Build SQLAlchemy query for aggregation.
"""
# Start with base query
query = self.db.query()
# Build SELECT clause
select_fields = []
# Add group by fields
if group_by:
for group_field in group_by:
if hasattr(Transaction, group_field):
select_fields.append(getattr(Transaction, group_field))
# Add aggregation field
if aggregate_function == "count":
if field == "*":
agg_field = func.count()
else:
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.count(field_attr)
else:
raise ValueError(f"Field {field} not found")
elif aggregate_function == "sum":
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.sum(field_attr)
else:
raise ValueError(f"Field {field} not found")
elif aggregate_function == "avg":
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.avg(field_attr)
else:
raise ValueError(f"Field {field} not found")
elif aggregate_function == "max":
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.max(field_attr)
else:
raise ValueError(f"Field {field} not found")
elif aggregate_function == "min":
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.min(field_attr)
else:
raise ValueError(f"Field {field} not found")
elif aggregate_function == "count_distinct":
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.count(func.distinct(field_attr))
else:
raise ValueError(f"Field {field} not found")
else:
raise ValueError(f"Unsupported aggregate function: {aggregate_function}")
select_fields.append(agg_field.label('agg_value'))
# Apply SELECT
query = query.add_columns(*select_fields).select_from(Transaction)
# Apply filters
if filters:
for filter_field, filter_value in filters.items():
if hasattr(Transaction, filter_field):
field_attr = getattr(Transaction, filter_field)
if isinstance(filter_value, list):
query = query.filter(field_attr.in_(filter_value))
else:
query = query.filter(field_attr == filter_value)
# Apply time filters
if time_window:
time_delta = self._parse_time_window(time_window)
cutoff_time = datetime.utcnow() - time_delta
query = query.filter(Transaction.created_at >= cutoff_time)
if date_from:
query = query.filter(Transaction.created_at >= date_from)
if date_to:
query = query.filter(Transaction.created_at <= date_to)
# Apply GROUP BY
if group_by:
for group_field in group_by:
if hasattr(Transaction, group_field):
query = query.group_by(getattr(Transaction, group_field))
# Add ORDER BY for consistent results
if group_by:
query = query.order_by(desc('agg_value'))
return query
def _build_time_aggregation_query(
self,
aggregate_function: str,
field: str,
time_period: str,
group_by: Optional[List[str]],
filters: Optional[Dict[str, Any]],
date_from: Optional[datetime],
date_to: Optional[datetime]
):
"""
Build query for time-based aggregation.
"""
# Time period extraction
if time_period == "hour":
time_expr = func.strftime('%Y-%m-%d %H:00:00', Transaction.created_at)
elif time_period == "day":
time_expr = func.strftime('%Y-%m-%d', Transaction.created_at)
elif time_period == "week":
time_expr = func.strftime('%Y-W%W', Transaction.created_at)
elif time_period == "month":
time_expr = func.strftime('%Y-%m', Transaction.created_at)
else:
raise ValueError(f"Unsupported time period: {time_period}")
# Start building query
select_fields = [time_expr.label('time_period')]
# Add group by fields
if group_by:
for group_field in group_by:
if hasattr(Transaction, group_field):
select_fields.append(getattr(Transaction, group_field))
# Add aggregation field
if aggregate_function == "count":
if field == "*":
agg_field = func.count()
else:
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.count(field_attr)
else:
raise ValueError(f"Field {field} not found")
elif aggregate_function == "sum":
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.sum(field_attr)
else:
raise ValueError(f"Field {field} not found")
elif aggregate_function == "avg":
field_attr = getattr(Transaction, field, None)
if field_attr:
agg_field = func.avg(field_attr)
else:
raise ValueError(f"Field {field} not found")
else:
raise ValueError(f"Unsupported aggregate function: {aggregate_function}")
select_fields.append(agg_field.label('agg_value'))
# Build query
query = self.db.query(*select_fields)
# Apply filters
if filters:
for filter_field, filter_value in filters.items():
if hasattr(Transaction, filter_field):
field_attr = getattr(Transaction, filter_field)
if isinstance(filter_value, list):
query = query.filter(field_attr.in_(filter_value))
else:
query = query.filter(field_attr == filter_value)
# Apply date filters
if date_from:
query = query.filter(Transaction.created_at >= date_from)
if date_to:
query = query.filter(Transaction.created_at <= date_to)
# Group by time period and other fields
query = query.group_by(time_expr)
if group_by:
for group_field in group_by:
if hasattr(Transaction, group_field):
query = query.group_by(getattr(Transaction, group_field))
# Order by time period
query = query.order_by(time_expr)
return query
def _format_aggregation_results(
self,
results: List[Tuple],
aggregate_function: str,
field: str,
group_by: Optional[List[str]]
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Format aggregation results into a consistent structure.
"""
if not group_by:
# Single result
if results:
return {"value": float(results[0].agg_value) if results[0].agg_value else 0}
else:
return {"value": 0}
# Multiple results with grouping
formatted = []
for result in results:
result_dict = {}
# Add group by fields
for i, group_field in enumerate(group_by):
result_dict[group_field] = getattr(result, group_field, None)
# Add aggregated value
result_dict["value"] = float(result.agg_value) if result.agg_value else 0
formatted.append(result_dict)
return formatted
def _format_time_aggregation_results(
self,
results: List[Tuple],
aggregate_function: str,
field: str,
time_period: str,
group_by: Optional[List[str]]
) -> List[Dict[str, Any]]:
"""
Format time-based aggregation results.
"""
formatted = []
for result in results:
result_dict = {
"time_period": result.time_period,
"value": float(result.agg_value) if result.agg_value else 0
}
# Add group by fields if present
if group_by:
for group_field in group_by:
result_dict[group_field] = getattr(result, group_field, None)
formatted.append(result_dict)
return formatted
def _generate_cache_key(
self,
aggregate_function: str,
field: str,
group_by: Optional[List[str]],
filters: Optional[Dict[str, Any]],
time_window: Optional[str],
date_from: Optional[datetime],
date_to: Optional[datetime]
) -> str:
"""
Generate cache key for aggregation.
"""
key_parts = [
f"agg:{aggregate_function}",
f"field:{field}",
f"group_by:{':'.join(group_by) if group_by else 'none'}",
f"filters:{json.dumps(filters, sort_keys=True) if filters else 'none'}",
f"time_window:{time_window or 'none'}",
f"date_from:{date_from.isoformat() if date_from else 'none'}",
f"date_to:{date_to.isoformat() if date_to else 'none'}"
]
return "|".join(key_parts)
def _get_cached_result(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""
Get cached aggregation result.
"""
cached = self.db.query(AggregateCache).filter(
AggregateCache.cache_key == cache_key,
AggregateCache.expires_at > datetime.utcnow()
).first()
if cached:
result = json.loads(cached.cache_value)
result["metadata"]["cache_hit"] = True
return result
return None
def _cache_result(
self,
cache_key: str,
result: Dict[str, Any],
time_window: Optional[str]
):
"""
Cache aggregation result.
"""
# Determine cache duration
if time_window:
time_delta = self._parse_time_window(time_window)
# Cache for 10% of time window, min 1 minute, max 1 hour
cache_duration = max(min(time_delta * 0.1, timedelta(hours=1)), timedelta(minutes=1))
else:
cache_duration = timedelta(minutes=5)
expires_at = datetime.utcnow() + cache_duration
# Upsert cache entry
existing = self.db.query(AggregateCache).filter(
AggregateCache.cache_key == cache_key
).first()
if existing:
existing.cache_value = json.dumps(result)
existing.expires_at = expires_at
else:
cache_entry = AggregateCache(
cache_key=cache_key,
cache_value=json.dumps(result),
expires_at=expires_at
)
self.db.add(cache_entry)
self.db.commit()
def _parse_time_window(self, time_window: str) -> timedelta:
"""
Parse time window string into timedelta.
"""
if time_window.endswith('h'):
hours = int(time_window[:-1])
return timedelta(hours=hours)
elif time_window.endswith('d'):
days = int(time_window[:-1])
return timedelta(days=days)
elif time_window.endswith('m'):
minutes = int(time_window[:-1])
return timedelta(minutes=minutes)
else:
raise ValueError(f"Unsupported time window format: {time_window}")
def cleanup_expired_cache(self):
"""
Remove expired cache entries.
"""
self.db.query(AggregateCache).filter(
AggregateCache.expires_at < datetime.utcnow()
).delete()
self.db.commit()