Update refresh token expiration from 30 to 7 days and format code

This commit is contained in:
Automated Action 2025-05-17 16:59:58 +00:00
parent ec714bf9f0
commit b43b669cab
19 changed files with 122 additions and 116 deletions

View File

@ -33,7 +33,7 @@ def register(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="A user with this email already exists", detail="A user with this email already exists",
) )
# Check if user with this username already exists # Check if user with this username already exists
user = user_service.get_by_username(db, username=user_in.username) user = user_service.get_by_username(db, username=user_in.username)
if user: if user:
@ -41,19 +41,19 @@ def register(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="A user with this username already exists", detail="A user with this username already exists",
) )
# Create new user # Create new user
user = user_service.create(db, obj_in=user_in) user = user_service.create(db, obj_in=user_in)
# Create access token # Create access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token( access_token = create_access_token(
subject=user.id, expires_delta=access_token_expires subject=user.id, expires_delta=access_token_expires
) )
# Create refresh token # Create refresh token
refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id) refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id)
return { return {
"access_token": access_token, "access_token": access_token,
"token_type": "bearer", "token_type": "bearer",
@ -79,23 +79,23 @@ def login(
detail="Incorrect username/email or password", detail="Incorrect username/email or password",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
if not user_service.is_active(user): if not user_service.is_active(user):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user", detail="Inactive user",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# Create access token # Create access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token( access_token = create_access_token(
subject=user.id, expires_delta=access_token_expires subject=user.id, expires_delta=access_token_expires
) )
# Create refresh token # Create refresh token
refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id) refresh_token_obj = token_service.create_refresh_token(db, user_id=user.id)
return { return {
"access_token": access_token, "access_token": access_token,
"token_type": "bearer", "token_type": "bearer",
@ -119,26 +119,28 @@ def refresh_token(
detail="Invalid refresh token", detail="Invalid refresh token",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
if not token_service.is_token_valid(refresh_token): if not token_service.is_token_valid(refresh_token):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token expired or revoked", detail="Refresh token expired or revoked",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# Revoke the used refresh token # Revoke the used refresh token
token_service.revoke_token(db, refresh_token) token_service.revoke_token(db, refresh_token)
# Create new access token # Create new access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token( access_token = create_access_token(
subject=refresh_token.user_id, expires_delta=access_token_expires subject=refresh_token.user_id, expires_delta=access_token_expires
) )
# Create new refresh token # Create new refresh token
new_refresh_token = token_service.create_refresh_token(db, user_id=refresh_token.user_id) new_refresh_token = token_service.create_refresh_token(
db, user_id=refresh_token.user_id
)
return { return {
"access_token": access_token, "access_token": access_token,
"token_type": "bearer", "token_type": "bearer",
@ -170,4 +172,4 @@ def logout_all(
Logout from all devices by revoking all refresh tokens Logout from all devices by revoking all refresh tokens
""" """
token_service.revoke_all_user_tokens(db, user_id=current_user.id) token_service.revoke_all_user_tokens(db, user_id=current_user.id)
return None return None

View File

@ -71,14 +71,14 @@ def create_user(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="A user with this email already exists", detail="A user with this email already exists",
) )
user = user_service.get_by_username(db, username=user_in.username) user = user_service.get_by_username(db, username=user_in.username)
if user: if user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="A user with this username already exists", detail="A user with this username already exists",
) )
user = user_service.create(db, obj_in=user_in) user = user_service.create(db, obj_in=user_in)
return user return user
@ -121,4 +121,4 @@ def update_user(
detail="User not found", detail="User not found",
) )
user = user_service.update(db, db_obj=user, obj_in=user_in) user = user_service.update(db, db_obj=user, obj_in=user_in)
return user return user

View File

@ -4,4 +4,4 @@ from app.api.v1.endpoints import auth, users
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(users.router, prefix="/users", tags=["users"])

View File

