Add enhanced authentication features
- Add role-based access control (admin/user roles) - Implement refresh token functionality - Add token revocation (logout) capability - Create admin-only endpoints - Add role validation middleware - Update documentation
This commit is contained in:
parent
1fca93299a
commit
4cfde1a74a
25
README.md
25
README.md
@ -4,7 +4,10 @@ A robust, RESTful API for managing todos, built with FastAPI and SQLite.
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- 🔐 JWT Authentication
|
- 🔐 Enhanced JWT Authentication
|
||||||
|
- Access and Refresh tokens
|
||||||
|
- Token revocation (logout)
|
||||||
|
- Role-based access control (User/Admin roles)
|
||||||
- 📝 Todo CRUD operations
|
- 📝 Todo CRUD operations
|
||||||
- 👤 User management
|
- 👤 User management
|
||||||
- 🔍 Advanced todo filtering and pagination
|
- 🔍 Advanced todo filtering and pagination
|
||||||
@ -63,6 +66,8 @@ The application can be configured using the following environment variables:
|
|||||||
|
|
||||||
- `POST /api/v1/auth/register` - Register a new user
|
- `POST /api/v1/auth/register` - Register a new user
|
||||||
- `POST /api/v1/auth/login` - Login and get access token
|
- `POST /api/v1/auth/login` - Login and get access token
|
||||||
|
- `POST /api/v1/auth/refresh` - Refresh access token using refresh token
|
||||||
|
- `POST /api/v1/auth/logout` - Logout and revoke refresh token
|
||||||
|
|
||||||
### Users
|
### Users
|
||||||
|
|
||||||
@ -71,6 +76,12 @@ The application can be configured using the following environment variables:
|
|||||||
- `PUT /api/v1/users/me` - Update current user
|
- `PUT /api/v1/users/me` - Update current user
|
||||||
- `GET /api/v1/users/{user_id}` - Get user by ID
|
- `GET /api/v1/users/{user_id}` - Get user by ID
|
||||||
|
|
||||||
|
### Admin
|
||||||
|
|
||||||
|
- `GET /api/v1/admin/users` - List all users (admin only)
|
||||||
|
- `GET /api/v1/admin/users/{user_id}` - Get user by ID (admin only)
|
||||||
|
- `PUT /api/v1/admin/users/{user_id}` - Update user (admin only)
|
||||||
|
|
||||||
### Todos
|
### Todos
|
||||||
|
|
||||||
- `GET /api/v1/todos/` - List todos (with filtering and pagination)
|
- `GET /api/v1/todos/` - List todos (with filtering and pagination)
|
||||||
@ -97,6 +108,7 @@ id: Integer (Primary Key)
|
|||||||
email: String (Unique, Indexed)
|
email: String (Unique, Indexed)
|
||||||
hashed_password: String
|
hashed_password: String
|
||||||
is_active: Boolean (Default: True)
|
is_active: Boolean (Default: True)
|
||||||
|
role: Enum(admin, user) (Default: user)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Todo Model
|
### Todo Model
|
||||||
@ -109,6 +121,17 @@ is_completed: Boolean (Default: False)
|
|||||||
owner_id: Integer (Foreign Key to User)
|
owner_id: Integer (Foreign Key to User)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### RefreshToken Model
|
||||||
|
|
||||||
|
```
|
||||||
|
id: Integer (Primary Key)
|
||||||
|
token: String (Unique, Indexed)
|
||||||
|
expires_at: DateTime
|
||||||
|
created_at: DateTime
|
||||||
|
revoked: Boolean (Default: False)
|
||||||
|
user_id: Integer (Foreign Key to User)
|
||||||
|
```
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
### Code Structure
|
### Code Structure
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from functools import wraps
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
@ -9,7 +10,7 @@ from sqlalchemy.orm import Session
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.crud.crud_user import get_user
|
from app.crud.crud_user import get_user
|
||||||
from app.db.session import SessionLocal
|
from app.db.session import SessionLocal
|
||||||
from app.models.user import User
|
from app.models.user import User, UserRole
|
||||||
from app.schemas.token import TokenPayload
|
from app.schemas.token import TokenPayload
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(
|
oauth2_scheme = OAuth2PasswordBearer(
|
||||||
@ -49,4 +50,29 @@ def get_current_active_user(
|
|||||||
) -> User:
|
) -> User:
|
||||||
if not current_user.is_active:
|
if not current_user.is_active:
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_admin_user(
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
) -> User:
|
||||||
|
if current_user.role != UserRole.ADMIN:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="The user doesn't have enough privileges",
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def check_role(required_role: UserRole):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, current_user: User = Depends(get_current_active_user), **kwargs):
|
||||||
|
if current_user.role != required_role and current_user.role != UserRole.ADMIN:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="The user doesn't have enough privileges",
|
||||||
|
)
|
||||||
|
return await func(*args, current_user=current_user, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
@ -1,8 +1,9 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api.v1.endpoints import auth, todos, users
|
from app.api.v1.endpoints import admin, auth, todos, users
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||||
api_router.include_router(todos.router, prefix="/todos", tags=["todos"])
|
api_router.include_router(todos.router, prefix="/todos", tags=["todos"])
|
||||||
|
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
|
66
app/api/v1/endpoints/admin.py
Normal file
66
app/api/v1/endpoints/admin.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.deps import get_current_admin_user, get_db
|
||||||
|
from app.crud.crud_user import get_user, get_users, update_user
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.user import User as UserSchema
|
||||||
|
from app.schemas.user import UserUpdate
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users", response_model=List[UserSchema])
|
||||||
|
def admin_read_users(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
current_user: User = Depends(get_current_admin_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Retrieve all users (admin only).
|
||||||
|
"""
|
||||||
|
users = get_users(db, skip=skip, limit=limit)
|
||||||
|
return users
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/users/{user_id}", response_model=UserSchema)
|
||||||
|
def admin_update_user(
|
||||||
|
*,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user_id: int,
|
||||||
|
user_in: UserUpdate,
|
||||||
|
current_user: User = Depends(get_current_admin_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Update a user (admin only).
|
||||||
|
"""
|
||||||
|
user = get_user(db, user_id=user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The user with this id does not exist in the system",
|
||||||
|
)
|
||||||
|
user = update_user(db, db_obj=user, obj_in=user_in)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users/{user_id}", response_model=UserSchema)
|
||||||
|
def admin_read_user(
|
||||||
|
*,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
user_id: int,
|
||||||
|
current_user: User = Depends(get_current_admin_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get user by ID (admin only).
|
||||||
|
"""
|
||||||
|
user = get_user(db, user_id=user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The user with this id does not exist in the system",
|
||||||
|
)
|
||||||
|
return user
|
@ -1,15 +1,29 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.api.deps import get_db
|
from app.api.deps import get_current_active_user, get_db
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.security import create_access_token
|
from app.core.security import (
|
||||||
from app.crud.crud_user import authenticate, create_user, get_user_by_email
|
create_access_token,
|
||||||
from app.schemas.token import Token
|
create_refresh_token,
|
||||||
|
)
|
||||||
|
from app.crud.crud_token import (
|
||||||
|
create_refresh_token as create_refresh_token_db,
|
||||||
|
get_refresh_token,
|
||||||
|
is_token_valid,
|
||||||
|
revoke_refresh_token,
|
||||||
|
)
|
||||||
|
from app.crud.crud_user import (
|
||||||
|
authenticate,
|
||||||
|
create_user,
|
||||||
|
get_user_by_email,
|
||||||
|
)
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.token import RefreshTokenCreate, Token
|
||||||
from app.schemas.user import UserCreate
|
from app.schemas.user import UserCreate
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -34,12 +48,29 @@ def login_access_token(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
|
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create access token
|
||||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
access_token, access_token_expiry = create_access_token(
|
||||||
|
user.id, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create refresh token
|
||||||
|
refresh_token, refresh_token_expiry = create_refresh_token(user.id)
|
||||||
|
|
||||||
|
# Save refresh token to database
|
||||||
|
refresh_token_data = RefreshTokenCreate(
|
||||||
|
token=refresh_token,
|
||||||
|
expires_at=refresh_token_expiry,
|
||||||
|
user_id=user.id
|
||||||
|
)
|
||||||
|
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"access_token": create_access_token(
|
"access_token": access_token,
|
||||||
user.id, expires_delta=access_token_expires
|
|
||||||
),
|
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"expires_at": access_token_expiry,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -59,10 +90,84 @@ def register_user(
|
|||||||
detail="A user with this email already exists",
|
detail="A user with this email already exists",
|
||||||
)
|
)
|
||||||
user = create_user(db, obj_in=user_in)
|
user = create_user(db, obj_in=user_in)
|
||||||
|
|
||||||
|
# Create access token
|
||||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
access_token, access_token_expiry = create_access_token(
|
||||||
|
user.id, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create refresh token
|
||||||
|
refresh_token, refresh_token_expiry = create_refresh_token(user.id)
|
||||||
|
|
||||||
|
# Save refresh token to database
|
||||||
|
refresh_token_data = RefreshTokenCreate(
|
||||||
|
token=refresh_token,
|
||||||
|
expires_at=refresh_token_expiry,
|
||||||
|
user_id=user.id
|
||||||
|
)
|
||||||
|
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"access_token": create_access_token(
|
"access_token": access_token,
|
||||||
user.id, expires_delta=access_token_expires
|
|
||||||
),
|
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
}
|
"refresh_token": refresh_token,
|
||||||
|
"expires_at": access_token_expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token)
|
||||||
|
def refresh_access_token(
|
||||||
|
db: Session = Depends(get_db), refresh_token: str = Body(...)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get a new access token using a refresh token
|
||||||
|
"""
|
||||||
|
db_token = get_refresh_token(db=db, token=refresh_token)
|
||||||
|
if not db_token or not is_token_valid(db_token):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new access token
|
||||||
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
access_token, access_token_expiry = create_access_token(
|
||||||
|
db_token.user_id, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new refresh token
|
||||||
|
new_refresh_token, refresh_token_expiry = create_refresh_token(db_token.user_id)
|
||||||
|
|
||||||
|
# Revoke old refresh token
|
||||||
|
revoke_refresh_token(db=db, token=refresh_token)
|
||||||
|
|
||||||
|
# Save new refresh token to database
|
||||||
|
refresh_token_data = RefreshTokenCreate(
|
||||||
|
token=new_refresh_token,
|
||||||
|
expires_at=refresh_token_expiry,
|
||||||
|
user_id=db_token.user_id
|
||||||
|
)
|
||||||
|
create_refresh_token_db(db=db, token_in=refresh_token_data)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"refresh_token": new_refresh_token,
|
||||||
|
"expires_at": access_token_expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, response_model=None)
|
||||||
|
def logout(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
refresh_token: str = Body(...),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Logout by revoking the refresh token
|
||||||
|
"""
|
||||||
|
db_token = get_refresh_token(db=db, token=refresh_token)
|
||||||
|
if db_token and db_token.user_id == current_user.id:
|
||||||
|
revoke_refresh_token(db=db, token=refresh_token)
|
||||||
|
return None
|
@ -1,5 +1,6 @@
|
|||||||
|
import secrets
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
@ -8,19 +9,52 @@ from app.core.config import settings
|
|||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
# Define token types
|
||||||
|
ACCESS_TOKEN_TYPE = "access"
|
||||||
|
REFRESH_TOKEN_TYPE = "refresh"
|
||||||
|
|
||||||
|
# Expiration times
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS = 30
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(
|
def create_access_token(
|
||||||
subject: Union[str, Any], expires_delta: timedelta = None
|
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||||
) -> str:
|
) -> Tuple[str, datetime]:
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.utcnow() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.utcnow() + timedelta(
|
expire = datetime.utcnow() + timedelta(
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
)
|
)
|
||||||
to_encode = {"exp": expire, "sub": str(subject)}
|
to_encode = {"exp": expire, "sub": str(subject), "type": ACCESS_TOKEN_TYPE}
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
||||||
return encoded_jwt
|
return encoded_jwt, expire
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(
|
||||||
|
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||||
|
) -> Tuple[str, datetime]:
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
# Generate a secure random token
|
||||||
|
refresh_token = secrets.token_urlsafe(64)
|
||||||
|
|
||||||
|
# No need to encode anything as we'll store this in the database
|
||||||
|
return refresh_token, expire
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.SECRET_KEY, algorithms=["HS256"]
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
except jwt.JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
67
app/crud/crud_token.py
Normal file
67
app/crud/crud_token.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.models.token import RefreshToken
|
||||||
|
from app.schemas.token import RefreshTokenCreate, RefreshTokenUpdate
|
||||||
|
|
||||||
|
|
||||||
|
def get_refresh_token(db: Session, token: str) -> Optional[RefreshToken]:
|
||||||
|
return db.query(RefreshToken).filter(RefreshToken.token == token).first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_refresh_tokens_by_user(
|
||||||
|
db: Session, user_id: int, skip: int = 0, limit: int = 100
|
||||||
|
) -> List[RefreshToken]:
|
||||||
|
return (
|
||||||
|
db.query(RefreshToken)
|
||||||
|
.filter(RefreshToken.user_id == user_id)
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(db: Session, token_in: RefreshTokenCreate) -> RefreshToken:
|
||||||
|
db_token = RefreshToken(
|
||||||
|
token=token_in.token,
|
||||||
|
expires_at=token_in.expires_at,
|
||||||
|
user_id=token_in.user_id,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
revoked=False,
|
||||||
|
)
|
||||||
|
db.add(db_token)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_token)
|
||||||
|
return db_token
|
||||||
|
|
||||||
|
|
||||||
|
def update_refresh_token(
|
||||||
|
db: Session, db_obj: RefreshToken, obj_in: RefreshTokenUpdate
|
||||||
|
) -> RefreshToken:
|
||||||
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
for field in update_data:
|
||||||
|
setattr(db_obj, field, update_data[field])
|
||||||
|
db.add(db_obj)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_refresh_token(db: Session, token: str) -> Optional[RefreshToken]:
|
||||||
|
db_token = get_refresh_token(db, token=token)
|
||||||
|
if db_token:
|
||||||
|
db_token.revoked = True
|
||||||
|
db.add(db_token)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_token)
|
||||||
|
return db_token
|
||||||
|
|
||||||
|
|
||||||
|
def is_token_valid(token: RefreshToken) -> bool:
|
||||||
|
return (
|
||||||
|
token is not None
|
||||||
|
and not token.revoked
|
||||||
|
and token.expires_at > datetime.utcnow()
|
||||||
|
)
|
@ -26,6 +26,7 @@ def create_user(db: Session, obj_in: UserCreate) -> User:
|
|||||||
email=obj_in.email,
|
email=obj_in.email,
|
||||||
hashed_password=get_password_hash(obj_in.password),
|
hashed_password=get_password_hash(obj_in.password),
|
||||||
is_active=True,
|
is_active=True,
|
||||||
|
role=obj_in.role,
|
||||||
)
|
)
|
||||||
db.add(db_obj)
|
db.add(db_obj)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
19
app/models/token.py
Normal file
19
app/models/token.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
token = Column(String, unique=True, index=True)
|
||||||
|
expires_at = Column(DateTime)
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
revoked = Column(Boolean, default=False)
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"))
|
||||||
|
|
||||||
|
user = relationship("User")
|
@ -1,9 +1,15 @@
|
|||||||
from sqlalchemy import Boolean, Column, Integer, String
|
from enum import Enum as PyEnum
|
||||||
|
from sqlalchemy import Boolean, Column, Integer, String, Enum
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from app.db.base import Base
|
from app.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(str, PyEnum):
|
||||||
|
ADMIN = "admin"
|
||||||
|
USER = "user"
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
@ -11,5 +17,6 @@ class User(Base):
|
|||||||
email = Column(String, unique=True, index=True)
|
email = Column(String, unique=True, index=True)
|
||||||
hashed_password = Column(String)
|
hashed_password = Column(String)
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
|
role = Column(Enum(UserRole), default=UserRole.USER)
|
||||||
|
|
||||||
todos = relationship("Todo", back_populates="owner")
|
todos = relationship("Todo", back_populates="owner")
|
@ -1,3 +1,4 @@
|
|||||||
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -6,7 +7,33 @@ from pydantic import BaseModel
|
|||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
refresh_token: Optional[str] = None
|
||||||
|
expires_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
class TokenPayload(BaseModel):
|
class TokenPayload(BaseModel):
|
||||||
sub: Optional[int] = None
|
sub: Optional[int] = None
|
||||||
|
exp: Optional[datetime] = None
|
||||||
|
type: Optional[str] = "access"
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenCreate(BaseModel):
|
||||||
|
token: str
|
||||||
|
expires_at: datetime
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenUpdate(BaseModel):
|
||||||
|
revoked: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenInDB(BaseModel):
|
||||||
|
id: int
|
||||||
|
token: str
|
||||||
|
expires_at: datetime
|
||||||
|
created_at: datetime
|
||||||
|
revoked: bool
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
@ -2,17 +2,21 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
|
from app.models.user import UserRole
|
||||||
|
|
||||||
|
|
||||||
# Shared properties
|
# Shared properties
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
is_active: Optional[bool] = True
|
is_active: Optional[bool] = True
|
||||||
|
role: Optional[UserRole] = None
|
||||||
|
|
||||||
|
|
||||||
# Properties to receive via API on creation
|
# Properties to receive via API on creation
|
||||||
class UserCreate(UserBase):
|
class UserCreate(UserBase):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
|
role: UserRole = UserRole.USER
|
||||||
|
|
||||||
|
|
||||||
# Properties to receive via API on update
|
# Properties to receive via API on update
|
||||||
@ -22,6 +26,7 @@ class UserUpdate(UserBase):
|
|||||||
|
|
||||||
class UserInDBBase(UserBase):
|
class UserInDBBase(UserBase):
|
||||||
id: Optional[int] = None
|
id: Optional[int] = None
|
||||||
|
role: UserRole = UserRole.USER
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
30
migrations/versions/002_add_user_roles.py
Normal file
30
migrations/versions/002_add_user_roles.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""add user roles
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2023-11-16
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '002'
|
||||||
|
down_revision = '001'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# Add role column to users table
|
||||||
|
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('role', sa.Enum('admin', 'user', name='userrole'), nullable=True))
|
||||||
|
batch_op.execute("UPDATE users SET role = 'user' WHERE role IS NULL")
|
||||||
|
batch_op.alter_column('role', nullable=False, server_default='user')
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# Remove role column from users table
|
||||||
|
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('role')
|
40
migrations/versions/003_add_refresh_tokens.py
Normal file
40
migrations/versions/003_add_refresh_tokens.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
"""add refresh tokens
|
||||||
|
|
||||||
|
Revision ID: 003
|
||||||
|
Revises: 002
|
||||||
|
Create Date: 2023-11-16
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '003'
|
||||||
|
down_revision = '002'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# Create refresh_tokens table
|
||||||
|
op.create_table(
|
||||||
|
'refresh_tokens',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('token', sa.String(), nullable=True),
|
||||||
|
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('revoked', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_refresh_tokens_id'), 'refresh_tokens', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_refresh_tokens_token'), 'refresh_tokens', ['token'], unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# Drop refresh_tokens table
|
||||||
|
op.drop_index(op.f('ix_refresh_tokens_token'), table_name='refresh_tokens')
|
||||||
|
op.drop_index(op.f('ix_refresh_tokens_id'), table_name='refresh_tokens')
|
||||||
|
op.drop_table('refresh_tokens')
|
Loading…
x
Reference in New Issue
Block a user