import logging import time from collections.abc import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp logger = logging.getLogger(__name__) class RateLimiter: """ Simple in-memory rate limiter implementation. For production use, consider using Redis or another distributed store. """ def __init__(self, rate_limit_per_minute: int = 60): self.rate_limit_per_minute = rate_limit_per_minute self.requests: dict[str, dict[float, int]] = {} self.window_size = 60 # 1 minute in seconds def is_rate_limited(self, client_id: str) -> tuple[bool, dict]: """ Check if a client is rate limited. Args: client_id: Identifier for the client (usually IP address) Returns: Tuple of (is_limited, rate_limit_info) """ current_time = time.time() # Initialize client record if it doesn't exist if client_id not in self.requests: self.requests[client_id] = {} # Clean up old records self._cleanup(client_id, current_time) # Count recent requests recent_requests = sum(self.requests[client_id].values()) # Check if rate limit is exceeded is_limited = recent_requests >= self.rate_limit_per_minute # Update request count if not limited if not is_limited: self.requests[client_id][current_time] = self.requests[client_id].get(current_time, 0) + 1 # Calculate rate limit info remaining = max(0, self.rate_limit_per_minute - recent_requests) reset_at = current_time + self.window_size return is_limited, { "limit": self.rate_limit_per_minute, "remaining": remaining, "reset": int(reset_at), } def _cleanup(self, client_id: str, current_time: float) -> None: """ Clean up old records for a client. Args: client_id: Identifier for the client current_time: Current timestamp """ cutoff_time = current_time - self.window_size timestamps_to_remove = [ts for ts in self.requests[client_id].keys() if ts < cutoff_time] for ts in timestamps_to_remove: del self.requests[client_id][ts] class RateLimitMiddleware(BaseHTTPMiddleware): """ Middleware for rate limiting API requests. """ def __init__( self, app: ASGIApp, rate_limit_per_minute: int = 60, whitelist_paths: list | None = None, client_id_func: Callable[[Request], str] | None = None ): super().__init__(app) self.rate_limiter = RateLimiter(rate_limit_per_minute) self.whitelist_paths = whitelist_paths or ["/health", "/docs", "/redoc", "/openapi.json"] self.client_id_func = client_id_func or self._default_client_id async def dispatch(self, request: Request, call_next) -> Response: """ Process the request through rate limiting. Args: request: The incoming request call_next: The next handler in the middleware chain Returns: The response """ # Skip rate limiting for whitelisted paths path = request.url.path if any(path.startswith(wl_path) for wl_path in self.whitelist_paths): return await call_next(request) # Get client identifier client_id = self.client_id_func(request) # Check if rate limited is_limited, rate_limit_info = self.rate_limiter.is_rate_limited(client_id) # If rate limited, return 429 Too Many Requests if is_limited: logger.warning(f"Rate limit exceeded for client {client_id}") response = Response( content={"detail": "Rate limit exceeded"}, status_code=429, media_type="application/json" ) else: # Process the request normally response = await call_next(request) # Add rate limit headers to response response.headers["X-RateLimit-Limit"] = str(rate_limit_info["limit"]) response.headers["X-RateLimit-Remaining"] = str(rate_limit_info["remaining"]) response.headers["X-RateLimit-Reset"] = str(rate_limit_info["reset"]) return response def _default_client_id(self, request: Request) -> str: """ Default function to extract client identifier from request. Uses the client's IP address. Args: request: The incoming request Returns: Client identifier string """ # Try to get real IP from forwarded header (for proxies) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: # Get the first IP in the list (client IP) return forwarded_for.split(",")[0].strip() # Fall back to the direct client address return request.client.host if request.client else "unknown"