@ -3,17 +3,18 @@ from pathlib import Path
from pydantic import AnyHttpUrl, validator from pydantic import AnyHttpUrl, validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
class Settings(BaseSettings): class Settings(BaseSettings):
API_V1_STR: str = "/api/v1" API_V1_STR: str = "/api/v1"
PROJECT_NAME: str = "User Authentication Service" PROJECT_NAME: str = "User Authentication Service"
PROJECT_DESCRIPTION: str = "API service for user authentication" PROJECT_DESCRIPTION: str = "API service for user authentication"
PROJECT_VERSION: str = "0.1.0" PROJECT_VERSION: str = "0.1.0"
# Secret key for JWT token and other security mechanisms # Secret key for JWT token and other security mechanisms
SECRET_KEY: str = "YOUR_SUPER_SECRET_KEY_CHANGE_THIS_IN_PRODUCTION" SECRET_KEY: str = "YOUR_SUPER_SECRET_KEY_CHANGE_THIS_IN_PRODUCTION"
# 60 minutes * 24 hours * 8 days = 8 days in minutes # 60 minutes * 24 hours * 8 days = 8 days in minutes
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
# CORS # CORS
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []
@ -24,19 +25,20 @@ class Settings(BaseSettings):
elif isinstance(v, (list, str)): elif isinstance(v, (list, str)):
return v return v
raise ValueError(v) raise ValueError(v)
# Database # Database
DB_DIR: Path = Path("/app/storage/db") DB_DIR: Path = Path("/app/storage/db")
SQLALCHEMY_DATABASE_URL: str = f"sqlite:///{DB_DIR}/db.sqlite" SQLALCHEMY_DATABASE_URL: str = f"sqlite:///{DB_DIR}/db.sqlite"
# Token related # Token related
TOKEN_URL: str = f"{API_V1_STR}/auth/login" TOKEN_URL: str = f"{API_V1_STR}/auth/login"
class Config: class Config:
case_sensitive = True case_sensitive = True
env_file = ".env" env_file = ".env"
# Create the DB directory if it doesn't exist # Create the DB directory if it doesn't exist
Settings().DB_DIR.mkdir(parents=True, exist_ok=True) Settings().DB_DIR.mkdir(parents=True, exist_ok=True)
settings = Settings() settings = Settings()

View File

@ -6,17 +6,18 @@ from app.core.config import settings
engine = create_engine( engine = create_engine(
settings.SQLALCHEMY_DATABASE_URL, settings.SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False} # Only needed for SQLite connect_args={"check_same_thread": False}, # Only needed for SQLite
) )
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
# Dependency to get the database session # Dependency to get the database session
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() db.close()

View File

@ -1,4 +1,3 @@
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from jose import jwt from jose import jwt
@ -23,9 +22,7 @@ def get_current_user(
Get current user based on JWT token Get current user based on JWT token
""" """
try: try:
payload = jwt.decode( payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
token, settings.SECRET_KEY, algorithms=[ALGORITHM]
)
token_data = TokenPayload(**payload) token_data = TokenPayload(**payload)
except (jwt.JWTError, ValidationError): except (jwt.JWTError, ValidationError):
raise HTTPException( raise HTTPException(
@ -33,12 +30,11 @@ def get_current_user(
detail="Could not validate credentials", detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
user = user_service.get_by_id(db, user_id=token_data.sub) user = user_service.get_by_id(db, user_id=token_data.sub)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
detail="User not found"
) )
return user return user
@ -51,8 +47,7 @@ def get_current_active_user(
""" """
if not user_service.is_active(current_user): if not user_service.is_active(current_user):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
detail="Inactive user"
) )
return current_user return current_user
@ -66,6 +61,6 @@ def get_current_active_superuser(
if not user_service.is_superuser(current_user): if not user_service.is_superuser(current_user):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="The user doesn't have enough privileges" detail="The user doesn't have enough privileges",
) )
return current_user return current_user

View File

@ -39,4 +39,4 @@ def get_password_hash(password: str) -> str:
""" """
Hash a password for storing Hash a password for storing
""" """
return pwd_context.hash(password) return pwd_context.hash(password)

View File

@ -2,4 +2,4 @@ from app.models.user import User
from app.models.token import Token from app.models.token import Token
# Add all models here so they can be imported from app.models # Add all models here so they can be imported from app.models
__all__ = ["User", "Token"] __all__ = ["User", "Token"]

