
- Created FastAPI application with transaction ingestion endpoints - Built dynamic rule engine supporting velocity checks and aggregations - Implemented real-time and batch screening capabilities - Added rule management with versioning and rollback functionality - Created comprehensive audit and reporting endpoints with pagination - Set up SQLite database with proper migrations using Alembic - Added intelligent caching for aggregate computations - Included extensive API documentation and example rule definitions - Configured CORS, health endpoints, and proper error handling - Added support for time-windowed aggregations (sum, count, avg, max, min) - Built background processing for high-volume batch screening - Implemented field-agnostic rule conditions with flexible operators Features include transaction ingestion, rule CRUD operations, real-time screening, batch processing, aggregation computations, and comprehensive reporting capabilities suitable for fintech fraud monitoring systems.
371 lines
14 KiB
Python
371 lines
14 KiB
Python
import json
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import func, desc
|
|
|
|
from app.models.transaction import Transaction
|
|
from app.models.rule import Rule
|
|
from app.models.screening import AggregateCache
|
|
from app.core.schemas import RuleCondition, ScreeningResultBase
|
|
|
|
class RuleEngine:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def evaluate_transaction(self, transaction_id: str, rule_ids: Optional[List[int]] = None) -> List[ScreeningResultBase]:
|
|
"""
|
|
Evaluate a transaction against all active rules or specific rules.
|
|
"""
|
|
# Get transaction
|
|
transaction = self.db.query(Transaction).filter(Transaction.transaction_id == transaction_id).first()
|
|
if not transaction:
|
|
raise ValueError(f"Transaction {transaction_id} not found")
|
|
|
|
# Get rules to evaluate
|
|
query = self.db.query(Rule).filter(Rule.is_active)
|
|
if rule_ids:
|
|
query = query.filter(Rule.id.in_(rule_ids))
|
|
|
|
rules = query.order_by(desc(Rule.priority)).all()
|
|
|
|
results = []
|
|
for rule in rules:
|
|
result = self._evaluate_rule_against_transaction(transaction, rule)
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def _evaluate_rule_against_transaction(self, transaction: Transaction, rule: Rule) -> ScreeningResultBase:
|
|
"""
|
|
Evaluate a single rule against a transaction.
|
|
"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
conditions = json.loads(rule.conditions)
|
|
actions = json.loads(rule.actions)
|
|
|
|
# Evaluate all conditions
|
|
condition_results = []
|
|
aggregated_data = {}
|
|
|
|
for condition in conditions:
|
|
condition_obj = RuleCondition(**condition)
|
|
result, agg_data = self._evaluate_condition(transaction, condition_obj)
|
|
condition_results.append(result)
|
|
if agg_data:
|
|
aggregated_data.update(agg_data)
|
|
|
|
# All conditions must be true for the rule to trigger
|
|
rule_triggered = all(condition_results)
|
|
|
|
# Determine status and risk score
|
|
if rule_triggered:
|
|
status = "flagged"
|
|
risk_score = self._calculate_risk_score(actions)
|
|
details = {
|
|
"rule_triggered": True,
|
|
"conditions_met": len(condition_results),
|
|
"evaluation_time_ms": (time.time() - start_time) * 1000,
|
|
"actions": actions
|
|
}
|
|
else:
|
|
status = "clean"
|
|
risk_score = 0.0
|
|
details = {
|
|
"rule_triggered": False,
|
|
"conditions_met": sum(condition_results),
|
|
"total_conditions": len(condition_results),
|
|
"evaluation_time_ms": (time.time() - start_time) * 1000
|
|
}
|
|
|
|
return ScreeningResultBase(
|
|
transaction_id=transaction.transaction_id,
|
|
rule_id=rule.id,
|
|
rule_name=rule.name,
|
|
rule_version=rule.version,
|
|
status=status,
|
|
risk_score=risk_score,
|
|
details=details,
|
|
aggregated_data=aggregated_data,
|
|
screening_type="real_time"
|
|
)
|
|
|
|
except Exception as e:
|
|
return ScreeningResultBase(
|
|
transaction_id=transaction.transaction_id,
|
|
rule_id=rule.id,
|
|
rule_name=rule.name,
|
|
rule_version=rule.version,
|
|
status="error",
|
|
risk_score=0.0,
|
|
details={"error": str(e), "evaluation_time_ms": (time.time() - start_time) * 1000},
|
|
aggregated_data={},
|
|
screening_type="real_time"
|
|
)
|
|
|
|
def _evaluate_condition(self, transaction: Transaction, condition: RuleCondition) -> Tuple[bool, Dict[str, Any]]:
|
|
"""
|
|
Evaluate a single condition against a transaction.
|
|
"""
|
|
aggregated_data = {}
|
|
|
|
# If this is an aggregate condition, compute the aggregate first
|
|
if condition.aggregate_function:
|
|
aggregate_value, agg_data = self._compute_aggregate(transaction, condition)
|
|
aggregated_data = agg_data
|
|
value_to_compare = aggregate_value
|
|
else:
|
|
# Get the field value from transaction
|
|
value_to_compare = self._get_transaction_field_value(transaction, condition.field)
|
|
|
|
# Perform comparison
|
|
return self._compare_values(value_to_compare, condition.operator, condition.value), aggregated_data
|
|
|
|
def _get_transaction_field_value(self, transaction: Transaction, field: str) -> Any:
|
|
"""
|
|
Get a field value from a transaction object.
|
|
"""
|
|
if hasattr(transaction, field):
|
|
return getattr(transaction, field)
|
|
|
|
# Check in metadata if field not found in main attributes
|
|
if transaction.metadata:
|
|
metadata = json.loads(transaction.metadata)
|
|
return metadata.get(field)
|
|
|
|
return None
|
|
|
|
def _compute_aggregate(self, transaction: Transaction, condition: RuleCondition) -> Tuple[Any, Dict[str, Any]]:
|
|
"""
|
|
Compute aggregate values based on condition parameters.
|
|
"""
|
|
# Generate cache key
|
|
cache_key = self._generate_cache_key(transaction, condition)
|
|
|
|
# Check cache first
|
|
cached_result = self._get_cached_aggregate(cache_key)
|
|
if cached_result:
|
|
return cached_result["value"], cached_result
|
|
|
|
# Compute aggregate
|
|
query = self.db.query(Transaction)
|
|
|
|
# Apply time window filter
|
|
if condition.time_window:
|
|
time_delta = self._parse_time_window(condition.time_window)
|
|
cutoff_time = datetime.utcnow() - time_delta
|
|
query = query.filter(Transaction.created_at >= cutoff_time)
|
|
|
|
# Apply group by filters
|
|
if condition.group_by:
|
|
for group_field in condition.group_by:
|
|
group_value = self._get_transaction_field_value(transaction, group_field)
|
|
if group_value is not None:
|
|
if group_field == "user_id":
|
|
query = query.filter(Transaction.user_id == group_value)
|
|
elif group_field == "account_id":
|
|
query = query.filter(Transaction.account_id == group_value)
|
|
elif group_field == "device_id":
|
|
query = query.filter(Transaction.device_id == group_value)
|
|
# Add more group by fields as needed
|
|
|
|
# Apply aggregate function
|
|
if condition.aggregate_function == "count":
|
|
if condition.field == "*":
|
|
result = query.count()
|
|
else:
|
|
# Count non-null values of specific field
|
|
field_attr = getattr(Transaction, condition.field, None)
|
|
if field_attr:
|
|
result = query.filter(field_attr.isnot(None)).count()
|
|
else:
|
|
result = 0
|
|
|
|
elif condition.aggregate_function == "sum":
|
|
field_attr = getattr(Transaction, condition.field, None)
|
|
if field_attr:
|
|
result = query.with_entities(func.sum(field_attr)).scalar() or 0
|
|
else:
|
|
result = 0
|
|
|
|
elif condition.aggregate_function == "avg":
|
|
field_attr = getattr(Transaction, condition.field, None)
|
|
if field_attr:
|
|
result = query.with_entities(func.avg(field_attr)).scalar() or 0
|
|
else:
|
|
result = 0
|
|
|
|
elif condition.aggregate_function == "max":
|
|
field_attr = getattr(Transaction, condition.field, None)
|
|
if field_attr:
|
|
result = query.with_entities(func.max(field_attr)).scalar() or 0
|
|
else:
|
|
result = 0
|
|
|
|
elif condition.aggregate_function == "min":
|
|
field_attr = getattr(Transaction, condition.field, None)
|
|
if field_attr:
|
|
result = query.with_entities(func.min(field_attr)).scalar() or 0
|
|
else:
|
|
result = 0
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported aggregate function: {condition.aggregate_function}")
|
|
|
|
# Cache the result
|
|
aggregate_data = {
|
|
"value": result,
|
|
"function": condition.aggregate_function,
|
|
"field": condition.field,
|
|
"time_window": condition.time_window,
|
|
"group_by": condition.group_by,
|
|
"computed_at": datetime.utcnow().isoformat(),
|
|
"cache_key": cache_key
|
|
}
|
|
|
|
self._cache_aggregate(cache_key, aggregate_data, condition.time_window)
|
|
|
|
return result, aggregate_data
|
|
|
|
def _generate_cache_key(self, transaction: Transaction, condition: RuleCondition) -> str:
|
|
"""
|
|
Generate a cache key for aggregate computation.
|
|
"""
|
|
key_parts = [
|
|
condition.aggregate_function,
|
|
condition.field,
|
|
condition.time_window or "no_window"
|
|
]
|
|
|
|
if condition.group_by:
|
|
for group_field in condition.group_by:
|
|
group_value = self._get_transaction_field_value(transaction, group_field)
|
|
key_parts.append(f"{group_field}:{group_value}")
|
|
|
|
return ":".join(str(part) for part in key_parts)
|
|
|
|
def _get_cached_aggregate(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Retrieve cached aggregate result if not expired.
|
|
"""
|
|
cached = self.db.query(AggregateCache).filter(
|
|
AggregateCache.cache_key == cache_key,
|
|
AggregateCache.expires_at > datetime.utcnow()
|
|
).first()
|
|
|
|
if cached:
|
|
return json.loads(cached.cache_value)
|
|
|
|
return None
|
|
|
|
def _cache_aggregate(self, cache_key: str, data: Dict[str, Any], time_window: Optional[str]):
|
|
"""
|
|
Cache aggregate result with appropriate expiration.
|
|
"""
|
|
# Determine cache expiration based on time window
|
|
if time_window:
|
|
time_delta = self._parse_time_window(time_window)
|
|
# Cache for 10% of the time window, minimum 1 minute, maximum 1 hour
|
|
cache_duration = max(min(time_delta * 0.1, timedelta(hours=1)), timedelta(minutes=1))
|
|
else:
|
|
cache_duration = timedelta(minutes=5) # Default 5 minutes
|
|
|
|
expires_at = datetime.utcnow() + cache_duration
|
|
|
|
# Upsert cache entry
|
|
existing = self.db.query(AggregateCache).filter(AggregateCache.cache_key == cache_key).first()
|
|
if existing:
|
|
existing.cache_value = json.dumps(data)
|
|
existing.expires_at = expires_at
|
|
else:
|
|
cache_entry = AggregateCache(
|
|
cache_key=cache_key,
|
|
cache_value=json.dumps(data),
|
|
expires_at=expires_at
|
|
)
|
|
self.db.add(cache_entry)
|
|
|
|
self.db.commit()
|
|
|
|
def _parse_time_window(self, time_window: str) -> timedelta:
|
|
"""
|
|
Parse time window string into timedelta.
|
|
Supports: 1h, 24h, 7d, 30d, etc.
|
|
"""
|
|
if time_window.endswith('h'):
|
|
hours = int(time_window[:-1])
|
|
return timedelta(hours=hours)
|
|
elif time_window.endswith('d'):
|
|
days = int(time_window[:-1])
|
|
return timedelta(days=days)
|
|
elif time_window.endswith('m'):
|
|
minutes = int(time_window[:-1])
|
|
return timedelta(minutes=minutes)
|
|
else:
|
|
raise ValueError(f"Unsupported time window format: {time_window}")
|
|
|
|
def _compare_values(self, left: Any, operator: str, right: Any) -> bool:
|
|
"""
|
|
Compare two values using the specified operator.
|
|
"""
|
|
if left is None:
|
|
return False
|
|
|
|
try:
|
|
if operator == "eq":
|
|
return left == right
|
|
elif operator == "ne":
|
|
return left != right
|
|
elif operator == "gt":
|
|
return float(left) > float(right)
|
|
elif operator == "gte":
|
|
return float(left) >= float(right)
|
|
elif operator == "lt":
|
|
return float(left) < float(right)
|
|
elif operator == "lte":
|
|
return float(left) <= float(right)
|
|
elif operator == "in":
|
|
return left in right
|
|
elif operator == "not_in":
|
|
return left not in right
|
|
elif operator == "contains":
|
|
return str(right).lower() in str(left).lower()
|
|
elif operator == "starts_with":
|
|
return str(left).lower().startswith(str(right).lower())
|
|
elif operator == "ends_with":
|
|
return str(left).lower().endswith(str(right).lower())
|
|
else:
|
|
raise ValueError(f"Unsupported operator: {operator}")
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
def _calculate_risk_score(self, actions: List[Dict[str, Any]]) -> float:
|
|
"""
|
|
Calculate risk score based on rule actions.
|
|
"""
|
|
max_score = 0.0
|
|
|
|
for action in actions:
|
|
if action.get("action_type") == "score":
|
|
score = action.get("parameters", {}).get("risk_score", 0.0)
|
|
max_score = max(max_score, score)
|
|
elif action.get("action_type") == "flag":
|
|
score = action.get("parameters", {}).get("risk_score", 50.0)
|
|
max_score = max(max_score, score)
|
|
elif action.get("action_type") == "block":
|
|
score = action.get("parameters", {}).get("risk_score", 100.0)
|
|
max_score = max(max_score, score)
|
|
|
|
return max_score
|
|
|
|
def cleanup_expired_cache(self):
|
|
"""
|
|
Remove expired cache entries.
|
|
"""
|
|
self.db.query(AggregateCache).filter(
|
|
AggregateCache.expires_at < datetime.utcnow()
|
|
).delete()
|
|
self.db.commit() |