"""Authentication middleware.""" from typing import Callable, Optional from fastapi import Depends, FastAPI, Header, HTTPException, Request, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from sqlalchemy.orm import Session from app.core.config import settings from app.db.session import get_db from app.models.user import User from app.schemas.token import TokenPayload async def get_token_from_header( authorization: Optional[str] = Header(None), ) -> Optional[str]: """ Extract token from the Authorization header. Args: authorization: Authorization header Returns: Optional[str]: JWT token or None """ if not authorization: return None scheme, _, token = authorization.partition(" ") if scheme.lower() != "bearer": return None return token async def validate_token( token: Optional[str], db: Session, ) -> Optional[User]: """ Validate JWT token and return the user. Args: token: JWT token db: Database session Returns: Optional[User]: User or None if token is invalid """ if not token: return None try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) token_data = TokenPayload(**payload) if token_data.sub is None: return None except (JWTError, ValueError): return None user = db.query(User).filter(User.id == token_data.sub).first() if not user or not user.is_active: return None return user class AuthMiddleware: """Authentication middleware for FastAPI applications.""" def __init__(self, app: FastAPI): """ Initialize middleware. Args: app: FastAPI application """ self.app = app async def __call__(self, request: Request, call_next: Callable): """ Process request through middleware. Args: request: FastAPI request call_next: Next middleware/endpoint in the chain Returns: Response: FastAPI response """ # Skip authentication for auth endpoints and docs path = request.url.path if ( path.startswith(f"{settings.API_V1_STR}/auth") or path == "/" or path == "/health" or path.startswith("/docs") or path.startswith("/redoc") or path.startswith("/openapi.json") ): return await call_next(request) # Get token from header token = await get_token_from_header(request.headers.get("Authorization")) # For protected endpoints, token is required if not token: return HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", headers={"WWW-Authenticate": "Bearer"}, ) # Validate token db = next(get_db()) user = await validate_token(token, db) if not user: return HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) # Attach user to request state request.state.user = user # Continue with the request return await call_next(request)