diff --git a/README.md b/README.md index 9610833..3a8b347 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,10 @@ A robust, RESTful API for managing todos, built with FastAPI and SQLite. ## Features -- 🔐 JWT Authentication +- 🔐 Enhanced JWT Authentication + - Access and Refresh tokens + - Token revocation (logout) + - Role-based access control (User/Admin roles) - 📝 Todo CRUD operations - 👤 User management - 🔍 Advanced todo filtering and pagination @@ -63,6 +66,8 @@ The application can be configured using the following environment variables: - `POST /api/v1/auth/register` - Register a new user - `POST /api/v1/auth/login` - Login and get access token +- `POST /api/v1/auth/refresh` - Refresh access token using refresh token +- `POST /api/v1/auth/logout` - Logout and revoke refresh token ### Users @@ -71,6 +76,12 @@ The application can be configured using the following environment variables: - `PUT /api/v1/users/me` - Update current user - `GET /api/v1/users/{user_id}` - Get user by ID +### Admin + +- `GET /api/v1/admin/users` - List all users (admin only) +- `GET /api/v1/admin/users/{user_id}` - Get user by ID (admin only) +- `PUT /api/v1/admin/users/{user_id}` - Update user (admin only) + ### Todos - `GET /api/v1/todos/` - List todos (with filtering and pagination) @@ -97,6 +108,7 @@ id: Integer (Primary Key) email: String (Unique, Indexed) hashed_password: String is_active: Boolean (Default: True) +role: Enum(admin, user) (Default: user) ``` ### Todo Model @@ -109,6 +121,17 @@ is_completed: Boolean (Default: False) owner_id: Integer (Foreign Key to User) ``` +### RefreshToken Model + +``` +id: Integer (Primary Key) +token: String (Unique, Indexed) +expires_at: DateTime +created_at: DateTime +revoked: Boolean (Default: False) +user_id: Integer (Foreign Key to User) +``` + ## Development ### Code Structure diff --git a/app/api/deps.py b/app/api/deps.py index f50aa66..89f43f8 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,3 +1,4 @@ +from functools import wraps from typing import Generator from fastapi import Depends, HTTPException, status @@ -9,7 +10,7 @@ from sqlalchemy.orm import Session from app.core.config import settings from app.crud.crud_user import get_user from app.db.session import SessionLocal -from app.models.user import User +from app.models.user import User, UserRole from app.schemas.token import TokenPayload oauth2_scheme = OAuth2PasswordBearer( @@ -49,4 +50,29 @@ def get_current_active_user( ) -> User: if not current_user.is_active: raise HTTPException(status_code=400, detail="Inactive user") - return current_user \ No newline at end of file + return current_user + + +def get_current_admin_user( + current_user: User = Depends(get_current_active_user), +) -> User: + if current_user.role != UserRole.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="The user doesn't have enough privileges", + ) + return current_user + + +def check_role(required_role: UserRole): + def decorator(func): + @wraps(func) + async def wrapper(*args, current_user: User = Depends(get_current_active_user), **kwargs): + if current_user.role != required_role and current_user.role != UserRole.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="The user doesn't have enough privileges", + ) + return await func(*args, current_user=current_user, **kwargs) + return wrapper + return decorator \ No newline at end of file diff --git a/app/api/v1/api.py b/app/api/v1/api.py index ab71af9..d8ab29c 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -1,8 +1,9 @@ from fastapi import APIRouter -from app.api.v1.endpoints import auth, todos, users +from app.api.v1.endpoints import admin, auth, todos, users api_router = APIRouter() api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) api_router.include_router(users.router, prefix="/users", tags=["users"]) -api_router.include_router(todos.router, prefix="/todos", tags=["todos"]) \ No newline at end of file +api_router.include_router(todos.router, prefix="/todos", tags=["todos"]) +api_router.include_router(admin.router, prefix="/admin", tags=["admin"]) \ No newline at end of file diff --git a/app/api/v1/endpoints/admin.py b/app/api/v1/endpoints/admin.py new file mode 100644 index 0000000..1969ed9 --- /dev/null +++ b/app/api/v1/endpoints/admin.py @@ -0,0 +1,66 @@ +from typing import Any, List + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.api.deps import get_current_admin_user, get_db +from app.crud.crud_user import get_user, get_users, update_user +from app.models.user import User +from app.schemas.user import User as UserSchema +from app.schemas.user import UserUpdate + +router = APIRouter() + + +@router.get("/users", response_model=List[UserSchema]) +def admin_read_users( + db: Session = Depends(get_db), + skip: int = 0, + limit: int = 100, + current_user: User = Depends(get_current_admin_user), +) -> Any: + """ + Retrieve all users (admin only). + """ + users = get_users(db, skip=skip, limit=limit) + return users + + +@router.put("/users/{user_id}", response_model=UserSchema) +def admin_update_user( + *, + db: Session = Depends(get_db), + user_id: int, + user_in: UserUpdate, + current_user: User = Depends(get_current_admin_user), +) -> Any: + """ + Update a user (admin only). + """ + user = get_user(db, user_id=user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The user with this id does not exist in the system", + ) + user = update_user(db, db_obj=user, obj_in=user_in) + return user + + +@router.get("/users/{user_id}", response_model=UserSchema) +def admin_read_user( + *, + db: Session = Depends(get_db), + user_id: int, + current_user: User = Depends(get_current_admin_user), +) -> Any: + """ + Get user by ID (admin only). + """ + user = get_user(db, user_id=user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The user with this id does not exist in the system", + ) + return user \ No newline at end of file diff --git a/app/api/v1/endpoints/auth.py b/app/api/v1/endpoints/auth.py index dca87fc..f3a685d 100644 --- a/app/api/v1/endpoints/auth.py +++ b/app/api/v1/endpoints/auth.py @@ -1,15 +1,29 @@ from datetime import timedelta from typing import Any -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Body, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.orm import Session -from app.api.deps import get_db +from app.api.deps import get_current_active_user, get_db from app.core.config import settings -from app.core.security import create_access_token -from app.crud.crud_user import authenticate, create_user, get_user_by_email -from app.schemas.token import Token +from app.core.security import ( + create_access_token, + create_refresh_token, +) +from app.crud.crud_token import ( + create_refresh_token as create_refresh_token_db, + get_refresh_token, + is_token_valid, + revoke_refresh_token, +) +from app.crud.crud_user import ( + authenticate, + create_user, + get_user_by_email, +) +from app.models.user import User +from app.schemas.token import RefreshTokenCreate, Token from app.schemas.user import UserCreate router = APIRouter() @@ -34,12 +48,29 @@ def login_access_token( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user" ) + + # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token, access_token_expiry = create_access_token( + user.id, expires_delta=access_token_expires + ) + + # Create refresh token + refresh_token, refresh_token_expiry = create_refresh_token(user.id) + + # Save refresh token to database + refresh_token_data = RefreshTokenCreate( + token=refresh_token, + expires_at=refresh_token_expiry, + user_id=user.id + ) + create_refresh_token_db(db=db, token_in=refresh_token_data) + return { - "access_token": create_access_token( - user.id, expires_delta=access_token_expires - ), + "access_token": access_token, "token_type": "bearer", + "refresh_token": refresh_token, + "expires_at": access_token_expiry, } @@ -59,10 +90,84 @@ def register_user( detail="A user with this email already exists", ) user = create_user(db, obj_in=user_in) + + # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token, access_token_expiry = create_access_token( + user.id, expires_delta=access_token_expires + ) + + # Create refresh token + refresh_token, refresh_token_expiry = create_refresh_token(user.id) + + # Save refresh token to database + refresh_token_data = RefreshTokenCreate( + token=refresh_token, + expires_at=refresh_token_expiry, + user_id=user.id + ) + create_refresh_token_db(db=db, token_in=refresh_token_data) + return { - "access_token": create_access_token( - user.id, expires_delta=access_token_expires - ), + "access_token": access_token, "token_type": "bearer", - } \ No newline at end of file + "refresh_token": refresh_token, + "expires_at": access_token_expiry, + } + + +@router.post("/refresh", response_model=Token) +def refresh_access_token( + db: Session = Depends(get_db), refresh_token: str = Body(...) +) -> Any: + """ + Get a new access token using a refresh token + """ + db_token = get_refresh_token(db=db, token=refresh_token) + if not db_token or not is_token_valid(db_token): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired refresh token", + ) + + # Create new access token + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token, access_token_expiry = create_access_token( + db_token.user_id, expires_delta=access_token_expires + ) + + # Create new refresh token + new_refresh_token, refresh_token_expiry = create_refresh_token(db_token.user_id) + + # Revoke old refresh token + revoke_refresh_token(db=db, token=refresh_token) + + # Save new refresh token to database + refresh_token_data = RefreshTokenCreate( + token=new_refresh_token, + expires_at=refresh_token_expiry, + user_id=db_token.user_id + ) + create_refresh_token_db(db=db, token_in=refresh_token_data) + + return { + "access_token": access_token, + "token_type": "bearer", + "refresh_token": new_refresh_token, + "expires_at": access_token_expiry, + } + + +@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, response_model=None) +def logout( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_active_user), + refresh_token: str = Body(...), +) -> Any: + """ + Logout by revoking the refresh token + """ + db_token = get_refresh_token(db=db, token=refresh_token) + if db_token and db_token.user_id == current_user.id: + revoke_refresh_token(db=db, token=refresh_token) + return None \ No newline at end of file diff --git a/app/core/security.py b/app/core/security.py index 6693013..426665e 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,5 +1,6 @@ +import secrets from datetime import datetime, timedelta -from typing import Any, Union +from typing import Any, Dict, Optional, Tuple, Union from jose import jwt from passlib.context import CryptContext @@ -8,19 +9,52 @@ from app.core.config import settings pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +# Define token types +ACCESS_TOKEN_TYPE = "access" +REFRESH_TOKEN_TYPE = "refresh" + +# Expiration times +ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES +REFRESH_TOKEN_EXPIRE_DAYS = 30 + def create_access_token( - subject: Union[str, Any], expires_delta: timedelta = None -) -> str: + subject: Union[str, Any], expires_delta: Optional[timedelta] = None +) -> Tuple[str, datetime]: if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta( - minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + minutes=ACCESS_TOKEN_EXPIRE_MINUTES ) - to_encode = {"exp": expire, "sub": str(subject)} + to_encode = {"exp": expire, "sub": str(subject), "type": ACCESS_TOKEN_TYPE} encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256") - return encoded_jwt + return encoded_jwt, expire + + +def create_refresh_token( + subject: Union[str, Any], expires_delta: Optional[timedelta] = None +) -> Tuple[str, datetime]: + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + + # Generate a secure random token + refresh_token = secrets.token_urlsafe(64) + + # No need to encode anything as we'll store this in the database + return refresh_token, expire + + +def verify_token(token: str) -> Optional[Dict[str, Any]]: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=["HS256"] + ) + return payload + except jwt.JWTError: + return None def verify_password(plain_password: str, hashed_password: str) -> bool: diff --git a/app/crud/crud_token.py b/app/crud/crud_token.py new file mode 100644 index 0000000..3f6c7c3 --- /dev/null +++ b/app/crud/crud_token.py @@ -0,0 +1,67 @@ +from datetime import datetime +from typing import List, Optional + +from sqlalchemy.orm import Session + +from app.models.token import RefreshToken +from app.schemas.token import RefreshTokenCreate, RefreshTokenUpdate + + +def get_refresh_token(db: Session, token: str) -> Optional[RefreshToken]: + return db.query(RefreshToken).filter(RefreshToken.token == token).first() + + +def get_refresh_tokens_by_user( + db: Session, user_id: int, skip: int = 0, limit: int = 100 +) -> List[RefreshToken]: + return ( + db.query(RefreshToken) + .filter(RefreshToken.user_id == user_id) + .offset(skip) + .limit(limit) + .all() + ) + + +def create_refresh_token(db: Session, token_in: RefreshTokenCreate) -> RefreshToken: + db_token = RefreshToken( + token=token_in.token, + expires_at=token_in.expires_at, + user_id=token_in.user_id, + created_at=datetime.utcnow(), + revoked=False, + ) + db.add(db_token) + db.commit() + db.refresh(db_token) + return db_token + + +def update_refresh_token( + db: Session, db_obj: RefreshToken, obj_in: RefreshTokenUpdate +) -> RefreshToken: + update_data = obj_in.model_dump(exclude_unset=True) + for field in update_data: + setattr(db_obj, field, update_data[field]) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + +def revoke_refresh_token(db: Session, token: str) -> Optional[RefreshToken]: + db_token = get_refresh_token(db, token=token) + if db_token: + db_token.revoked = True + db.add(db_token) + db.commit() + db.refresh(db_token) + return db_token + + +def is_token_valid(token: RefreshToken) -> bool: + return ( + token is not None + and not token.revoked + and token.expires_at > datetime.utcnow() + ) \ No newline at end of file diff --git a/app/crud/crud_user.py b/app/crud/crud_user.py index 344db03..49d1c9e 100644 --- a/app/crud/crud_user.py +++ b/app/crud/crud_user.py @@ -26,6 +26,7 @@ def create_user(db: Session, obj_in: UserCreate) -> User: email=obj_in.email, hashed_password=get_password_hash(obj_in.password), is_active=True, + role=obj_in.role, ) db.add(db_obj) db.commit() diff --git a/app/models/token.py b/app/models/token.py new file mode 100644 index 0000000..8c24a46 --- /dev/null +++ b/app/models/token.py @@ -0,0 +1,19 @@ +from datetime import datetime + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + +from app.db.base import Base + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id = Column(Integer, primary_key=True, index=True) + token = Column(String, unique=True, index=True) + expires_at = Column(DateTime) + created_at = Column(DateTime, default=datetime.utcnow) + revoked = Column(Boolean, default=False) + user_id = Column(Integer, ForeignKey("users.id")) + + user = relationship("User") \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py index aefc85f..5b7f67a 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -1,9 +1,15 @@ -from sqlalchemy import Boolean, Column, Integer, String +from enum import Enum as PyEnum +from sqlalchemy import Boolean, Column, Integer, String, Enum from sqlalchemy.orm import relationship from app.db.base import Base +class UserRole(str, PyEnum): + ADMIN = "admin" + USER = "user" + + class User(Base): __tablename__ = "users" @@ -11,5 +17,6 @@ class User(Base): email = Column(String, unique=True, index=True) hashed_password = Column(String) is_active = Column(Boolean, default=True) + role = Column(Enum(UserRole), default=UserRole.USER) todos = relationship("Todo", back_populates="owner") \ No newline at end of file diff --git a/app/schemas/token.py b/app/schemas/token.py index 69541e2..e09301c 100644 --- a/app/schemas/token.py +++ b/app/schemas/token.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Optional from pydantic import BaseModel @@ -6,7 +7,33 @@ from pydantic import BaseModel class Token(BaseModel): access_token: str token_type: str + refresh_token: Optional[str] = None + expires_at: Optional[datetime] = None class TokenPayload(BaseModel): - sub: Optional[int] = None \ No newline at end of file + sub: Optional[int] = None + exp: Optional[datetime] = None + type: Optional[str] = "access" + + +class RefreshTokenCreate(BaseModel): + token: str + expires_at: datetime + user_id: int + + +class RefreshTokenUpdate(BaseModel): + revoked: Optional[bool] = None + + +class RefreshTokenInDB(BaseModel): + id: int + token: str + expires_at: datetime + created_at: datetime + revoked: bool + user_id: int + + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/schemas/user.py b/app/schemas/user.py index ad36fac..f3a2390 100644 --- a/app/schemas/user.py +++ b/app/schemas/user.py @@ -2,17 +2,21 @@ from typing import Optional from pydantic import BaseModel, EmailStr +from app.models.user import UserRole + # Shared properties class UserBase(BaseModel): email: Optional[EmailStr] = None is_active: Optional[bool] = True + role: Optional[UserRole] = None # Properties to receive via API on creation class UserCreate(UserBase): email: EmailStr password: str + role: UserRole = UserRole.USER # Properties to receive via API on update @@ -22,6 +26,7 @@ class UserUpdate(UserBase): class UserInDBBase(UserBase): id: Optional[int] = None + role: UserRole = UserRole.USER class Config: from_attributes = True diff --git a/migrations/versions/002_add_user_roles.py b/migrations/versions/002_add_user_roles.py new file mode 100644 index 0000000..454d843 --- /dev/null +++ b/migrations/versions/002_add_user_roles.py @@ -0,0 +1,30 @@ +"""add user roles + +Revision ID: 002 +Revises: 001 +Create Date: 2023-11-16 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '002' +down_revision = '001' +branch_labels = None +depends_on = None + + +def upgrade(): + # Add role column to users table + with op.batch_alter_table('users', schema=None) as batch_op: + batch_op.add_column(sa.Column('role', sa.Enum('admin', 'user', name='userrole'), nullable=True)) + batch_op.execute("UPDATE users SET role = 'user' WHERE role IS NULL") + batch_op.alter_column('role', nullable=False, server_default='user') + + +def downgrade(): + # Remove role column from users table + with op.batch_alter_table('users', schema=None) as batch_op: + batch_op.drop_column('role') \ No newline at end of file diff --git a/migrations/versions/003_add_refresh_tokens.py b/migrations/versions/003_add_refresh_tokens.py new file mode 100644 index 0000000..bfaa5ad --- /dev/null +++ b/migrations/versions/003_add_refresh_tokens.py @@ -0,0 +1,40 @@ +"""add refresh tokens + +Revision ID: 003 +Revises: 002 +Create Date: 2023-11-16 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '003' +down_revision = '002' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create refresh_tokens table + op.create_table( + 'refresh_tokens', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('token', sa.String(), nullable=True), + sa.Column('expires_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('revoked', sa.Boolean(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_refresh_tokens_id'), 'refresh_tokens', ['id'], unique=False) + op.create_index(op.f('ix_refresh_tokens_token'), 'refresh_tokens', ['token'], unique=True) + + +def downgrade(): + # Drop refresh_tokens table + op.drop_index(op.f('ix_refresh_tokens_token'), table_name='refresh_tokens') + op.drop_index(op.f('ix_refresh_tokens_id'), table_name='refresh_tokens') + op.drop_table('refresh_tokens') \ No newline at end of file