107 lines
3.4 KiB
Python

from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from sqlalchemy.orm import Session
from app.database import get_db
from app.schemas.user import User, UserCreate, Token, TokenData
from app.utils import (
ACCESS_TOKEN_EXPIRE_MINUTES,
ALGORITHM,
SECRET_KEY,
create_access_token,
create_user,
get_user,
get_user_by_email,
get_user_by_username,
verify_password,
)
router = APIRouter(
prefix="/users",
tags=["users"],
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="users/token")
async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
):
"""Get the current user from the JWT token."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user_by_username(db, username=token_data.username)
if user is None:
raise credentials_exception
return user
@router.post("/", response_model=User, status_code=status.HTTP_201_CREATED)
async def register_user(user: UserCreate, db: Session = Depends(get_db)):
"""Register a new user."""
# Check if email already exists
db_user = get_user_by_email(db, email=user.email)
if db_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
)
# Check if username already exists
db_user = get_user_by_username(db, username=user.username)
if db_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
)
# Create user
return create_user(db, user.email, user.username, user.password)
@router.post("/token", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
):
"""Get an access token from username and password."""
user = get_user_by_username(db, username=form_data.username)
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_user)):
"""Get the current authenticated user."""
return current_user
@router.get("/{user_id}", response_model=User)
async def read_user(user_id: int, db: Session = Depends(get_db)):
"""Get a user by ID."""
db_user = get_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user