2025-06-08 21:57:05 +00:00

134 lines
3.5 KiB
Python

"""Authentication middleware."""
from typing import Callable, Optional
from fastapi import Depends, FastAPI, Header, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db.session import get_db
from app.models.user import User
from app.schemas.token import TokenPayload
async def get_token_from_header(
authorization: Optional[str] = Header(None),
) -> Optional[str]:
"""
Extract token from the Authorization header.
Args:
authorization: Authorization header
Returns:
Optional[str]: JWT token or None
"""
if not authorization:
return None
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer":
return None
return token
async def validate_token(
token: Optional[str],
db: Session,
) -> Optional[User]:
"""
Validate JWT token and return the user.
Args:
token: JWT token
db: Database session
Returns:
Optional[User]: User or None if token is invalid
"""
if not token:
return None
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
token_data = TokenPayload(**payload)
if token_data.sub is None:
return None
except (JWTError, ValueError):
return None
user = db.query(User).filter(User.id == token_data.sub).first()
if not user or not user.is_active:
return None
return user
class AuthMiddleware:
"""Authentication middleware for FastAPI applications."""
def __init__(self, app: FastAPI):
"""
Initialize middleware.
Args:
app: FastAPI application
"""
self.app = app
async def __call__(self, request: Request, call_next: Callable):
"""
Process request through middleware.
Args:
request: FastAPI request
call_next: Next middleware/endpoint in the chain
Returns:
Response: FastAPI response
"""
# Skip authentication for auth endpoints and docs
path = request.url.path
if (
path.startswith(f"{settings.API_V1_STR}/auth")
or path == "/"
or path == "/health"
or path.startswith("/docs")
or path.startswith("/redoc")
or path.startswith("/openapi.json")
):
return await call_next(request)
# Get token from header
token = await get_token_from_header(request.headers.get("Authorization"))
# For protected endpoints, token is required
if not token:
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
# Validate token
db = next(get_db())
user = await validate_token(token, db)
if not user:
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Attach user to request state
request.state.user = user
# Continue with the request
return await call_next(request)