147 lines
4.1 KiB
Python

from datetime import timedelta
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from jose import jwt
from pydantic import ValidationError
from sqlalchemy.orm import Session
from app import schemas
from app.api import deps
from app.core import security
from app.core.config import settings
from app.services import user as user_service
router = APIRouter()
@router.post("/register", response_model=schemas.User)
def register(
*,
db: Session = Depends(deps.get_db),
user_in: schemas.UserCreate,
) -> Any:
"""
Register a new user.
"""
# Check if user with email exists
user = user_service.get_by_email(db, email=user_in.email)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
# Check if user with username exists
user = user_service.get_by_username(db, username=user_in.username)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already taken",
)
# Create new user
user = user_service.create(db, obj_in=user_in)
return user
@router.post("/login", response_model=schemas.Token)
def login(
db: Session = Depends(deps.get_db),
form_data: OAuth2PasswordRequestForm = Depends(),
) -> Any:
"""
Login for access token.
"""
# Authenticate user
user = user_service.authenticate(
db, email_or_username=form_data.username, password=form_data.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email/username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if user is active
if not user_service.is_active(user):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user",
)
# Create access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = security.create_access_token(
user.id, expires_delta=access_token_expires
)
# Create refresh token
refresh_token = security.create_refresh_token(user.id)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
@router.post("/refresh-token", response_model=schemas.Token)
def refresh_token(
db: Session = Depends(deps.get_db),
token_data: schemas.RefreshToken = None,
) -> Any:
"""
Refresh access token.
"""
try:
# Decode the refresh token
payload = jwt.decode(
token_data.refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
token_data = schemas.TokenPayload(**payload)
# Check if token type is refresh
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid refresh token",
)
except (jwt.JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
# Get the user from the token
user = user_service.get_by_id(db, id=int(token_data.sub))
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
# Check if user is active
if not user_service.is_active(user):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user",
)
# Create new access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = security.create_access_token(
user.id, expires_delta=access_token_expires
)
# Create new refresh token
refresh_token = security.create_refresh_token(user.id)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}