156 lines
5.0 KiB
Python

import logging
import time
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
class RateLimiter:
"""
Simple in-memory rate limiter implementation.
For production use, consider using Redis or another distributed store.
"""
def __init__(self, rate_limit_per_minute: int = 60):
self.rate_limit_per_minute = rate_limit_per_minute
self.requests: dict[str, dict[float, int]] = {}
self.window_size = 60 # 1 minute in seconds
def is_rate_limited(self, client_id: str) -> tuple[bool, dict]:
"""
Check if a client is rate limited.
Args:
client_id: Identifier for the client (usually IP address)
Returns:
Tuple of (is_limited, rate_limit_info)
"""
current_time = time.time()
# Initialize client record if it doesn't exist
if client_id not in self.requests:
self.requests[client_id] = {}
# Clean up old records
self._cleanup(client_id, current_time)
# Count recent requests
recent_requests = sum(self.requests[client_id].values())
# Check if rate limit is exceeded
is_limited = recent_requests >= self.rate_limit_per_minute
# Update request count if not limited
if not is_limited:
self.requests[client_id][current_time] = self.requests[client_id].get(current_time, 0) + 1
# Calculate rate limit info
remaining = max(0, self.rate_limit_per_minute - recent_requests)
reset_at = current_time + self.window_size
return is_limited, {
"limit": self.rate_limit_per_minute,
"remaining": remaining,
"reset": int(reset_at),
}
def _cleanup(self, client_id: str, current_time: float) -> None:
"""
Clean up old records for a client.
Args:
client_id: Identifier for the client
current_time: Current timestamp
"""
cutoff_time = current_time - self.window_size
timestamps_to_remove = [ts for ts in self.requests[client_id].keys() if ts < cutoff_time]
for ts in timestamps_to_remove:
del self.requests[client_id][ts]
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware for rate limiting API requests.
"""
def __init__(
self,
app: ASGIApp,
rate_limit_per_minute: int = 60,
whitelist_paths: list | None = None,
client_id_func: Callable[[Request], str] | None = None
):
super().__init__(app)
self.rate_limiter = RateLimiter(rate_limit_per_minute)
self.whitelist_paths = whitelist_paths or ["/health", "/docs", "/redoc", "/openapi.json"]
self.client_id_func = client_id_func or self._default_client_id
async def dispatch(self, request: Request, call_next) -> Response:
"""
Process the request through rate limiting.
Args:
request: The incoming request
call_next: The next handler in the middleware chain
Returns:
The response
"""
# Skip rate limiting for whitelisted paths
path = request.url.path
if any(path.startswith(wl_path) for wl_path in self.whitelist_paths):
return await call_next(request)
# Get client identifier
client_id = self.client_id_func(request)
# Check if rate limited
is_limited, rate_limit_info = self.rate_limiter.is_rate_limited(client_id)
# If rate limited, return 429 Too Many Requests
if is_limited:
logger.warning(f"Rate limit exceeded for client {client_id}")
response = Response(
content={"detail": "Rate limit exceeded"},
status_code=429,
media_type="application/json"
)
else:
# Process the request normally
response = await call_next(request)
# Add rate limit headers to response
response.headers["X-RateLimit-Limit"] = str(rate_limit_info["limit"])
response.headers["X-RateLimit-Remaining"] = str(rate_limit_info["remaining"])
response.headers["X-RateLimit-Reset"] = str(rate_limit_info["reset"])
return response
def _default_client_id(self, request: Request) -> str:
"""
Default function to extract client identifier from request.
Uses the client's IP address.
Args:
request: The incoming request
Returns:
Client identifier string
"""
# Try to get real IP from forwarded header (for proxies)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Get the first IP in the list (client IP)
return forwarded_for.split(",")[0].strip()
# Fall back to the direct client address
return request.client.host if request.client else "unknown"