
- Add role-based access control (admin/user roles) - Implement refresh token functionality - Add token revocation (logout) capability - Create admin-only endpoints - Add role validation middleware - Update documentation
173 lines
5.2 KiB
Python
173 lines
5.2 KiB
Python
from datetime import timedelta
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.api.deps import get_current_active_user, get_db
|
|
from app.core.config import settings
|
|
from app.core.security import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
)
|
|
from app.crud.crud_token import (
|
|
create_refresh_token as create_refresh_token_db,
|
|
get_refresh_token,
|
|
is_token_valid,
|
|
revoke_refresh_token,
|
|
)
|
|
from app.crud.crud_user import (
|
|
authenticate,
|
|
create_user,
|
|
get_user_by_email,
|
|
)
|
|
from app.models.user import User
|
|
from app.schemas.token import RefreshTokenCreate, Token
|
|
from app.schemas.user import UserCreate
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/login", response_model=Token)
|
|
def login_access_token(
|
|
db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
|
) -> Any:
|
|
"""
|
|
OAuth2 compatible token login, get an access token for future requests
|
|
"""
|
|
user = authenticate(
|
|
db, email=form_data.username, password=form_data.password
|
|
)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Incorrect email or password",
|
|
)
|
|
elif not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
|
|
)
|
|
|
|
# Create access token
|
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token, access_token_expiry = create_access_token(
|
|
user.id, expires_delta=access_token_expires
|
|
)
|
|
|
|
# Create refresh token
|
|
refresh_token, refresh_token_expiry = create_refresh_token(user.id)
|
|
|
|
# Save refresh token to database
|
|
refresh_token_data = RefreshTokenCreate(
|
|
token=refresh_token,
|
|
expires_at=refresh_token_expiry,
|
|
user_id=user.id
|
|
)
|
|
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "bearer",
|
|
"refresh_token": refresh_token,
|
|
"expires_at": access_token_expiry,
|
|
}
|
|
|
|
|
|
@router.post("/register", response_model=Token)
|
|
def register_user(
|
|
*,
|
|
db: Session = Depends(get_db),
|
|
user_in: UserCreate,
|
|
) -> Any:
|
|
"""
|
|
Register a new user and return an access token
|
|
"""
|
|
user = get_user_by_email(db, email=user_in.email)
|
|
if user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="A user with this email already exists",
|
|
)
|
|
user = create_user(db, obj_in=user_in)
|
|
|
|
# Create access token
|
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token, access_token_expiry = create_access_token(
|
|
user.id, expires_delta=access_token_expires
|
|
)
|
|
|
|
# Create refresh token
|
|
refresh_token, refresh_token_expiry = create_refresh_token(user.id)
|
|
|
|
# Save refresh token to database
|
|
refresh_token_data = RefreshTokenCreate(
|
|
token=refresh_token,
|
|
expires_at=refresh_token_expiry,
|
|
user_id=user.id
|
|
)
|
|
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "bearer",
|
|
"refresh_token": refresh_token,
|
|
"expires_at": access_token_expiry,
|
|
}
|
|
|
|
|
|
@router.post("/refresh", response_model=Token)
|
|
def refresh_access_token(
|
|
db: Session = Depends(get_db), refresh_token: str = Body(...)
|
|
) -> Any:
|
|
"""
|
|
Get a new access token using a refresh token
|
|
"""
|
|
db_token = get_refresh_token(db=db, token=refresh_token)
|
|
if not db_token or not is_token_valid(db_token):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired refresh token",
|
|
)
|
|
|
|
# Create new access token
|
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token, access_token_expiry = create_access_token(
|
|
db_token.user_id, expires_delta=access_token_expires
|
|
)
|
|
|
|
# Create new refresh token
|
|
new_refresh_token, refresh_token_expiry = create_refresh_token(db_token.user_id)
|
|
|
|
# Revoke old refresh token
|
|
revoke_refresh_token(db=db, token=refresh_token)
|
|
|
|
# Save new refresh token to database
|
|
refresh_token_data = RefreshTokenCreate(
|
|
token=new_refresh_token,
|
|
expires_at=refresh_token_expiry,
|
|
user_id=db_token.user_id
|
|
)
|
|
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "bearer",
|
|
"refresh_token": new_refresh_token,
|
|
"expires_at": access_token_expiry,
|
|
}
|
|
|
|
|
|
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, response_model=None)
|
|
def logout(
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_active_user),
|
|
refresh_token: str = Body(...),
|
|
) -> Any:
|
|
"""
|
|
Logout by revoking the refresh token
|
|
"""
|
|
db_token = get_refresh_token(db=db, token=refresh_token)
|
|
if db_token and db_token.user_id == current_user.id:
|
|
revoke_refresh_token(db=db, token=refresh_token)
|
|
return None |