Update refresh token expiration from 30 to 7 days and format code
This commit is contained in:
parent
ec714bf9f0
commit
b43b669cab
@ -137,7 +137,9 @@ def refresh_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create new refresh token
|
# Create new refresh token
|
||||||
new_refresh_token = token_service.create_refresh_token(db, user_id=refresh_token.user_id)
|
new_refresh_token = token_service.create_refresh_token(
|
||||||
|
db, user_id=refresh_token.user_id
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from pydantic import AnyHttpUrl, validator
|
from pydantic import AnyHttpUrl, validator
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
PROJECT_NAME: str = "User Authentication Service"
|
PROJECT_NAME: str = "User Authentication Service"
|
||||||
@ -36,6 +37,7 @@ class Settings(BaseSettings):
|
|||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
||||||
# Create the DB directory if it doesn't exist
|
# Create the DB directory if it doesn't exist
|
||||||
Settings().DB_DIR.mkdir(parents=True, exist_ok=True)
|
Settings().DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -6,13 +6,14 @@ from app.core.config import settings
|
|||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
settings.SQLALCHEMY_DATABASE_URL,
|
settings.SQLALCHEMY_DATABASE_URL,
|
||||||
connect_args={"check_same_thread": False} # Only needed for SQLite
|
connect_args={"check_same_thread": False}, # Only needed for SQLite
|
||||||
)
|
)
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
# Dependency to get the database session
|
# Dependency to get the database session
|
||||||
def get_db():
|
def get_db():
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
@ -23,9 +22,7 @@ def get_current_user(
|
|||||||
Get current user based on JWT token
|
Get current user based on JWT token
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
token, settings.SECRET_KEY, algorithms=[ALGORITHM]
|
|
||||||
)
|
|
||||||
token_data = TokenPayload(**payload)
|
token_data = TokenPayload(**payload)
|
||||||
except (jwt.JWTError, ValidationError):
|
except (jwt.JWTError, ValidationError):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -37,8 +34,7 @@ def get_current_user(
|
|||||||
user = user_service.get_by_id(db, user_id=token_data.sub)
|
user = user_service.get_by_id(db, user_id=token_data.sub)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||||
detail="User not found"
|
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@ -51,8 +47,7 @@ def get_current_active_user(
|
|||||||
"""
|
"""
|
||||||
if not user_service.is_active(current_user):
|
if not user_service.is_active(current_user):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
|
||||||
detail="Inactive user"
|
|
||||||
)
|
)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
@ -66,6 +61,6 @@ def get_current_active_superuser(
|
|||||||
if not user_service.is_superuser(current_user):
|
if not user_service.is_superuser(current_user):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="The user doesn't have enough privileges"
|
detail="The user doesn't have enough privileges",
|
||||||
)
|
)
|
||||||
return current_user
|
return current_user
|
@ -8,7 +8,9 @@ class Token(Base):
|
|||||||
__tablename__ = "tokens"
|
__tablename__ = "tokens"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
user_id = Column(
|
||||||
|
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
token = Column(String, unique=True, index=True, nullable=False)
|
token = Column(String, unique=True, index=True, nullable=False)
|
||||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
is_revoked = Column(Boolean, default=False)
|
is_revoked = Column(Boolean, default=False)
|
||||||
|
@ -15,4 +15,6 @@ class User(Base):
|
|||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
is_superuser = Column(Boolean, default=False)
|
is_superuser = Column(Boolean, default=False)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
updated_at = Column(
|
||||||
|
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
@ -14,5 +14,5 @@ __all__ = [
|
|||||||
"Message",
|
"Message",
|
||||||
"HTTPValidationError",
|
"HTTPValidationError",
|
||||||
"ResponseBase",
|
"ResponseBase",
|
||||||
"ResponseData"
|
"ResponseData",
|
||||||
]
|
]
|
@ -19,17 +19,14 @@ def create_refresh_token(
|
|||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.utcnow() + timedelta(days=30) # 30 days default
|
expire = datetime.utcnow() + timedelta(days=7) # 7 days default
|
||||||
|
|
||||||
# Generate a secure random token
|
# Generate a secure random token
|
||||||
token_value = secrets.token_urlsafe(64)
|
token_value = secrets.token_urlsafe(64)
|
||||||
|
|
||||||
# Create token in DB
|
# Create token in DB
|
||||||
db_token = Token(
|
db_token = Token(
|
||||||
user_id=user_id,
|
user_id=user_id, token=token_value, expires_at=expire, is_revoked=False
|
||||||
token=token_value,
|
|
||||||
expires_at=expire,
|
|
||||||
is_revoked=False
|
|
||||||
)
|
)
|
||||||
db.add(db_token)
|
db.add(db_token)
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -80,9 +77,7 @@ def decode_token(token: str) -> Optional[dict]:
|
|||||||
Decode a JWT token
|
Decode a JWT token
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
token, settings.SECRET_KEY, algorithms=[ALGORITHM]
|
|
||||||
)
|
|
||||||
return payload
|
return payload
|
||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return None
|
@ -28,9 +28,7 @@ def get_by_username(db: Session, username: str) -> Optional[User]:
|
|||||||
return db.query(User).filter(User.username == username).first()
|
return db.query(User).filter(User.username == username).first()
|
||||||
|
|
||||||
|
|
||||||
def get_multi(
|
def get_multi(db: Session, *, skip: int = 0, limit: int = 100) -> List[User]:
|
||||||
db: Session, *, skip: int = 0, limit: int = 100
|
|
||||||
) -> List[User]:
|
|
||||||
"""
|
"""
|
||||||
Get multiple users with pagination
|
Get multiple users with pagination
|
||||||
"""
|
"""
|
||||||
|
3
main.py
3
main.py
@ -24,11 +24,14 @@ if settings.BACKEND_CORS_ORIGINS:
|
|||||||
|
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|
||||||
|
|
||||||
# Health check endpoint
|
# Health check endpoint
|
||||||
@app.get("/health", status_code=200)
|
@app.get("/health", status_code=200)
|
||||||
def health_check():
|
def health_check():
|
||||||
return {"status": "healthy"}
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
@ -65,7 +65,7 @@ def run_migrations_online() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
is_sqlite = connection.dialect.name == 'sqlite'
|
is_sqlite = connection.dialect.name == "sqlite"
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
target_metadata=target_metadata,
|
target_metadata=target_metadata,
|
||||||
|
@ -5,13 +5,14 @@ Revises:
|
|||||||
Create Date: 2023-10-10
|
Create Date: 2023-10-10
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = '001'
|
revision = "001"
|
||||||
down_revision = None
|
down_revision = None
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
@ -20,46 +21,51 @@ depends_on = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Create users table
|
# Create users table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'users',
|
"users",
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column('email', sa.String(), nullable=False),
|
sa.Column("email", sa.String(), nullable=False),
|
||||||
sa.Column('username', sa.String(), nullable=False),
|
sa.Column("username", sa.String(), nullable=False),
|
||||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
sa.Column("hashed_password", sa.String(), nullable=False),
|
||||||
sa.Column('full_name', sa.String(), nullable=True),
|
sa.Column("full_name", sa.String(), nullable=True),
|
||||||
sa.Column('is_active', sa.Boolean(), default=True),
|
sa.Column("is_active", sa.Boolean(), default=True),
|
||||||
sa.Column('is_superuser', sa.Boolean(), default=False),
|
sa.Column("is_superuser", sa.Boolean(), default=False),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=func.now()),
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=func.now()),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=func.now(), onupdate=func.now()),
|
sa.Column(
|
||||||
sa.PrimaryKeyConstraint('id')
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)
|
||||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
|
||||||
|
|
||||||
# Create tokens table
|
# Create tokens table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'tokens',
|
"tokens",
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
sa.Column('token', sa.String(), nullable=False),
|
sa.Column("token", sa.String(), nullable=False),
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column('is_revoked', sa.Boolean(), default=False),
|
sa.Column("is_revoked", sa.Boolean(), default=False),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=func.now()),
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=func.now()),
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.create_index(op.f('ix_tokens_id'), 'tokens', ['id'], unique=False)
|
op.create_index(op.f("ix_tokens_id"), "tokens", ["id"], unique=False)
|
||||||
op.create_index(op.f('ix_tokens_token'), 'tokens', ['token'], unique=True)
|
op.create_index(op.f("ix_tokens_token"), "tokens", ["token"], unique=True)
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_index(op.f('ix_tokens_token'), table_name='tokens')
|
op.drop_index(op.f("ix_tokens_token"), table_name="tokens")
|
||||||
op.drop_index(op.f('ix_tokens_id'), table_name='tokens')
|
op.drop_index(op.f("ix_tokens_id"), table_name="tokens")
|
||||||
op.drop_table('tokens')
|
op.drop_table("tokens")
|
||||||
|
|
||||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
op.drop_index(op.f("ix_users_username"), table_name="users")
|
||||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
op.drop_index(op.f("ix_users_id"), table_name="users")
|
||||||
op.drop_table('users')
|
op.drop_table("users")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user