156 lines
5.0 KiB
Python
156 lines
5.0 KiB
Python
import logging
|
|
import time
|
|
from collections.abc import Callable
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Simple in-memory rate limiter implementation.
|
|
For production use, consider using Redis or another distributed store.
|
|
"""
|
|
|
|
def __init__(self, rate_limit_per_minute: int = 60):
|
|
self.rate_limit_per_minute = rate_limit_per_minute
|
|
self.requests: dict[str, dict[float, int]] = {}
|
|
self.window_size = 60 # 1 minute in seconds
|
|
|
|
def is_rate_limited(self, client_id: str) -> tuple[bool, dict]:
|
|
"""
|
|
Check if a client is rate limited.
|
|
|
|
Args:
|
|
client_id: Identifier for the client (usually IP address)
|
|
|
|
Returns:
|
|
Tuple of (is_limited, rate_limit_info)
|
|
|
|
"""
|
|
current_time = time.time()
|
|
|
|
# Initialize client record if it doesn't exist
|
|
if client_id not in self.requests:
|
|
self.requests[client_id] = {}
|
|
|
|
# Clean up old records
|
|
self._cleanup(client_id, current_time)
|
|
|
|
# Count recent requests
|
|
recent_requests = sum(self.requests[client_id].values())
|
|
|
|
# Check if rate limit is exceeded
|
|
is_limited = recent_requests >= self.rate_limit_per_minute
|
|
|
|
# Update request count if not limited
|
|
if not is_limited:
|
|
self.requests[client_id][current_time] = self.requests[client_id].get(current_time, 0) + 1
|
|
|
|
# Calculate rate limit info
|
|
remaining = max(0, self.rate_limit_per_minute - recent_requests)
|
|
reset_at = current_time + self.window_size
|
|
|
|
return is_limited, {
|
|
"limit": self.rate_limit_per_minute,
|
|
"remaining": remaining,
|
|
"reset": int(reset_at),
|
|
}
|
|
|
|
def _cleanup(self, client_id: str, current_time: float) -> None:
|
|
"""
|
|
Clean up old records for a client.
|
|
|
|
Args:
|
|
client_id: Identifier for the client
|
|
current_time: Current timestamp
|
|
|
|
"""
|
|
cutoff_time = current_time - self.window_size
|
|
timestamps_to_remove = [ts for ts in self.requests[client_id].keys() if ts < cutoff_time]
|
|
|
|
for ts in timestamps_to_remove:
|
|
del self.requests[client_id][ts]
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware for rate limiting API requests.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
rate_limit_per_minute: int = 60,
|
|
whitelist_paths: list | None = None,
|
|
client_id_func: Callable[[Request], str] | None = None
|
|
):
|
|
super().__init__(app)
|
|
self.rate_limiter = RateLimiter(rate_limit_per_minute)
|
|
self.whitelist_paths = whitelist_paths or ["/health", "/docs", "/redoc", "/openapi.json"]
|
|
self.client_id_func = client_id_func or self._default_client_id
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
"""
|
|
Process the request through rate limiting.
|
|
|
|
Args:
|
|
request: The incoming request
|
|
call_next: The next handler in the middleware chain
|
|
|
|
Returns:
|
|
The response
|
|
|
|
"""
|
|
# Skip rate limiting for whitelisted paths
|
|
path = request.url.path
|
|
if any(path.startswith(wl_path) for wl_path in self.whitelist_paths):
|
|
return await call_next(request)
|
|
|
|
# Get client identifier
|
|
client_id = self.client_id_func(request)
|
|
|
|
# Check if rate limited
|
|
is_limited, rate_limit_info = self.rate_limiter.is_rate_limited(client_id)
|
|
|
|
# If rate limited, return 429 Too Many Requests
|
|
if is_limited:
|
|
logger.warning(f"Rate limit exceeded for client {client_id}")
|
|
response = Response(
|
|
content={"detail": "Rate limit exceeded"},
|
|
status_code=429,
|
|
media_type="application/json"
|
|
)
|
|
else:
|
|
# Process the request normally
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers to response
|
|
response.headers["X-RateLimit-Limit"] = str(rate_limit_info["limit"])
|
|
response.headers["X-RateLimit-Remaining"] = str(rate_limit_info["remaining"])
|
|
response.headers["X-RateLimit-Reset"] = str(rate_limit_info["reset"])
|
|
|
|
return response
|
|
|
|
def _default_client_id(self, request: Request) -> str:
|
|
"""
|
|
Default function to extract client identifier from request.
|
|
Uses the client's IP address.
|
|
|
|
Args:
|
|
request: The incoming request
|
|
|
|
Returns:
|
|
Client identifier string
|
|
|
|
"""
|
|
# Try to get real IP from forwarded header (for proxies)
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
# Get the first IP in the list (client IP)
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
# Fall back to the direct client address
|
|
return request.client.host if request.client else "unknown"
|