import json import time from datetime import datetime, timedelta from typing import List, Dict, Any, Optional, Tuple from sqlalchemy.orm import Session from sqlalchemy import func, desc from app.models.transaction import Transaction from app.models.rule import Rule from app.models.screening import AggregateCache from app.core.schemas import RuleCondition, ScreeningResultBase class RuleEngine: def __init__(self, db: Session): self.db = db def evaluate_transaction(self, transaction_id: str, rule_ids: Optional[List[int]] = None) -> List[ScreeningResultBase]: """ Evaluate a transaction against all active rules or specific rules. """ # Get transaction transaction = self.db.query(Transaction).filter(Transaction.transaction_id == transaction_id).first() if not transaction: raise ValueError(f"Transaction {transaction_id} not found") # Get rules to evaluate query = self.db.query(Rule).filter(Rule.is_active) if rule_ids: query = query.filter(Rule.id.in_(rule_ids)) rules = query.order_by(desc(Rule.priority)).all() results = [] for rule in rules: result = self._evaluate_rule_against_transaction(transaction, rule) results.append(result) return results def _evaluate_rule_against_transaction(self, transaction: Transaction, rule: Rule) -> ScreeningResultBase: """ Evaluate a single rule against a transaction. """ start_time = time.time() try: conditions = json.loads(rule.conditions) actions = json.loads(rule.actions) # Evaluate all conditions condition_results = [] aggregated_data = {} for condition in conditions: condition_obj = RuleCondition(**condition) result, agg_data = self._evaluate_condition(transaction, condition_obj) condition_results.append(result) if agg_data: aggregated_data.update(agg_data) # All conditions must be true for the rule to trigger rule_triggered = all(condition_results) # Determine status and risk score if rule_triggered: status = "flagged" risk_score = self._calculate_risk_score(actions) details = { "rule_triggered": True, "conditions_met": len(condition_results), "evaluation_time_ms": (time.time() - start_time) * 1000, "actions": actions } else: status = "clean" risk_score = 0.0 details = { "rule_triggered": False, "conditions_met": sum(condition_results), "total_conditions": len(condition_results), "evaluation_time_ms": (time.time() - start_time) * 1000 } return ScreeningResultBase( transaction_id=transaction.transaction_id, rule_id=rule.id, rule_name=rule.name, rule_version=rule.version, status=status, risk_score=risk_score, details=details, aggregated_data=aggregated_data, screening_type="real_time" ) except Exception as e: return ScreeningResultBase( transaction_id=transaction.transaction_id, rule_id=rule.id, rule_name=rule.name, rule_version=rule.version, status="error", risk_score=0.0, details={"error": str(e), "evaluation_time_ms": (time.time() - start_time) * 1000}, aggregated_data={}, screening_type="real_time" ) def _evaluate_condition(self, transaction: Transaction, condition: RuleCondition) -> Tuple[bool, Dict[str, Any]]: """ Evaluate a single condition against a transaction. """ aggregated_data = {} # If this is an aggregate condition, compute the aggregate first if condition.aggregate_function: aggregate_value, agg_data = self._compute_aggregate(transaction, condition) aggregated_data = agg_data value_to_compare = aggregate_value else: # Get the field value from transaction value_to_compare = self._get_transaction_field_value(transaction, condition.field) # Perform comparison return self._compare_values(value_to_compare, condition.operator, condition.value), aggregated_data def _get_transaction_field_value(self, transaction: Transaction, field: str) -> Any: """ Get a field value from a transaction object. """ if hasattr(transaction, field): return getattr(transaction, field) # Check in metadata if field not found in main attributes if transaction.metadata: metadata = json.loads(transaction.metadata) return metadata.get(field) return None def _compute_aggregate(self, transaction: Transaction, condition: RuleCondition) -> Tuple[Any, Dict[str, Any]]: """ Compute aggregate values based on condition parameters. """ # Generate cache key cache_key = self._generate_cache_key(transaction, condition) # Check cache first cached_result = self._get_cached_aggregate(cache_key) if cached_result: return cached_result["value"], cached_result # Compute aggregate query = self.db.query(Transaction) # Apply time window filter if condition.time_window: time_delta = self._parse_time_window(condition.time_window) cutoff_time = datetime.utcnow() - time_delta query = query.filter(Transaction.created_at >= cutoff_time) # Apply group by filters if condition.group_by: for group_field in condition.group_by: group_value = self._get_transaction_field_value(transaction, group_field) if group_value is not None: if group_field == "user_id": query = query.filter(Transaction.user_id == group_value) elif group_field == "account_id": query = query.filter(Transaction.account_id == group_value) elif group_field == "device_id": query = query.filter(Transaction.device_id == group_value) # Add more group by fields as needed # Apply aggregate function if condition.aggregate_function == "count": if condition.field == "*": result = query.count() else: # Count non-null values of specific field field_attr = getattr(Transaction, condition.field, None) if field_attr: result = query.filter(field_attr.isnot(None)).count() else: result = 0 elif condition.aggregate_function == "sum": field_attr = getattr(Transaction, condition.field, None) if field_attr: result = query.with_entities(func.sum(field_attr)).scalar() or 0 else: result = 0 elif condition.aggregate_function == "avg": field_attr = getattr(Transaction, condition.field, None) if field_attr: result = query.with_entities(func.avg(field_attr)).scalar() or 0 else: result = 0 elif condition.aggregate_function == "max": field_attr = getattr(Transaction, condition.field, None) if field_attr: result = query.with_entities(func.max(field_attr)).scalar() or 0 else: result = 0 elif condition.aggregate_function == "min": field_attr = getattr(Transaction, condition.field, None) if field_attr: result = query.with_entities(func.min(field_attr)).scalar() or 0 else: result = 0 else: raise ValueError(f"Unsupported aggregate function: {condition.aggregate_function}") # Cache the result aggregate_data = { "value": result, "function": condition.aggregate_function, "field": condition.field, "time_window": condition.time_window, "group_by": condition.group_by, "computed_at": datetime.utcnow().isoformat(), "cache_key": cache_key } self._cache_aggregate(cache_key, aggregate_data, condition.time_window) return result, aggregate_data def _generate_cache_key(self, transaction: Transaction, condition: RuleCondition) -> str: """ Generate a cache key for aggregate computation. """ key_parts = [ condition.aggregate_function, condition.field, condition.time_window or "no_window" ] if condition.group_by: for group_field in condition.group_by: group_value = self._get_transaction_field_value(transaction, group_field) key_parts.append(f"{group_field}:{group_value}") return ":".join(str(part) for part in key_parts) def _get_cached_aggregate(self, cache_key: str) -> Optional[Dict[str, Any]]: """ Retrieve cached aggregate result if not expired. """ cached = self.db.query(AggregateCache).filter( AggregateCache.cache_key == cache_key, AggregateCache.expires_at > datetime.utcnow() ).first() if cached: return json.loads(cached.cache_value) return None def _cache_aggregate(self, cache_key: str, data: Dict[str, Any], time_window: Optional[str]): """ Cache aggregate result with appropriate expiration. """ # Determine cache expiration based on time window if time_window: time_delta = self._parse_time_window(time_window) # Cache for 10% of the time window, minimum 1 minute, maximum 1 hour cache_duration = max(min(time_delta * 0.1, timedelta(hours=1)), timedelta(minutes=1)) else: cache_duration = timedelta(minutes=5) # Default 5 minutes expires_at = datetime.utcnow() + cache_duration # Upsert cache entry existing = self.db.query(AggregateCache).filter(AggregateCache.cache_key == cache_key).first() if existing: existing.cache_value = json.dumps(data) existing.expires_at = expires_at else: cache_entry = AggregateCache( cache_key=cache_key, cache_value=json.dumps(data), expires_at=expires_at ) self.db.add(cache_entry) self.db.commit() def _parse_time_window(self, time_window: str) -> timedelta: """ Parse time window string into timedelta. Supports: 1h, 24h, 7d, 30d, etc. """ if time_window.endswith('h'): hours = int(time_window[:-1]) return timedelta(hours=hours) elif time_window.endswith('d'): days = int(time_window[:-1]) return timedelta(days=days) elif time_window.endswith('m'): minutes = int(time_window[:-1]) return timedelta(minutes=minutes) else: raise ValueError(f"Unsupported time window format: {time_window}") def _compare_values(self, left: Any, operator: str, right: Any) -> bool: """ Compare two values using the specified operator. """ if left is None: return False try: if operator == "eq": return left == right elif operator == "ne": return left != right elif operator == "gt": return float(left) > float(right) elif operator == "gte": return float(left) >= float(right) elif operator == "lt": return float(left) < float(right) elif operator == "lte": return float(left) <= float(right) elif operator == "in": return left in right elif operator == "not_in": return left not in right elif operator == "contains": return str(right).lower() in str(left).lower() elif operator == "starts_with": return str(left).lower().startswith(str(right).lower()) elif operator == "ends_with": return str(left).lower().endswith(str(right).lower()) else: raise ValueError(f"Unsupported operator: {operator}") except (ValueError, TypeError): return False def _calculate_risk_score(self, actions: List[Dict[str, Any]]) -> float: """ Calculate risk score based on rule actions. """ max_score = 0.0 for action in actions: if action.get("action_type") == "score": score = action.get("parameters", {}).get("risk_score", 0.0) max_score = max(max_score, score) elif action.get("action_type") == "flag": score = action.get("parameters", {}).get("risk_score", 50.0) max_score = max(max_score, score) elif action.get("action_type") == "block": score = action.get("parameters", {}).get("risk_score", 100.0) max_score = max(max_score, score) return max_score def cleanup_expired_cache(self): """ Remove expired cache entries. """ self.db.query(AggregateCache).filter( AggregateCache.expires_at < datetime.utcnow() ).delete() self.db.commit()