Automated Action 4730c37915 Implement comprehensive transaction fraud monitoring API
- Created FastAPI application with transaction ingestion endpoints
- Built dynamic rule engine supporting velocity checks and aggregations
- Implemented real-time and batch screening capabilities
- Added rule management with versioning and rollback functionality
- Created comprehensive audit and reporting endpoints with pagination
- Set up SQLite database with proper migrations using Alembic
- Added intelligent caching for aggregate computations
- Included extensive API documentation and example rule definitions
- Configured CORS, health endpoints, and proper error handling
- Added support for time-windowed aggregations (sum, count, avg, max, min)
- Built background processing for high-volume batch screening
- Implemented field-agnostic rule conditions with flexible operators

Features include transaction ingestion, rule CRUD operations, real-time screening,
batch processing, aggregation computations, and comprehensive reporting capabilities
suitable for fintech fraud monitoring systems.
2025-06-27 16:00:48 +00:00

371 lines
14 KiB
Python

import json
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from app.models.transaction import Transaction
from app.models.rule import Rule
from app.models.screening import AggregateCache
from app.core.schemas import RuleCondition, ScreeningResultBase
class RuleEngine:
def __init__(self, db: Session):
self.db = db
def evaluate_transaction(self, transaction_id: str, rule_ids: Optional[List[int]] = None) -> List[ScreeningResultBase]:
"""
Evaluate a transaction against all active rules or specific rules.
"""
# Get transaction
transaction = self.db.query(Transaction).filter(Transaction.transaction_id == transaction_id).first()
if not transaction:
raise ValueError(f"Transaction {transaction_id} not found")
# Get rules to evaluate
query = self.db.query(Rule).filter(Rule.is_active)
if rule_ids:
query = query.filter(Rule.id.in_(rule_ids))
rules = query.order_by(desc(Rule.priority)).all()
results = []
for rule in rules:
result = self._evaluate_rule_against_transaction(transaction, rule)
results.append(result)
return results
def _evaluate_rule_against_transaction(self, transaction: Transaction, rule: Rule) -> ScreeningResultBase:
"""
Evaluate a single rule against a transaction.
"""
start_time = time.time()
try:
conditions = json.loads(rule.conditions)
actions = json.loads(rule.actions)
# Evaluate all conditions
condition_results = []
aggregated_data = {}
for condition in conditions:
condition_obj = RuleCondition(**condition)
result, agg_data = self._evaluate_condition(transaction, condition_obj)
condition_results.append(result)
if agg_data:
aggregated_data.update(agg_data)
# All conditions must be true for the rule to trigger
rule_triggered = all(condition_results)
# Determine status and risk score
if rule_triggered:
status = "flagged"
risk_score = self._calculate_risk_score(actions)
details = {
"rule_triggered": True,
"conditions_met": len(condition_results),
"evaluation_time_ms": (time.time() - start_time) * 1000,
"actions": actions
}
else:
status = "clean"
risk_score = 0.0
details = {
"rule_triggered": False,
"conditions_met": sum(condition_results),
"total_conditions": len(condition_results),
"evaluation_time_ms": (time.time() - start_time) * 1000
}
return ScreeningResultBase(
transaction_id=transaction.transaction_id,
rule_id=rule.id,
rule_name=rule.name,
rule_version=rule.version,
status=status,
risk_score=risk_score,
details=details,
aggregated_data=aggregated_data,
screening_type="real_time"
)
except Exception as e:
return ScreeningResultBase(
transaction_id=transaction.transaction_id,
rule_id=rule.id,
rule_name=rule.name,
rule_version=rule.version,
status="error",
risk_score=0.0,
details={"error": str(e), "evaluation_time_ms": (time.time() - start_time) * 1000},
aggregated_data={},
screening_type="real_time"
)
def _evaluate_condition(self, transaction: Transaction, condition: RuleCondition) -> Tuple[bool, Dict[str, Any]]:
"""
Evaluate a single condition against a transaction.
"""
aggregated_data = {}
# If this is an aggregate condition, compute the aggregate first
if condition.aggregate_function:
aggregate_value, agg_data = self._compute_aggregate(transaction, condition)
aggregated_data = agg_data
value_to_compare = aggregate_value
else:
# Get the field value from transaction
value_to_compare = self._get_transaction_field_value(transaction, condition.field)
# Perform comparison
return self._compare_values(value_to_compare, condition.operator, condition.value), aggregated_data
def _get_transaction_field_value(self, transaction: Transaction, field: str) -> Any:
"""
Get a field value from a transaction object.
"""
if hasattr(transaction, field):
return getattr(transaction, field)
# Check in metadata if field not found in main attributes
if transaction.metadata:
metadata = json.loads(transaction.metadata)
return metadata.get(field)
return None
def _compute_aggregate(self, transaction: Transaction, condition: RuleCondition) -> Tuple[Any, Dict[str, Any]]:
"""
Compute aggregate values based on condition parameters.
"""
# Generate cache key
cache_key = self._generate_cache_key(transaction, condition)
# Check cache first
cached_result = self._get_cached_aggregate(cache_key)
if cached_result:
return cached_result["value"], cached_result
# Compute aggregate
query = self.db.query(Transaction)
# Apply time window filter
if condition.time_window:
time_delta = self._parse_time_window(condition.time_window)
cutoff_time = datetime.utcnow() - time_delta
query = query.filter(Transaction.created_at >= cutoff_time)
# Apply group by filters
if condition.group_by:
for group_field in condition.group_by:
group_value = self._get_transaction_field_value(transaction, group_field)
if group_value is not None:
if group_field == "user_id":
query = query.filter(Transaction.user_id == group_value)
elif group_field == "account_id":
query = query.filter(Transaction.account_id == group_value)
elif group_field == "device_id":
query = query.filter(Transaction.device_id == group_value)
# Add more group by fields as needed
# Apply aggregate function
if condition.aggregate_function == "count":
if condition.field == "*":
result = query.count()
else:
# Count non-null values of specific field
field_attr = getattr(Transaction, condition.field, None)
if field_attr:
result = query.filter(field_attr.isnot(None)).count()
else:
result = 0
elif condition.aggregate_function == "sum":
field_attr = getattr(Transaction, condition.field, None)
if field_attr:
result = query.with_entities(func.sum(field_attr)).scalar() or 0
else:
result = 0
elif condition.aggregate_function == "avg":
field_attr = getattr(Transaction, condition.field, None)
if field_attr:
result = query.with_entities(func.avg(field_attr)).scalar() or 0
else:
result = 0
elif condition.aggregate_function == "max":
field_attr = getattr(Transaction, condition.field, None)
if field_attr:
result = query.with_entities(func.max(field_attr)).scalar() or 0
else:
result = 0
elif condition.aggregate_function == "min":
field_attr = getattr(Transaction, condition.field, None)
if field_attr:
result = query.with_entities(func.min(field_attr)).scalar() or 0
else:
result = 0
else:
raise ValueError(f"Unsupported aggregate function: {condition.aggregate_function}")
# Cache the result
aggregate_data = {
"value": result,
"function": condition.aggregate_function,
"field": condition.field,
"time_window": condition.time_window,
"group_by": condition.group_by,
"computed_at": datetime.utcnow().isoformat(),
"cache_key": cache_key
}
self._cache_aggregate(cache_key, aggregate_data, condition.time_window)
return result, aggregate_data
def _generate_cache_key(self, transaction: Transaction, condition: RuleCondition) -> str:
"""
Generate a cache key for aggregate computation.
"""
key_parts = [
condition.aggregate_function,
condition.field,
condition.time_window or "no_window"
]
if condition.group_by:
for group_field in condition.group_by:
group_value = self._get_transaction_field_value(transaction, group_field)
key_parts.append(f"{group_field}:{group_value}")
return ":".join(str(part) for part in key_parts)
def _get_cached_aggregate(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""
Retrieve cached aggregate result if not expired.
"""
cached = self.db.query(AggregateCache).filter(
AggregateCache.cache_key == cache_key,
AggregateCache.expires_at > datetime.utcnow()
).first()
if cached:
return json.loads(cached.cache_value)
return None
def _cache_aggregate(self, cache_key: str, data: Dict[str, Any], time_window: Optional[str]):
"""
Cache aggregate result with appropriate expiration.
"""
# Determine cache expiration based on time window
if time_window:
time_delta = self._parse_time_window(time_window)
# Cache for 10% of the time window, minimum 1 minute, maximum 1 hour
cache_duration = max(min(time_delta * 0.1, timedelta(hours=1)), timedelta(minutes=1))
else:
cache_duration = timedelta(minutes=5) # Default 5 minutes
expires_at = datetime.utcnow() + cache_duration
# Upsert cache entry
existing = self.db.query(AggregateCache).filter(AggregateCache.cache_key == cache_key).first()
if existing:
existing.cache_value = json.dumps(data)
existing.expires_at = expires_at
else:
cache_entry = AggregateCache(
cache_key=cache_key,
cache_value=json.dumps(data),
expires_at=expires_at
)
self.db.add(cache_entry)
self.db.commit()
def _parse_time_window(self, time_window: str) -> timedelta:
"""
Parse time window string into timedelta.
Supports: 1h, 24h, 7d, 30d, etc.
"""
if time_window.endswith('h'):
hours = int(time_window[:-1])
return timedelta(hours=hours)
elif time_window.endswith('d'):
days = int(time_window[:-1])
return timedelta(days=days)
elif time_window.endswith('m'):
minutes = int(time_window[:-1])
return timedelta(minutes=minutes)
else:
raise ValueError(f"Unsupported time window format: {time_window}")
def _compare_values(self, left: Any, operator: str, right: Any) -> bool:
"""
Compare two values using the specified operator.
"""
if left is None:
return False
try:
if operator == "eq":
return left == right
elif operator == "ne":
return left != right
elif operator == "gt":
return float(left) > float(right)
elif operator == "gte":
return float(left) >= float(right)
elif operator == "lt":
return float(left) < float(right)
elif operator == "lte":
return float(left) <= float(right)
elif operator == "in":
return left in right
elif operator == "not_in":
return left not in right
elif operator == "contains":
return str(right).lower() in str(left).lower()
elif operator == "starts_with":
return str(left).lower().startswith(str(right).lower())
elif operator == "ends_with":
return str(left).lower().endswith(str(right).lower())
else:
raise ValueError(f"Unsupported operator: {operator}")
except (ValueError, TypeError):
return False
def _calculate_risk_score(self, actions: List[Dict[str, Any]]) -> float:
"""
Calculate risk score based on rule actions.
"""
max_score = 0.0
for action in actions:
if action.get("action_type") == "score":
score = action.get("parameters", {}).get("risk_score", 0.0)
max_score = max(max_score, score)
elif action.get("action_type") == "flag":
score = action.get("parameters", {}).get("risk_score", 50.0)
max_score = max(max_score, score)
elif action.get("action_type") == "block":
score = action.get("parameters", {}).get("risk_score", 100.0)
max_score = max(max_score, score)
return max_score
def cleanup_expired_cache(self):
"""
Remove expired cache entries.
"""
self.db.query(AggregateCache).filter(
AggregateCache.expires_at < datetime.utcnow()
).delete()
self.db.commit()