View File

@ -8,8 +8,10 @@ class Token(Base):
__tablename__ = "tokens" __tablename__ = "tokens"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) user_id = Column(
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False
)
token = Column(String, unique=True, index=True, nullable=False) token = Column(String, unique=True, index=True, nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False) expires_at = Column(DateTime(timezone=True), nullable=False)
is_revoked = Column(Boolean, default=False) is_revoked = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())

View File

@ -15,4 +15,6 @@ class User(Base):
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False) is_superuser = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)

View File

@ -4,15 +4,15 @@ from app.schemas.user import User, UserCreate, UserInDB, UserUpdate
# Add all schemas here so they can be imported from app.schemas # Add all schemas here so they can be imported from app.schemas
__all__ = [ __all__ = [
"User", "User",
"UserCreate", "UserCreate",
"UserInDB", "UserInDB",
"UserUpdate", "UserUpdate",
"Token", "Token",
"TokenPayload", "TokenPayload",
"TokenRefresh", "TokenRefresh",
"Message", "Message",
"HTTPValidationError", "HTTPValidationError",
"ResponseBase", "ResponseBase",
"ResponseData" "ResponseData",
] ]

View File

@ -23,4 +23,4 @@ class ResponseBase(BaseModel):
class ResponseData(ResponseBase): class ResponseData(ResponseBase):
data: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None data: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None

View File

@ -17,4 +17,4 @@ class TokenPayload(BaseModel):
class TokenRefresh(BaseModel): class TokenRefresh(BaseModel):
refresh_token: str refresh_token: str

View File

@ -36,4 +36,4 @@ class User(UserInDBBase):
class UserInDB(UserInDBBase): class UserInDB(UserInDBBase):
hashed_password: str hashed_password: str

View File

