Add enhanced authentication features
- Add role-based access control (admin/user roles) - Implement refresh token functionality - Add token revocation (logout) capability - Create admin-only endpoints - Add role validation middleware - Update documentation
This commit is contained in:
parent
1fca93299a
commit
4cfde1a74a
25
README.md
25
README.md
@ -4,7 +4,10 @@ A robust, RESTful API for managing todos, built with FastAPI and SQLite.
|
||||
|
||||
## Features
|
||||
|
||||
- 🔐 JWT Authentication
|
||||
- 🔐 Enhanced JWT Authentication
|
||||
- Access and Refresh tokens
|
||||
- Token revocation (logout)
|
||||
- Role-based access control (User/Admin roles)
|
||||
- 📝 Todo CRUD operations
|
||||
- 👤 User management
|
||||
- 🔍 Advanced todo filtering and pagination
|
||||
@ -63,6 +66,8 @@ The application can be configured using the following environment variables:
|
||||
|
||||
- `POST /api/v1/auth/register` - Register a new user
|
||||
- `POST /api/v1/auth/login` - Login and get access token
|
||||
- `POST /api/v1/auth/refresh` - Refresh access token using refresh token
|
||||
- `POST /api/v1/auth/logout` - Logout and revoke refresh token
|
||||
|
||||
### Users
|
||||
|
||||
@ -71,6 +76,12 @@ The application can be configured using the following environment variables:
|
||||
- `PUT /api/v1/users/me` - Update current user
|
||||
- `GET /api/v1/users/{user_id}` - Get user by ID
|
||||
|
||||
### Admin
|
||||
|
||||
- `GET /api/v1/admin/users` - List all users (admin only)
|
||||
- `GET /api/v1/admin/users/{user_id}` - Get user by ID (admin only)
|
||||
- `PUT /api/v1/admin/users/{user_id}` - Update user (admin only)
|
||||
|
||||
### Todos
|
||||
|
||||
- `GET /api/v1/todos/` - List todos (with filtering and pagination)
|
||||
@ -97,6 +108,7 @@ id: Integer (Primary Key)
|
||||
email: String (Unique, Indexed)
|
||||
hashed_password: String
|
||||
is_active: Boolean (Default: True)
|
||||
role: Enum(admin, user) (Default: user)
|
||||
```
|
||||
|
||||
### Todo Model
|
||||
@ -109,6 +121,17 @@ is_completed: Boolean (Default: False)
|
||||
owner_id: Integer (Foreign Key to User)
|
||||
```
|
||||
|
||||
### RefreshToken Model
|
||||
|
||||
```
|
||||
id: Integer (Primary Key)
|
||||
token: String (Unique, Indexed)
|
||||
expires_at: DateTime
|
||||
created_at: DateTime
|
||||
revoked: Boolean (Default: False)
|
||||
user_id: Integer (Foreign Key to User)
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Code Structure
|
||||
|
@ -1,3 +1,4 @@
|
||||
from functools import wraps
|
||||
from typing import Generator
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
@ -9,7 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.config import settings
|
||||
from app.crud.crud_user import get_user
|
||||
from app.db.session import SessionLocal
|
||||
from app.models.user import User
|
||||
from app.models.user import User, UserRole
|
||||
from app.schemas.token import TokenPayload
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
@ -49,4 +50,29 @@ def get_current_active_user(
|
||||
) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_admin_user(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> User:
|
||||
if current_user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="The user doesn't have enough privileges",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def check_role(required_role: UserRole):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, current_user: User = Depends(get_current_active_user), **kwargs):
|
||||
if current_user.role != required_role and current_user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="The user doesn't have enough privileges",
|
||||
)
|
||||
return await func(*args, current_user=current_user, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
@ -1,8 +1,9 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1.endpoints import auth, todos, users
|
||||
from app.api.v1.endpoints import admin, auth, todos, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(todos.router, prefix="/todos", tags=["todos"])
|
||||
api_router.include_router(todos.router, prefix="/todos", tags=["todos"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
|
66
app/api/v1/endpoints/admin.py
Normal file
66
app/api/v1/endpoints/admin.py
Normal file
@ -0,0 +1,66 @@
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_current_admin_user, get_db
|
||||
from app.crud.crud_user import get_user, get_users, update_user
|
||||
from app.models.user import User
|
||||
from app.schemas.user import User as UserSchema
|
||||
from app.schemas.user import UserUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserSchema])
|
||||
def admin_read_users(
|
||||
db: Session = Depends(get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: User = Depends(get_current_admin_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve all users (admin only).
|
||||
"""
|
||||
users = get_users(db, skip=skip, limit=limit)
|
||||
return users
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", response_model=UserSchema)
|
||||
def admin_update_user(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: int,
|
||||
user_in: UserUpdate,
|
||||
current_user: User = Depends(get_current_admin_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Update a user (admin only).
|
||||
"""
|
||||
user = get_user(db, user_id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The user with this id does not exist in the system",
|
||||
)
|
||||
user = update_user(db, db_obj=user, obj_in=user_in)
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserSchema)
|
||||
def admin_read_user(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_admin_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID (admin only).
|
||||
"""
|
||||
user = get_user(db, user_id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The user with this id does not exist in the system",
|
||||
)
|
||||
return user
|
@ -1,15 +1,29 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.api.deps import get_current_active_user, get_db
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token
|
||||
from app.crud.crud_user import authenticate, create_user, get_user_by_email
|
||||
from app.schemas.token import Token
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
)
|
||||
from app.crud.crud_token import (
|
||||
create_refresh_token as create_refresh_token_db,
|
||||
get_refresh_token,
|
||||
is_token_valid,
|
||||
revoke_refresh_token,
|
||||
)
|
||||
from app.crud.crud_user import (
|
||||
authenticate,
|
||||
create_user,
|
||||
get_user_by_email,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.token import RefreshTokenCreate, Token
|
||||
from app.schemas.user import UserCreate
|
||||
|
||||
router = APIRouter()
|
||||
@ -34,12 +48,29 @@ def login_access_token(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
|
||||
)
|
||||
|
||||
# Create access token
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token, access_token_expiry = create_access_token(
|
||||
user.id, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
# Create refresh token
|
||||
refresh_token, refresh_token_expiry = create_refresh_token(user.id)
|
||||
|
||||
# Save refresh token to database
|
||||
refresh_token_data = RefreshTokenCreate(
|
||||
token=refresh_token,
|
||||
expires_at=refresh_token_expiry,
|
||||
user_id=user.id
|
||||
)
|
||||
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(
|
||||
user.id, expires_delta=access_token_expires
|
||||
),
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"refresh_token": refresh_token,
|
||||
"expires_at": access_token_expiry,
|
||||
}
|
||||
|
||||
|
||||
@ -59,10 +90,84 @@ def register_user(
|
||||
detail="A user with this email already exists",
|
||||
)
|
||||
user = create_user(db, obj_in=user_in)
|
||||
|
||||
# Create access token
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token, access_token_expiry = create_access_token(
|
||||
user.id, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
# Create refresh token
|
||||
refresh_token, refresh_token_expiry = create_refresh_token(user.id)
|
||||
|
||||
# Save refresh token to database
|
||||
refresh_token_data = RefreshTokenCreate(
|
||||
token=refresh_token,
|
||||
expires_at=refresh_token_expiry,
|
||||
user_id=user.id
|
||||
)
|
||||
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
||||
|
||||
return {
|
||||
"access_token": create_access_token(
|
||||
user.id, expires_delta=access_token_expires
|
||||
),
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
"refresh_token": refresh_token,
|
||||
"expires_at": access_token_expiry,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
def refresh_access_token(
|
||||
db: Session = Depends(get_db), refresh_token: str = Body(...)
|
||||
) -> Any:
|
||||
"""
|
||||
Get a new access token using a refresh token
|
||||
"""
|
||||
db_token = get_refresh_token(db=db, token=refresh_token)
|
||||
if not db_token or not is_token_valid(db_token):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token, access_token_expiry = create_access_token(
|
||||
db_token.user_id, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
# Create new refresh token
|
||||
new_refresh_token, refresh_token_expiry = create_refresh_token(db_token.user_id)
|
||||
|
||||
# Revoke old refresh token
|
||||
revoke_refresh_token(db=db, token=refresh_token)
|
||||
|
||||
# Save new refresh token to database
|
||||
refresh_token_data = RefreshTokenCreate(
|
||||
token=new_refresh_token,
|
||||
expires_at=refresh_token_expiry,
|
||||
user_id=db_token.user_id
|
||||
)
|
||||
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"refresh_token": new_refresh_token,
|
||||
"expires_at": access_token_expiry,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, response_model=None)
|
||||
def logout(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
refresh_token: str = Body(...),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout by revoking the refresh token
|
||||
"""
|
||||
db_token = get_refresh_token(db=db, token=refresh_token)
|
||||
if db_token and db_token.user_id == current_user.id:
|
||||
revoke_refresh_token(db=db, token=refresh_token)
|
||||
return None
|
@ -1,5 +1,6 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
@ -8,19 +9,52 @@ from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Define token types
|
||||
ACCESS_TOKEN_TYPE = "access"
|
||||
REFRESH_TOKEN_TYPE = "refresh"
|
||||
|
||||
# Expiration times
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any], expires_delta: timedelta = None
|
||||
) -> str:
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> Tuple[str, datetime]:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
to_encode = {"exp": expire, "sub": str(subject), "type": ACCESS_TOKEN_TYPE}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
||||
return encoded_jwt
|
||||
return encoded_jwt, expire
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> Tuple[str, datetime]:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
# Generate a secure random token
|
||||
refresh_token = secrets.token_urlsafe(64)
|
||||
|
||||
# No need to encode anything as we'll store this in the database
|
||||
return refresh_token, expire
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=["HS256"]
|
||||
)
|
||||
return payload
|
||||
except jwt.JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
|
67
app/crud/crud_token.py
Normal file
67
app/crud/crud_token.py
Normal file
@ -0,0 +1,67 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.token import RefreshToken
|
||||
from app.schemas.token import RefreshTokenCreate, RefreshTokenUpdate
|
||||
|
||||
|
||||
def get_refresh_token(db: Session, token: str) -> Optional[RefreshToken]:
|
||||
return db.query(RefreshToken).filter(RefreshToken.token == token).first()
|
||||
|
||||
|
||||
def get_refresh_tokens_by_user(
|
||||
db: Session, user_id: int, skip: int = 0, limit: int = 100
|
||||
) -> List[RefreshToken]:
|
||||
return (
|
||||
db.query(RefreshToken)
|
||||
.filter(RefreshToken.user_id == user_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(db: Session, token_in: RefreshTokenCreate) -> RefreshToken:
|
||||
db_token = RefreshToken(
|
||||
token=token_in.token,
|
||||
expires_at=token_in.expires_at,
|
||||
user_id=token_in.user_id,
|
||||
created_at=datetime.utcnow(),
|
||||
revoked=False,
|
||||
)
|
||||
db.add(db_token)
|
||||
db.commit()
|
||||
db.refresh(db_token)
|
||||
return db_token
|
||||
|
||||
|
||||
def update_refresh_token(
|
||||
db: Session, db_obj: RefreshToken, obj_in: RefreshTokenUpdate
|
||||
) -> RefreshToken:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
def revoke_refresh_token(db: Session, token: str) -> Optional[RefreshToken]:
|
||||
db_token = get_refresh_token(db, token=token)
|
||||
if db_token:
|
||||
db_token.revoked = True
|
||||
db.add(db_token)
|
||||
db.commit()
|
||||
db.refresh(db_token)
|
||||
return db_token
|
||||
|
||||
|
||||
def is_token_valid(token: RefreshToken) -> bool:
|
||||
return (
|
||||
token is not None
|
||||
and not token.revoked
|
||||
and token.expires_at > datetime.utcnow()
|
||||
)
|
@ -26,6 +26,7 @@ def create_user(db: Session, obj_in: UserCreate) -> User:
|
||||
email=obj_in.email,
|
||||
hashed_password=get_password_hash(obj_in.password),
|
||||
is_active=True,
|
||||
role=obj_in.role,
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
|
19
app/models/token.py
Normal file
19
app/models/token.py
Normal file
@ -0,0 +1,19 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
token = Column(String, unique=True, index=True)
|
||||
expires_at = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
revoked = Column(Boolean, default=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"))
|
||||
|
||||
user = relationship("User")
|
@ -1,9 +1,15 @@
|
||||
from sqlalchemy import Boolean, Column, Integer, String
|
||||
from enum import Enum as PyEnum
|
||||
from sqlalchemy import Boolean, Column, Integer, String, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class UserRole(str, PyEnum):
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
@ -11,5 +17,6 @@ class User(Base):
|
||||
email = Column(String, unique=True, index=True)
|
||||
hashed_password = Column(String)
|
||||
is_active = Column(Boolean, default=True)
|
||||
role = Column(Enum(UserRole), default=UserRole.USER)
|
||||
|
||||
todos = relationship("Todo", back_populates="owner")
|
@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -6,7 +7,33 @@ from pydantic import BaseModel
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
refresh_token: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: Optional[int] = None
|
||||
sub: Optional[int] = None
|
||||
exp: Optional[datetime] = None
|
||||
type: Optional[str] = "access"
|
||||
|
||||
|
||||
class RefreshTokenCreate(BaseModel):
|
||||
token: str
|
||||
expires_at: datetime
|
||||
user_id: int
|
||||
|
||||
|
||||
class RefreshTokenUpdate(BaseModel):
|
||||
revoked: Optional[bool] = None
|
||||
|
||||
|
||||
class RefreshTokenInDB(BaseModel):
|
||||
id: int
|
||||
token: str
|
||||
expires_at: datetime
|
||||
created_at: datetime
|
||||
revoked: bool
|
||||
user_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
@ -2,17 +2,21 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from app.models.user import UserRole
|
||||
|
||||
|
||||
# Shared properties
|
||||
class UserBase(BaseModel):
|
||||
email: Optional[EmailStr] = None
|
||||
is_active: Optional[bool] = True
|
||||
role: Optional[UserRole] = None
|
||||
|
||||
|
||||
# Properties to receive via API on creation
|
||||
class UserCreate(UserBase):
|
||||
email: EmailStr
|
||||
password: str
|
||||
role: UserRole = UserRole.USER
|
||||
|
||||
|
||||
# Properties to receive via API on update
|
||||
@ -22,6 +26,7 @@ class UserUpdate(UserBase):
|
||||
|
||||
class UserInDBBase(UserBase):
|
||||
id: Optional[int] = None
|
||||
role: UserRole = UserRole.USER
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
30
migrations/versions/002_add_user_roles.py
Normal file
30
migrations/versions/002_add_user_roles.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""add user roles
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2023-11-16
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '002'
|
||||
down_revision = '001'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add role column to users table
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('role', sa.Enum('admin', 'user', name='userrole'), nullable=True))
|
||||
batch_op.execute("UPDATE users SET role = 'user' WHERE role IS NULL")
|
||||
batch_op.alter_column('role', nullable=False, server_default='user')
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Remove role column from users table
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.drop_column('role')
|
40
migrations/versions/003_add_refresh_tokens.py
Normal file
40
migrations/versions/003_add_refresh_tokens.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""add refresh tokens
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2023-11-16
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '003'
|
||||
down_revision = '002'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create refresh_tokens table
|
||||
op.create_table(
|
||||
'refresh_tokens',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('token', sa.String(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('revoked', sa.Boolean(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_refresh_tokens_id'), 'refresh_tokens', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_refresh_tokens_token'), 'refresh_tokens', ['token'], unique=True)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Drop refresh_tokens table
|
||||
op.drop_index(op.f('ix_refresh_tokens_token'), table_name='refresh_tokens')
|
||||
op.drop_index(op.f('ix_refresh_tokens_id'), table_name='refresh_tokens')
|
||||
op.drop_table('refresh_tokens')
|
Loading…
x
Reference in New Issue
Block a user