from datetime import timedelta from typing import Any from fastapi import APIRouter, Body, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.orm import Session from app.api.deps import get_current_active_user, get_db from app.core.config import settings from app.core.security import ( create_access_token, create_refresh_token, ) from app.crud.crud_token import ( create_refresh_token as create_refresh_token_db, get_refresh_token, is_token_valid, revoke_refresh_token, ) from app.crud.crud_user import ( authenticate, create_user, get_user_by_email, ) from app.models.user import User from app.schemas.token import RefreshTokenCreate, Token from app.schemas.user import UserCreate router = APIRouter() @router.post("/login", response_model=Token) def login_access_token( db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends() ) -> Any: """ OAuth2 compatible token login, get an access token for future requests """ user = authenticate( db, email=form_data.username, password=form_data.password ) if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Incorrect email or password", ) elif not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user" ) # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token, access_token_expiry = create_access_token( user.id, expires_delta=access_token_expires ) # Create refresh token refresh_token, refresh_token_expiry = create_refresh_token(user.id) # Save refresh token to database refresh_token_data = RefreshTokenCreate( token=refresh_token, expires_at=refresh_token_expiry, user_id=user.id ) create_refresh_token_db(db=db, token_in=refresh_token_data) return { "access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token, "expires_at": access_token_expiry, } @router.post("/register", response_model=Token) def register_user( *, db: Session = Depends(get_db), user_in: UserCreate, ) -> Any: """ Register a new user and return an access token """ user = get_user_by_email(db, email=user_in.email) if user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="A user with this email already exists", ) user = create_user(db, obj_in=user_in) # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token, access_token_expiry = create_access_token( user.id, expires_delta=access_token_expires ) # Create refresh token refresh_token, refresh_token_expiry = create_refresh_token(user.id) # Save refresh token to database refresh_token_data = RefreshTokenCreate( token=refresh_token, expires_at=refresh_token_expiry, user_id=user.id ) create_refresh_token_db(db=db, token_in=refresh_token_data) return { "access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token, "expires_at": access_token_expiry, } @router.post("/refresh", response_model=Token) def refresh_access_token( db: Session = Depends(get_db), refresh_token: str = Body(...) ) -> Any: """ Get a new access token using a refresh token """ db_token = get_refresh_token(db=db, token=refresh_token) if not db_token or not is_token_valid(db_token): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired refresh token", ) # Create new access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token, access_token_expiry = create_access_token( db_token.user_id, expires_delta=access_token_expires ) # Create new refresh token new_refresh_token, refresh_token_expiry = create_refresh_token(db_token.user_id) # Revoke old refresh token revoke_refresh_token(db=db, token=refresh_token) # Save new refresh token to database refresh_token_data = RefreshTokenCreate( token=new_refresh_token, expires_at=refresh_token_expiry, user_id=db_token.user_id ) create_refresh_token_db(db=db, token_in=refresh_token_data) return { "access_token": access_token, "token_type": "bearer", "refresh_token": new_refresh_token, "expires_at": access_token_expiry, } @router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, response_model=None) def logout( db: Session = Depends(get_db), current_user: User = Depends(get_current_active_user), refresh_token: str = Body(...), ) -> Any: """ Logout by revoking the refresh token """ db_token = get_refresh_token(db=db, token=refresh_token) if db_token and db_token.user_id == current_user.id: revoke_refresh_token(db=db, token=refresh_token) return None