from datetime import timedelta from typing import Any from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.orm import Session from app.core.config import settings from app.core.database import get_db from app.core.dependencies import get_current_active_user from app.core.security import create_access_token from app.models.user import User from app.schemas.token import Token, TokenRefresh from app.schemas.user import UserCreate from app.services import token as token_service from app.services import user as user_service router = APIRouter() @router.post("/register", response_model=Token) def register( user_in: UserCreate, db: Session = Depends(get_db), ) -> Any: """ Register a new user and return access token """ # Check if user with this email already exists user = user_service.get_by_email(db, email=user_in.email) if user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="A user with this email already exists", ) # Check if user with this username already exists user = user_service.get_by_username(db, username=user_in.username) if user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="A user with this username already exists", ) # Create new user user = user_service.create(db, obj_in=user_in) # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( subject=user.id, expires_delta=access_token_expires ) # Create refresh token refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id) return { "access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token_obj.token, "expires_at": refresh_token_obj.expires_at, } @router.post("/login", response_model=Token) def login( db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends(), ) -> Any: """ OAuth2 compatible token login, get an access token for future requests """ user = user_service.authenticate( db, username_or_email=form_data.username, password=form_data.password ) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username/email or password", headers={"WWW-Authenticate": "Bearer"}, ) if not user_service.is_active(user): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user", headers={"WWW-Authenticate": "Bearer"}, ) # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( subject=user.id, expires_delta=access_token_expires ) # Create refresh token refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id) return { "access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token_obj.token, "expires_at": refresh_token_obj.expires_at, } @router.post("/refresh", response_model=Token) def refresh_token( token_in: TokenRefresh, db: Session = Depends(get_db), ) -> Any: """ Refresh access token """ refresh_token = token_service.get_by_token(db, token=token_in.refresh_token) if not refresh_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", headers={"WWW-Authenticate": "Bearer"}, ) if not token_service.is_token_valid(refresh_token): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired or revoked", headers={"WWW-Authenticate": "Bearer"}, ) # Revoke the used refresh token token_service.revoke_token(db, refresh_token) # Create new access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( subject=refresh_token.user_id, expires_delta=access_token_expires ) # Create new refresh token new_refresh_token = token_service.create_refresh_token(db, user_id=refresh_token.user_id) return { "access_token": access_token, "token_type": "bearer", "refresh_token": new_refresh_token.token, "expires_at": new_refresh_token.expires_at, } @router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, response_model=None) def logout( token_in: TokenRefresh, db: Session = Depends(get_db), ) -> None: """ Logout by revoking the refresh token """ refresh_token = token_service.get_by_token(db, token=token_in.refresh_token) if refresh_token: token_service.revoke_token(db, refresh_token) return None @router.post("/logout-all", status_code=status.HTTP_204_NO_CONTENT, response_model=None) def logout_all( current_user: User = Depends(get_current_active_user), db: Session = Depends(get_db), ) -> None: """ Logout from all devices by revoking all refresh tokens """ token_service.revoke_all_user_tokens(db, user_id=current_user.id) return None