@ -19,22 +19,19 @@ def create_refresh_token(
if expires_delta: if expires_delta:
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
else: else:
expire = datetime.utcnow() + timedelta(days=30) # 30 days default expire = datetime.utcnow() + timedelta(days=7) # 7 days default
# Generate a secure random token # Generate a secure random token
token_value = secrets.token_urlsafe(64) token_value = secrets.token_urlsafe(64)
# Create token in DB # Create token in DB
db_token = Token( db_token = Token(
user_id=user_id, user_id=user_id, token=token_value, expires_at=expire, is_revoked=False
token=token_value,
expires_at=expire,
is_revoked=False
) )
db.add(db_token) db.add(db_token)
db.commit() db.commit()
db.refresh(db_token) db.refresh(db_token)
return db_token return db_token
@ -71,7 +68,7 @@ def revoke_all_user_tokens(db: Session, user_id: int) -> None:
tokens = db.query(Token).filter(Token.user_id == user_id).all() tokens = db.query(Token).filter(Token.user_id == user_id).all()
for token in tokens: for token in tokens:
token.is_revoked = True token.is_revoked = True
db.commit() db.commit()
@ -80,9 +77,7 @@ def decode_token(token: str) -> Optional[dict]:
Decode a JWT token Decode a JWT token
""" """
try: try:
payload = jwt.decode( payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
token, settings.SECRET_KEY, algorithms=[ALGORITHM]
)
return payload return payload
except JWTError: except JWTError:
return None return None

View File

@ -28,9 +28,7 @@ def get_by_username(db: Session, username: str) -> Optional[User]:
return db.query(User).filter(User.username == username).first() return db.query(User).filter(User.username == username).first()
def get_multi( def get_multi(db: Session, *, skip: int = 0, limit: int = 100) -> List[User]:
db: Session, *, skip: int = 0, limit: int = 100
) -> List[User]:
""" """
Get multiple users with pagination Get multiple users with pagination
""" """
@ -65,16 +63,16 @@ def update(
update_data = obj_in update_data = obj_in
else: else:
update_data = obj_in.dict(exclude_unset=True) update_data = obj_in.dict(exclude_unset=True)
if "password" in update_data and update_data["password"]: if "password" in update_data and update_data["password"]:
hashed_password = get_password_hash(update_data["password"]) hashed_password = get_password_hash(update_data["password"])
del update_data["password"] del update_data["password"]
update_data["hashed_password"] = hashed_password update_data["hashed_password"] = hashed_password
for field in update_data: for field in update_data:
if field in update_data: if field in update_data:
setattr(db_obj, field, update_data[field]) setattr(db_obj, field, update_data[field])
db.add(db_obj) db.add(db_obj)
db.commit() db.commit()
db.refresh(db_obj) db.refresh(db_obj)
@ -108,4 +106,4 @@ def is_superuser(user: User) -> bool:
""" """
Check if user is superuser Check if user is superuser
""" """
return user.is_superuser return user.is_superuser

View File

@ -24,11 +24,14 @@ if settings.BACKEND_CORS_ORIGINS:
app.include_router(api_router, prefix=settings.API_V1_STR) app.include_router(api_router, prefix=settings.API_V1_STR)
# Health check endpoint # Health check endpoint
@app.get("/health", status_code=200) @app.get("/health", status_code=200)
def health_check(): def health_check():
return {"status": "healthy"} return {"status": "healthy"}
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

View File

@ -65,7 +65,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
is_sqlite = connection.dialect.name == 'sqlite' is_sqlite = connection.dialect.name == "sqlite"
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=target_metadata, target_metadata=target_metadata,
@ -79,4 +79,4 @@ def run_migrations_online() -> None:
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:
run_migrations_online() run_migrations_online()

View File

@ -1,17 +1,18 @@
"""create user token tables """create user token tables
Revision ID: 001 Revision ID: 001
Revises: Revises:
Create Date: 2023-10-10 Create Date: 2023-10-10
""" """
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.sql import func from sqlalchemy.sql import func
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '001' revision = "001"
down_revision = None down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -20,46 +21,51 @@ depends_on = None
def upgrade() -> None: def upgrade() -> None:
# Create users table # Create users table
op.create_table( op.create_table(
'users', "users",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=False), sa.Column("email", sa.String(), nullable=False),
sa.Column('username', sa.String(), nullable=False), sa.Column("username", sa.String(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False), sa.Column("hashed_password", sa.String(), nullable=False),
sa.Column('full_name', sa.String(), nullable=True), sa.Column("full_name", sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), default=True), sa.Column("is_active", sa.Boolean(), default=True),
sa.Column('is_superuser', sa.Boolean(), default=False), sa.Column("is_superuser", sa.Boolean(), default=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=func.now()), sa.Column("created_at", sa.DateTime(timezone=True), server_default=func.now()),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=func.now(), onupdate=func.now()), sa.Column(
sa.PrimaryKeyConstraint('id') "updated_at",
sa.DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
),
sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
# Create tokens table # Create tokens table
op.create_table( op.create_table(
'tokens', "tokens",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column('token', sa.String(), nullable=False), sa.Column("token", sa.String(), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('is_revoked', sa.Boolean(), default=False), sa.Column("is_revoked", sa.Boolean(), default=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=func.now()), sa.Column("created_at", sa.DateTime(timezone=True), server_default=func.now()),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f('ix_tokens_id'), 'tokens', ['id'], unique=False) op.create_index(op.f("ix_tokens_id"), "tokens", ["id"], unique=False)
op.create_index(op.f('ix_tokens_token'), 'tokens', ['token'], unique=True) op.create_index(op.f("ix_tokens_token"), "tokens", ["token"], unique=True)
def downgrade() -> None: def downgrade() -> None:
op.drop_index(op.f('ix_tokens_token'), table_name='tokens') op.drop_index(op.f("ix_tokens_token"), table_name="tokens")
op.drop_index(op.f('ix_tokens_id'), table_name='tokens') op.drop_index(op.f("ix_tokens_id"), table_name="tokens")
op.drop_table('tokens') op.drop_table("tokens")
op.drop_index(op.f('ix_users_username'), table_name='users') op.drop_index(op.f("ix_users_username"), table_name="users")
op.drop_index(op.f('ix_users_email'), table_name='users') op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_index(op.f('ix_users_id'), table_name='users') op.drop_index(op.f("ix_users_id"), table_name="users")
op.drop_table('users') op.drop_table("users")