diff --git a/app/api/v1/endpoints/auth.py b/app/api/v1/endpoints/auth.py index eb8e92c..02138f7 100644 --- a/app/api/v1/endpoints/auth.py +++ b/app/api/v1/endpoints/auth.py @@ -33,7 +33,7 @@ def register( status_code=status.HTTP_400_BAD_REQUEST, detail="A user with this email already exists", ) - + # Check if user with this username already exists user = user_service.get_by_username(db, username=user_in.username) if user: @@ -41,19 +41,19 @@ def register( status_code=status.HTTP_400_BAD_REQUEST, detail="A user with this username already exists", ) - + # Create new user user = user_service.create(db, obj_in=user_in) - + # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( subject=user.id, expires_delta=access_token_expires ) - + # Create refresh token refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id) - + return { "access_token": access_token, "token_type": "bearer", @@ -79,23 +79,23 @@ def login( detail="Incorrect username/email or password", headers={"WWW-Authenticate": "Bearer"}, ) - + if not user_service.is_active(user): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user", headers={"WWW-Authenticate": "Bearer"}, ) - + # Create access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( subject=user.id, expires_delta=access_token_expires ) - + # Create refresh token refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id) - + return { "access_token": access_token, "token_type": "bearer", @@ -119,26 +119,28 @@ def refresh_token( detail="Invalid refresh token", headers={"WWW-Authenticate": "Bearer"}, ) - + if not token_service.is_token_valid(refresh_token): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired or revoked", headers={"WWW-Authenticate": "Bearer"}, ) - + # Revoke the used refresh token token_service.revoke_token(db, refresh_token) - + # Create new access token access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( subject=refresh_token.user_id, expires_delta=access_token_expires ) - + # Create new refresh token - new_refresh_token = token_service.create_refresh_token(db, user_id=refresh_token.user_id) - + new_refresh_token = token_service.create_refresh_token( + db, user_id=refresh_token.user_id + ) + return { "access_token": access_token, "token_type": "bearer", @@ -170,4 +172,4 @@ def logout_all( Logout from all devices by revoking all refresh tokens """ token_service.revoke_all_user_tokens(db, user_id=current_user.id) - return None \ No newline at end of file + return None diff --git a/app/api/v1/endpoints/users.py b/app/api/v1/endpoints/users.py index 24a888d..cae2bd8 100644 --- a/app/api/v1/endpoints/users.py +++ b/app/api/v1/endpoints/users.py @@ -71,14 +71,14 @@ def create_user( status_code=status.HTTP_400_BAD_REQUEST, detail="A user with this email already exists", ) - + user = user_service.get_by_username(db, username=user_in.username) if user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="A user with this username already exists", ) - + user = user_service.create(db, obj_in=user_in) return user @@ -121,4 +121,4 @@ def update_user( detail="User not found", ) user = user_service.update(db, db_obj=user, obj_in=user_in) - return user \ No newline at end of file + return user diff --git a/app/api/v1/router.py b/app/api/v1/router.py index c960076..cec7c85 100644 --- a/app/api/v1/router.py +++ b/app/api/v1/router.py @@ -4,4 +4,4 @@ from app.api.v1.endpoints import auth, users api_router = APIRouter() api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) -api_router.include_router(users.router, prefix="/users", tags=["users"]) \ No newline at end of file +api_router.include_router(users.router, prefix="/users", tags=["users"]) diff --git a/app/core/config.py b/app/core/config.py index 7f2b8a8..e9b52fb 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -3,17 +3,18 @@ from pathlib import Path from pydantic import AnyHttpUrl, validator from pydantic_settings import BaseSettings + class Settings(BaseSettings): API_V1_STR: str = "/api/v1" PROJECT_NAME: str = "User Authentication Service" PROJECT_DESCRIPTION: str = "API service for user authentication" PROJECT_VERSION: str = "0.1.0" - + # Secret key for JWT token and other security mechanisms SECRET_KEY: str = "YOUR_SUPER_SECRET_KEY_CHANGE_THIS_IN_PRODUCTION" # 60 minutes * 24 hours * 8 days = 8 days in minutes ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 - + # CORS BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] @@ -24,19 +25,20 @@ class Settings(BaseSettings): elif isinstance(v, (list, str)): return v raise ValueError(v) - + # Database DB_DIR: Path = Path("/app/storage/db") SQLALCHEMY_DATABASE_URL: str = f"sqlite:///{DB_DIR}/db.sqlite" - + # Token related TOKEN_URL: str = f"{API_V1_STR}/auth/login" - + class Config: case_sensitive = True env_file = ".env" + # Create the DB directory if it doesn't exist Settings().DB_DIR.mkdir(parents=True, exist_ok=True) -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/app/core/database.py b/app/core/database.py index ebd475d..cf2af71 100644 --- a/app/core/database.py +++ b/app/core/database.py @@ -6,17 +6,18 @@ from app.core.config import settings engine = create_engine( settings.SQLALCHEMY_DATABASE_URL, - connect_args={"check_same_thread": False} # Only needed for SQLite + connect_args={"check_same_thread": False}, # Only needed for SQLite ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + # Dependency to get the database session def get_db(): db = SessionLocal() try: yield db finally: - db.close() \ No newline at end of file + db.close() diff --git a/app/core/dependencies.py b/app/core/dependencies.py index ee3716b..fee67ce 100644 --- a/app/core/dependencies.py +++ b/app/core/dependencies.py @@ -1,4 +1,3 @@ - from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import jwt @@ -23,9 +22,7 @@ def get_current_user( Get current user based on JWT token """ try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[ALGORITHM] - ) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) token_data = TokenPayload(**payload) except (jwt.JWTError, ValidationError): raise HTTPException( @@ -33,12 +30,11 @@ def get_current_user( detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) - + user = user_service.get_by_id(db, user_id=token_data.sub) if not user: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) return user @@ -51,8 +47,7 @@ def get_current_active_user( """ if not user_service.is_active(current_user): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Inactive user" + status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user" ) return current_user @@ -66,6 +61,6 @@ def get_current_active_superuser( if not user_service.is_superuser(current_user): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="The user doesn't have enough privileges" + detail="The user doesn't have enough privileges", ) - return current_user \ No newline at end of file + return current_user diff --git a/app/core/security.py b/app/core/security.py index 95d69f7..edccdac 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -39,4 +39,4 @@ def get_password_hash(password: str) -> str: """ Hash a password for storing """ - return pwd_context.hash(password) \ No newline at end of file + return pwd_context.hash(password) diff --git a/app/models/__init__.py b/app/models/__init__.py index 9044645..acd1835 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -2,4 +2,4 @@ from app.models.user import User from app.models.token import Token # Add all models here so they can be imported from app.models -__all__ = ["User", "Token"] \ No newline at end of file +__all__ = ["User", "Token"] diff --git a/app/models/token.py b/app/models/token.py index 553795d..dabd9a8 100644 --- a/app/models/token.py +++ b/app/models/token.py @@ -8,8 +8,10 @@ class Token(Base): __tablename__ = "tokens" id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + user_id = Column( + Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) token = Column(String, unique=True, index=True, nullable=False) expires_at = Column(DateTime(timezone=True), nullable=False) is_revoked = Column(Boolean, default=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) \ No newline at end of file + created_at = Column(DateTime(timezone=True), server_default=func.now()) diff --git a/app/models/user.py b/app/models/user.py index a2fd899..af2d979 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -15,4 +15,6 @@ class User(Base): is_active = Column(Boolean, default=True) is_superuser = Column(Boolean, default=False) created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) \ No newline at end of file + updated_at = Column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index c8c14f3..4f32ad4 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -4,15 +4,15 @@ from app.schemas.user import User, UserCreate, UserInDB, UserUpdate # Add all schemas here so they can be imported from app.schemas __all__ = [ - "User", - "UserCreate", - "UserInDB", - "UserUpdate", - "Token", - "TokenPayload", + "User", + "UserCreate", + "UserInDB", + "UserUpdate", + "Token", + "TokenPayload", "TokenRefresh", "Message", "HTTPValidationError", "ResponseBase", - "ResponseData" -] \ No newline at end of file + "ResponseData", +] diff --git a/app/schemas/message.py b/app/schemas/message.py index ba88cae..9dae7cf 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -23,4 +23,4 @@ class ResponseBase(BaseModel): class ResponseData(ResponseBase): - data: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None \ No newline at end of file + data: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None diff --git a/app/schemas/token.py b/app/schemas/token.py index 45fe13d..11da9af 100644 --- a/app/schemas/token.py +++ b/app/schemas/token.py @@ -17,4 +17,4 @@ class TokenPayload(BaseModel): class TokenRefresh(BaseModel): - refresh_token: str \ No newline at end of file + refresh_token: str diff --git a/app/schemas/user.py b/app/schemas/user.py index 721046d..6581352 100644 --- a/app/schemas/user.py +++ b/app/schemas/user.py @@ -36,4 +36,4 @@ class User(UserInDBBase): class UserInDB(UserInDBBase): - hashed_password: str \ No newline at end of file + hashed_password: str diff --git a/app/services/token.py b/app/services/token.py index 0986432..cc8570d 100644 --- a/app/services/token.py +++ b/app/services/token.py @@ -19,22 +19,19 @@ def create_refresh_token( if expires_delta: expire = datetime.utcnow() + expires_delta else: - expire = datetime.utcnow() + timedelta(days=30) # 30 days default - + expire = datetime.utcnow() + timedelta(days=7) # 7 days default + # Generate a secure random token token_value = secrets.token_urlsafe(64) - + # Create token in DB db_token = Token( - user_id=user_id, - token=token_value, - expires_at=expire, - is_revoked=False + user_id=user_id, token=token_value, expires_at=expire, is_revoked=False ) db.add(db_token) db.commit() db.refresh(db_token) - + return db_token @@ -71,7 +68,7 @@ def revoke_all_user_tokens(db: Session, user_id: int) -> None: tokens = db.query(Token).filter(Token.user_id == user_id).all() for token in tokens: token.is_revoked = True - + db.commit() @@ -80,9 +77,7 @@ def decode_token(token: str) -> Optional[dict]: Decode a JWT token """ try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[ALGORITHM] - ) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) return payload except JWTError: - return None \ No newline at end of file + return None diff --git a/app/services/user.py b/app/services/user.py index 86b19ad..fb715a2 100644 --- a/app/services/user.py +++ b/app/services/user.py @@ -28,9 +28,7 @@ def get_by_username(db: Session, username: str) -> Optional[User]: return db.query(User).filter(User.username == username).first() -def get_multi( - db: Session, *, skip: int = 0, limit: int = 100 -) -> List[User]: +def get_multi(db: Session, *, skip: int = 0, limit: int = 100) -> List[User]: """ Get multiple users with pagination """ @@ -65,16 +63,16 @@ def update( update_data = obj_in else: update_data = obj_in.dict(exclude_unset=True) - + if "password" in update_data and update_data["password"]: hashed_password = get_password_hash(update_data["password"]) del update_data["password"] update_data["hashed_password"] = hashed_password - + for field in update_data: if field in update_data: setattr(db_obj, field, update_data[field]) - + db.add(db_obj) db.commit() db.refresh(db_obj) @@ -108,4 +106,4 @@ def is_superuser(user: User) -> bool: """ Check if user is superuser """ - return user.is_superuser \ No newline at end of file + return user.is_superuser diff --git a/main.py b/main.py index ee9d08e..e00eda7 100644 --- a/main.py +++ b/main.py @@ -24,11 +24,14 @@ if settings.BACKEND_CORS_ORIGINS: app.include_router(api_router, prefix=settings.API_V1_STR) + # Health check endpoint @app.get("/health", status_code=200) def health_check(): return {"status": "healthy"} + if __name__ == "__main__": import uvicorn - uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file + + uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/migrations/env.py b/migrations/env.py index 98a5563..a028607 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -65,7 +65,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - is_sqlite = connection.dialect.name == 'sqlite' + is_sqlite = connection.dialect.name == "sqlite" context.configure( connection=connection, target_metadata=target_metadata, @@ -79,4 +79,4 @@ def run_migrations_online() -> None: if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() \ No newline at end of file + run_migrations_online() diff --git a/migrations/versions/001_create_user_token_tables.py b/migrations/versions/001_create_user_token_tables.py index 664a5a0..cc57237 100644 --- a/migrations/versions/001_create_user_token_tables.py +++ b/migrations/versions/001_create_user_token_tables.py @@ -1,17 +1,18 @@ """create user token tables Revision ID: 001 -Revises: +Revises: Create Date: 2023-10-10 """ + from alembic import op import sqlalchemy as sa from sqlalchemy.sql import func # revision identifiers, used by Alembic. -revision = '001' +revision = "001" down_revision = None branch_labels = None depends_on = None @@ -20,46 +21,51 @@ depends_on = None def upgrade() -> None: # Create users table op.create_table( - 'users', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('email', sa.String(), nullable=False), - sa.Column('username', sa.String(), nullable=False), - sa.Column('hashed_password', sa.String(), nullable=False), - sa.Column('full_name', sa.String(), nullable=True), - sa.Column('is_active', sa.Boolean(), default=True), - sa.Column('is_superuser', sa.Boolean(), default=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=func.now()), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=func.now(), onupdate=func.now()), - sa.PrimaryKeyConstraint('id') + "users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("username", sa.String(), nullable=False), + sa.Column("hashed_password", sa.String(), nullable=False), + sa.Column("full_name", sa.String(), nullable=True), + sa.Column("is_active", sa.Boolean(), default=True), + sa.Column("is_superuser", sa.Boolean(), default=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=func.now()), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + ), + sa.PrimaryKeyConstraint("id"), ) - - op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) - op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) - op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) - + + op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False) + op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) + op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True) + # Create tokens table op.create_table( - 'tokens', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('token', sa.String(), nullable=False), - sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('is_revoked', sa.Boolean(), default=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=func.now()), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + "tokens", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("token", sa.String(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("is_revoked", sa.Boolean(), default=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=func.now()), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - - op.create_index(op.f('ix_tokens_id'), 'tokens', ['id'], unique=False) - op.create_index(op.f('ix_tokens_token'), 'tokens', ['token'], unique=True) + + op.create_index(op.f("ix_tokens_id"), "tokens", ["id"], unique=False) + op.create_index(op.f("ix_tokens_token"), "tokens", ["token"], unique=True) def downgrade() -> None: - op.drop_index(op.f('ix_tokens_token'), table_name='tokens') - op.drop_index(op.f('ix_tokens_id'), table_name='tokens') - op.drop_table('tokens') - - op.drop_index(op.f('ix_users_username'), table_name='users') - op.drop_index(op.f('ix_users_email'), table_name='users') - op.drop_index(op.f('ix_users_id'), table_name='users') - op.drop_table('users') \ No newline at end of file + op.drop_index(op.f("ix_tokens_token"), table_name="tokens") + op.drop_index(op.f("ix_tokens_id"), table_name="tokens") + op.drop_table("tokens") + + op.drop_index(op.f("ix_users_username"), table_name="users") + op.drop_index(op.f("ix_users_email"), table_name="users") + op.drop_index(op.f("ix_users_id"), table_name="users") + op.drop_table("users")