from fastapi import Request, HTTPException, status, Depends from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import Response from app.core.config import settings from app.core.security import ALGORITHM # Define a simple middleware class for global JWT verification (if needed) class AuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # Skip auth for certain paths if self._should_skip_auth(request.url.path): return await call_next(request) # Get the Authorization header auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): # If no auth header, proceed without user context (endpoints will handle auth as needed) return await call_next(request) # Extract the token token = auth_header.replace("Bearer ", "") try: # Verify token and get payload payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) user_id = payload.get("sub") # Attach user ID to request state for use in endpoint handlers request.state.user_id = user_id except JWTError: # If token is invalid, proceed without user context pass # Continue processing the request return await call_next(request) def _should_skip_auth(self, path: str) -> bool: """Paths that don't require authentication check.""" skip_paths = [ "/docs", "/redoc", "/openapi.json", f"{settings.API_V1_STR}/auth/login", f"{settings.API_V1_STR}/auth/register", "/health", ] return any(path.startswith(skip_path) for skip_path in skip_paths)