Update code via agent code generation
This commit is contained in:
parent
4730c37915
commit
59f241e534
610
app/services/db_aggregator.py
Normal file
610
app/services/db_aggregator.py
Normal file
@ -0,0 +1,610 @@
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, func, desc, text, case, cast, String
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
from app.models.transaction import Transaction
|
||||
from app.models.screening import AggregateCache
|
||||
|
||||
class DatabaseAggregator:
|
||||
"""
|
||||
High-performance database-level aggregation service.
|
||||
Uses native SQL aggregation functions for optimal performance.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def aggregate_transactions(
|
||||
self,
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
group_by: Optional[List[str]] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
time_window: Optional[str] = None,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
use_cache: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform database-level aggregation on transactions.
|
||||
|
||||
Args:
|
||||
aggregate_function: sum, count, avg, max, min, count_distinct
|
||||
field: Field to aggregate on (* for count)
|
||||
group_by: List of fields to group by
|
||||
filters: Filter conditions
|
||||
time_window: Time window (e.g., "24h", "7d")
|
||||
date_from: Start date filter
|
||||
date_to: End date filter
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
Dictionary with aggregation results
|
||||
"""
|
||||
# Generate cache key
|
||||
cache_key = self._generate_cache_key(
|
||||
aggregate_function, field, group_by, filters, time_window, date_from, date_to
|
||||
)
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cached_result = self._get_cached_result(cache_key)
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
# Build query
|
||||
query = self._build_aggregation_query(
|
||||
aggregate_function, field, group_by, filters, time_window, date_from, date_to
|
||||
)
|
||||
|
||||
# Execute query
|
||||
start_time = datetime.utcnow()
|
||||
results = query.all()
|
||||
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
# Format results
|
||||
formatted_results = self._format_aggregation_results(
|
||||
results, aggregate_function, field, group_by
|
||||
)
|
||||
|
||||
# Prepare response
|
||||
response = {
|
||||
"results": formatted_results,
|
||||
"metadata": {
|
||||
"aggregate_function": aggregate_function,
|
||||
"field": field,
|
||||
"group_by": group_by,
|
||||
"filters": filters,
|
||||
"time_window": time_window,
|
||||
"execution_time_ms": execution_time,
|
||||
"total_groups": len(formatted_results) if group_by else 1,
|
||||
"cache_hit": False,
|
||||
"computed_at": start_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Cache result if enabled
|
||||
if use_cache:
|
||||
self._cache_result(cache_key, response, time_window)
|
||||
|
||||
return response
|
||||
|
||||
def aggregate_by_user(
|
||||
self,
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
user_id: Optional[str] = None,
|
||||
time_window: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregate transactions by user.
|
||||
"""
|
||||
group_by = ["user_id"]
|
||||
if user_id:
|
||||
filters = filters or {}
|
||||
filters["user_id"] = user_id
|
||||
group_by = None # Don't group if filtering by specific user
|
||||
|
||||
return self.aggregate_transactions(
|
||||
aggregate_function=aggregate_function,
|
||||
field=field,
|
||||
group_by=group_by,
|
||||
filters=filters,
|
||||
time_window=time_window
|
||||
)
|
||||
|
||||
def aggregate_by_account(
|
||||
self,
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
account_id: Optional[str] = None,
|
||||
time_window: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregate transactions by account.
|
||||
"""
|
||||
group_by = ["account_id"]
|
||||
if account_id:
|
||||
filters = filters or {}
|
||||
filters["account_id"] = account_id
|
||||
group_by = None
|
||||
|
||||
return self.aggregate_transactions(
|
||||
aggregate_function=aggregate_function,
|
||||
field=field,
|
||||
group_by=group_by,
|
||||
filters=filters,
|
||||
time_window=time_window
|
||||
)
|
||||
|
||||
def aggregate_by_time_period(
|
||||
self,
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
time_period: str = "hour", # hour, day, week, month
|
||||
group_by: Optional[List[str]] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregate transactions by time periods with optional additional grouping.
|
||||
"""
|
||||
# Add time period to group by
|
||||
time_group_by = [f"time_{time_period}"]
|
||||
if group_by:
|
||||
time_group_by.extend(group_by)
|
||||
|
||||
# Build query with time-based grouping
|
||||
query = self._build_time_aggregation_query(
|
||||
aggregate_function, field, time_period, group_by, filters, date_from, date_to
|
||||
)
|
||||
|
||||
# Execute query
|
||||
start_time = datetime.utcnow()
|
||||
results = query.all()
|
||||
execution_time = (datetime.utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
# Format time-based results
|
||||
formatted_results = self._format_time_aggregation_results(
|
||||
results, aggregate_function, field, time_period, group_by
|
||||
)
|
||||
|
||||
return {
|
||||
"results": formatted_results,
|
||||
"metadata": {
|
||||
"aggregate_function": aggregate_function,
|
||||
"field": field,
|
||||
"time_period": time_period,
|
||||
"group_by": group_by,
|
||||
"filters": filters,
|
||||
"execution_time_ms": execution_time,
|
||||
"total_periods": len(formatted_results),
|
||||
"computed_at": start_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
def get_velocity_metrics(
|
||||
self,
|
||||
user_id: str,
|
||||
time_window: str = "24h",
|
||||
transaction_type: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive velocity metrics for a user.
|
||||
"""
|
||||
filters = {"user_id": user_id}
|
||||
if transaction_type:
|
||||
filters["transaction_type"] = transaction_type
|
||||
|
||||
# Get multiple aggregations in parallel
|
||||
metrics = {}
|
||||
|
||||
# Transaction count
|
||||
count_result = self.aggregate_transactions(
|
||||
"count", "*", filters=filters, time_window=time_window
|
||||
)
|
||||
metrics["transaction_count"] = count_result["results"]
|
||||
|
||||
# Amount sum
|
||||
sum_result = self.aggregate_transactions(
|
||||
"sum", "amount", filters=filters, time_window=time_window
|
||||
)
|
||||
metrics["total_amount"] = sum_result["results"]
|
||||
|
||||
# Average amount
|
||||
avg_result = self.aggregate_transactions(
|
||||
"avg", "amount", filters=filters, time_window=time_window
|
||||
)
|
||||
metrics["average_amount"] = avg_result["results"]
|
||||
|
||||
# Maximum amount
|
||||
max_result = self.aggregate_transactions(
|
||||
"max", "amount", filters=filters, time_window=time_window
|
||||
)
|
||||
metrics["max_amount"] = max_result["results"]
|
||||
|
||||
# Unique devices
|
||||
device_result = self.aggregate_transactions(
|
||||
"count_distinct", "device_id", filters=filters, time_window=time_window
|
||||
)
|
||||
metrics["unique_devices"] = device_result["results"]
|
||||
|
||||
# Unique channels
|
||||
channel_result = self.aggregate_transactions(
|
||||
"count_distinct", "channel", filters=filters, time_window=time_window
|
||||
)
|
||||
metrics["unique_channels"] = channel_result["results"]
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"time_window": time_window,
|
||||
"metrics": metrics,
|
||||
"computed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def _build_aggregation_query(
|
||||
self,
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
group_by: Optional[List[str]],
|
||||
filters: Optional[Dict[str, Any]],
|
||||
time_window: Optional[str],
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime]
|
||||
):
|
||||
"""
|
||||
Build SQLAlchemy query for aggregation.
|
||||
"""
|
||||
# Start with base query
|
||||
query = self.db.query()
|
||||
|
||||
# Build SELECT clause
|
||||
select_fields = []
|
||||
|
||||
# Add group by fields
|
||||
if group_by:
|
||||
for group_field in group_by:
|
||||
if hasattr(Transaction, group_field):
|
||||
select_fields.append(getattr(Transaction, group_field))
|
||||
|
||||
# Add aggregation field
|
||||
if aggregate_function == "count":
|
||||
if field == "*":
|
||||
agg_field = func.count()
|
||||
else:
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.count(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
elif aggregate_function == "sum":
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.sum(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
elif aggregate_function == "avg":
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.avg(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
elif aggregate_function == "max":
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.max(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
elif aggregate_function == "min":
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.min(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
elif aggregate_function == "count_distinct":
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.count(func.distinct(field_attr))
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
else:
|
||||
raise ValueError(f"Unsupported aggregate function: {aggregate_function}")
|
||||
|
||||
select_fields.append(agg_field.label('agg_value'))
|
||||
|
||||
# Apply SELECT
|
||||
query = query.add_columns(*select_fields).select_from(Transaction)
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for filter_field, filter_value in filters.items():
|
||||
if hasattr(Transaction, filter_field):
|
||||
field_attr = getattr(Transaction, filter_field)
|
||||
if isinstance(filter_value, list):
|
||||
query = query.filter(field_attr.in_(filter_value))
|
||||
else:
|
||||
query = query.filter(field_attr == filter_value)
|
||||
|
||||
# Apply time filters
|
||||
if time_window:
|
||||
time_delta = self._parse_time_window(time_window)
|
||||
cutoff_time = datetime.utcnow() - time_delta
|
||||
query = query.filter(Transaction.created_at >= cutoff_time)
|
||||
|
||||
if date_from:
|
||||
query = query.filter(Transaction.created_at >= date_from)
|
||||
|
||||
if date_to:
|
||||
query = query.filter(Transaction.created_at <= date_to)
|
||||
|
||||
# Apply GROUP BY
|
||||
if group_by:
|
||||
for group_field in group_by:
|
||||
if hasattr(Transaction, group_field):
|
||||
query = query.group_by(getattr(Transaction, group_field))
|
||||
|
||||
# Add ORDER BY for consistent results
|
||||
if group_by:
|
||||
query = query.order_by(desc('agg_value'))
|
||||
|
||||
return query
|
||||
|
||||
def _build_time_aggregation_query(
|
||||
self,
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
time_period: str,
|
||||
group_by: Optional[List[str]],
|
||||
filters: Optional[Dict[str, Any]],
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime]
|
||||
):
|
||||
"""
|
||||
Build query for time-based aggregation.
|
||||
"""
|
||||
# Time period extraction
|
||||
if time_period == "hour":
|
||||
time_expr = func.strftime('%Y-%m-%d %H:00:00', Transaction.created_at)
|
||||
elif time_period == "day":
|
||||
time_expr = func.strftime('%Y-%m-%d', Transaction.created_at)
|
||||
elif time_period == "week":
|
||||
time_expr = func.strftime('%Y-W%W', Transaction.created_at)
|
||||
elif time_period == "month":
|
||||
time_expr = func.strftime('%Y-%m', Transaction.created_at)
|
||||
else:
|
||||
raise ValueError(f"Unsupported time period: {time_period}")
|
||||
|
||||
# Start building query
|
||||
select_fields = [time_expr.label('time_period')]
|
||||
|
||||
# Add group by fields
|
||||
if group_by:
|
||||
for group_field in group_by:
|
||||
if hasattr(Transaction, group_field):
|
||||
select_fields.append(getattr(Transaction, group_field))
|
||||
|
||||
# Add aggregation field
|
||||
if aggregate_function == "count":
|
||||
if field == "*":
|
||||
agg_field = func.count()
|
||||
else:
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.count(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
elif aggregate_function == "sum":
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.sum(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
elif aggregate_function == "avg":
|
||||
field_attr = getattr(Transaction, field, None)
|
||||
if field_attr:
|
||||
agg_field = func.avg(field_attr)
|
||||
else:
|
||||
raise ValueError(f"Field {field} not found")
|
||||
else:
|
||||
raise ValueError(f"Unsupported aggregate function: {aggregate_function}")
|
||||
|
||||
select_fields.append(agg_field.label('agg_value'))
|
||||
|
||||
# Build query
|
||||
query = self.db.query(*select_fields)
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for filter_field, filter_value in filters.items():
|
||||
if hasattr(Transaction, filter_field):
|
||||
field_attr = getattr(Transaction, filter_field)
|
||||
if isinstance(filter_value, list):
|
||||
query = query.filter(field_attr.in_(filter_value))
|
||||
else:
|
||||
query = query.filter(field_attr == filter_value)
|
||||
|
||||
# Apply date filters
|
||||
if date_from:
|
||||
query = query.filter(Transaction.created_at >= date_from)
|
||||
if date_to:
|
||||
query = query.filter(Transaction.created_at <= date_to)
|
||||
|
||||
# Group by time period and other fields
|
||||
query = query.group_by(time_expr)
|
||||
if group_by:
|
||||
for group_field in group_by:
|
||||
if hasattr(Transaction, group_field):
|
||||
query = query.group_by(getattr(Transaction, group_field))
|
||||
|
||||
# Order by time period
|
||||
query = query.order_by(time_expr)
|
||||
|
||||
return query
|
||||
|
||||
def _format_aggregation_results(
|
||||
self,
|
||||
results: List[Tuple],
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
group_by: Optional[List[str]]
|
||||
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Format aggregation results into a consistent structure.
|
||||
"""
|
||||
if not group_by:
|
||||
# Single result
|
||||
if results:
|
||||
return {"value": float(results[0].agg_value) if results[0].agg_value else 0}
|
||||
else:
|
||||
return {"value": 0}
|
||||
|
||||
# Multiple results with grouping
|
||||
formatted = []
|
||||
for result in results:
|
||||
result_dict = {}
|
||||
|
||||
# Add group by fields
|
||||
for i, group_field in enumerate(group_by):
|
||||
result_dict[group_field] = getattr(result, group_field, None)
|
||||
|
||||
# Add aggregated value
|
||||
result_dict["value"] = float(result.agg_value) if result.agg_value else 0
|
||||
|
||||
formatted.append(result_dict)
|
||||
|
||||
return formatted
|
||||
|
||||
def _format_time_aggregation_results(
|
||||
self,
|
||||
results: List[Tuple],
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
time_period: str,
|
||||
group_by: Optional[List[str]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Format time-based aggregation results.
|
||||
"""
|
||||
formatted = []
|
||||
for result in results:
|
||||
result_dict = {
|
||||
"time_period": result.time_period,
|
||||
"value": float(result.agg_value) if result.agg_value else 0
|
||||
}
|
||||
|
||||
# Add group by fields if present
|
||||
if group_by:
|
||||
for group_field in group_by:
|
||||
result_dict[group_field] = getattr(result, group_field, None)
|
||||
|
||||
formatted.append(result_dict)
|
||||
|
||||
return formatted
|
||||
|
||||
def _generate_cache_key(
|
||||
self,
|
||||
aggregate_function: str,
|
||||
field: str,
|
||||
group_by: Optional[List[str]],
|
||||
filters: Optional[Dict[str, Any]],
|
||||
time_window: Optional[str],
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime]
|
||||
) -> str:
|
||||
"""
|
||||
Generate cache key for aggregation.
|
||||
"""
|
||||
key_parts = [
|
||||
f"agg:{aggregate_function}",
|
||||
f"field:{field}",
|
||||
f"group_by:{':'.join(group_by) if group_by else 'none'}",
|
||||
f"filters:{json.dumps(filters, sort_keys=True) if filters else 'none'}",
|
||||
f"time_window:{time_window or 'none'}",
|
||||
f"date_from:{date_from.isoformat() if date_from else 'none'}",
|
||||
f"date_to:{date_to.isoformat() if date_to else 'none'}"
|
||||
]
|
||||
return "|".join(key_parts)
|
||||
|
||||
def _get_cached_result(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached aggregation result.
|
||||
"""
|
||||
cached = self.db.query(AggregateCache).filter(
|
||||
AggregateCache.cache_key == cache_key,
|
||||
AggregateCache.expires_at > datetime.utcnow()
|
||||
).first()
|
||||
|
||||
if cached:
|
||||
result = json.loads(cached.cache_value)
|
||||
result["metadata"]["cache_hit"] = True
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def _cache_result(
|
||||
self,
|
||||
cache_key: str,
|
||||
result: Dict[str, Any],
|
||||
time_window: Optional[str]
|
||||
):
|
||||
"""
|
||||
Cache aggregation result.
|
||||
"""
|
||||
# Determine cache duration
|
||||
if time_window:
|
||||
time_delta = self._parse_time_window(time_window)
|
||||
# Cache for 10% of time window, min 1 minute, max 1 hour
|
||||
cache_duration = max(min(time_delta * 0.1, timedelta(hours=1)), timedelta(minutes=1))
|
||||
else:
|
||||
cache_duration = timedelta(minutes=5)
|
||||
|
||||
expires_at = datetime.utcnow() + cache_duration
|
||||
|
||||
# Upsert cache entry
|
||||
existing = self.db.query(AggregateCache).filter(
|
||||
AggregateCache.cache_key == cache_key
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.cache_value = json.dumps(result)
|
||||
existing.expires_at = expires_at
|
||||
else:
|
||||
cache_entry = AggregateCache(
|
||||
cache_key=cache_key,
|
||||
cache_value=json.dumps(result),
|
||||
expires_at=expires_at
|
||||
)
|
||||
self.db.add(cache_entry)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def _parse_time_window(self, time_window: str) -> timedelta:
|
||||
"""
|
||||
Parse time window string into timedelta.
|
||||
"""
|
||||
if time_window.endswith('h'):
|
||||
hours = int(time_window[:-1])
|
||||
return timedelta(hours=hours)
|
||||
elif time_window.endswith('d'):
|
||||
days = int(time_window[:-1])
|
||||
return timedelta(days=days)
|
||||
elif time_window.endswith('m'):
|
||||
minutes = int(time_window[:-1])
|
||||
return timedelta(minutes=minutes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported time window format: {time_window}")
|
||||
|
||||
def cleanup_expired_cache(self):
|
||||
"""
|
||||
Remove expired cache entries.
|
||||
"""
|
||||
self.db.query(AggregateCache).filter(
|
||||
AggregateCache.expires_at < datetime.utcnow()
|
||||
).delete()
|
||||
self.db.commit()
|
Loading…
x
Reference in New Issue
Block a user