95 lines
3.3 KiB
Python

from datetime import datetime, timedelta
from typing import Optional, Union
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from api.schemas.user import TokenData
from db.database import get_db
from db.models import User
# to get a string like this run: openssl rand -hex 32
# In production, this should be stored in environment variables
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" # noqa: S105
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/token")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Hash a password"""
return pwd_context.hash(password)
def authenticate_user(db: Session, username: str, password: str) -> Union[User, bool]:
"""Authenticate a user"""
user = db.query(User).filter(User.username == username).first()
if not user or not verify_password(password, user.hashed_password):
return False
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user",
)
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db),
) -> User:
"""Get the current user from a JWT token"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError as err:
raise credentials_exception from err
user = db.query(User).filter(User.username == token_data.username).first()
if user is None or not user.is_active:
raise credentials_exception
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
"""Get the current active user"""
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
"""Get the current superuser"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions",
)
return current_user