from fastapi import Request import re class ValidationMiddleware: def __init__(self): self.suspicious_patterns = [ r']*>.*?', # XSS r'union\s+select', # SQL injection r'drop\s+table', # SQL injection r'insert\s+into', # SQL injection r'delete\s+from', # SQL injection r'update\s+.*\s+set', # SQL injection r'exec\s*\(', # Command injection r'eval\s*\(', # Code injection r'javascript:', # XSS r'vbscript:', # XSS r'data:text/html', # Data URL XSS ] self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_patterns] def validate_input(self, text: str) -> bool: """Check if input contains suspicious patterns""" if not text: return True for pattern in self.compiled_patterns: if pattern.search(text): return False return True def sanitize_headers(self, headers: dict) -> bool: """Validate request headers""" dangerous_headers = ['x-forwarded-host', 'x-original-url', 'x-rewrite-url'] for header_name, header_value in headers.items(): if header_name.lower() in dangerous_headers: if not self.validate_input(str(header_value)): return False # Check for header injection if '\n' in str(header_value) or '\r' in str(header_value): return False return True def validate_json_payload(self, payload: dict) -> bool: """Recursively validate JSON payload""" if isinstance(payload, dict): for key, value in payload.items(): if isinstance(value, str): if not self.validate_input(value): return False elif isinstance(value, (dict, list)): if not self.validate_json_payload(value): return False elif isinstance(payload, list): for item in payload: if isinstance(item, str): if not self.validate_input(item): return False elif isinstance(item, (dict, list)): if not self.validate_json_payload(item): return False return True validation_middleware = ValidationMiddleware() def validate_request_size(request: Request) -> bool: """Validate request size to prevent DoS attacks""" content_length = request.headers.get('content-length') if content_length: try: size = int(content_length) # Limit to 10MB if size > 10 * 1024 * 1024: return False except ValueError: return False return True def security_headers_middleware(request: Request, call_next): """Add security headers to responses""" response = call_next(request) # Add security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Content-Security-Policy"] = "default-src 'self'" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" return response