diff --git a/app/api/endpoints/auth.py b/app/api/endpoints/auth.py index 3fa2d3f..9fffaf7 100644 --- a/app/api/endpoints/auth.py +++ b/app/api/endpoints/auth.py @@ -26,11 +26,11 @@ def login_access_token( """ OAuth2 compatible token login, get an access token for future requests """ - user = user_service.authenticate(db, email=form_data.username, password=form_data.password) + user = user_service.authenticate(db, username=form_data.username, password=form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or password", + detail="Incorrect username or password", ) elif not user_service.is_active(user): raise HTTPException( @@ -55,15 +55,21 @@ def register_user( """ Register a new user and return an access token """ - user = user_service.get_by_email(db, email=form_data.username) - if user: + existing_user = user_service.get_by_username(db, username=form_data.username) + if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="A user with this email already exists", + detail="A user with this username already exists", ) - # Create user - user_in = UserCreate(email=form_data.username, password=form_data.password) + # Create user - using username as the primary identifier, but also require an email + # For OAuth2PasswordRequestForm, we'll use the username field from the form + # and set a default email based on the username + user_in = UserCreate( + username=form_data.username, + email=f"{form_data.username}@example.com", # Default email since OAuth2 form doesn't have email field + password=form_data.password + ) user = user_service.create(db, obj_in=user_in) # Generate token diff --git a/app/api/endpoints/users.py b/app/api/endpoints/users.py index 12dea56..b9f6ef2 100644 --- a/app/api/endpoints/users.py +++ b/app/api/endpoints/users.py @@ -38,12 +38,22 @@ def create_user( """ Create new user. """ + # Check if username already exists + user = user_service.get_by_username(db, username=user_in.username) + if user: + raise HTTPException( + status_code=400, + detail="A user with this username already exists", + ) + + # Also check email for uniqueness user = user_service.get_by_email(db, email=user_in.email) if user: raise HTTPException( status_code=400, detail="A user with this email already exists", ) + user = user_service.create(db, obj_in=user_in) return user @@ -54,6 +64,7 @@ def update_user_me( db: Session = Depends(get_db), password: str = Body(None), full_name: str = Body(None), + username: str = Body(None), email: EmailStr = Body(None), current_user: User = Depends(get_current_active_user), ) -> Any: @@ -62,12 +73,32 @@ def update_user_me( """ current_user_data = jsonable_encoder(current_user) user_in = UserUpdate(**current_user_data) + if password is not None: user_in.password = password if full_name is not None: user_in.full_name = full_name - if email is not None: + + # Check for username uniqueness if username is being updated + if username is not None and username != current_user.username: + existing_user = user_service.get_by_username(db, username=username) + if existing_user: + raise HTTPException( + status_code=400, + detail="A user with this username already exists", + ) + user_in.username = username + + # Check for email uniqueness if email is being updated + if email is not None and email != current_user.email: + existing_user = user_service.get_by_email(db, email=email) + if existing_user: + raise HTTPException( + status_code=400, + detail="A user with this email already exists", + ) user_in.email = email + user = user_service.update(db, db_obj=current_user, obj_in=user_in) return user diff --git a/app/models/user.py b/app/models/user.py index 22fdfcb..62de70e 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -7,6 +7,7 @@ from app.db.base_class import Base class User(Base): id = Column(String, primary_key=True, index=True) + username = Column(String, unique=True, index=True) email = Column(String, unique=True, index=True) hashed_password = Column(String) full_name = Column(String, index=True) diff --git a/app/schemas/user.py b/app/schemas/user.py index 5acecd9..968a0ce 100644 --- a/app/schemas/user.py +++ b/app/schemas/user.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, EmailStr # Shared properties class UserBase(BaseModel): + username: Optional[str] = None email: Optional[EmailStr] = None full_name: Optional[str] = None is_active: Optional[bool] = True @@ -13,6 +14,7 @@ class UserBase(BaseModel): # Properties to receive via API on creation class UserCreate(UserBase): + username: str email: EmailStr password: str diff --git a/app/services/user.py b/app/services/user.py index dc6a1f6..828bc39 100644 --- a/app/services/user.py +++ b/app/services/user.py @@ -16,9 +16,14 @@ def get_by_email(db: Session, email: str) -> Optional[User]: return db.query(User).filter(User.email == email).first() +def get_by_username(db: Session, username: str) -> Optional[User]: + return db.query(User).filter(User.username == username).first() + + def create(db: Session, *, obj_in: UserCreate) -> User: db_obj = User( id=str(uuid.uuid4()), + username=obj_in.username, email=obj_in.email, hashed_password=get_password_hash(obj_in.password), full_name=obj_in.full_name, @@ -46,8 +51,8 @@ def update(db: Session, *, db_obj: User, obj_in: UserUpdate) -> User: return db_obj -def authenticate(db: Session, *, email: str, password: str) -> Optional[User]: - user = get_by_email(db, email=email) +def authenticate(db: Session, *, username: str, password: str) -> Optional[User]: + user = get_by_username(db, username=username) if not user: return None if not verify_password(password, user.hashed_password): diff --git a/migrations/versions/2a3b4c5d6e7f_add_username_field.py b/migrations/versions/2a3b4c5d6e7f_add_username_field.py new file mode 100644 index 0000000..bc46334 --- /dev/null +++ b/migrations/versions/2a3b4c5d6e7f_add_username_field.py @@ -0,0 +1,42 @@ +"""Add username field + +Revision ID: 2a3b4c5d6e7f +Revises: 1a2b3c4d5e6f +Create Date: 2023-11-17 00:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2a3b4c5d6e7f' +down_revision = '1a2b3c4d5e6f' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add username column to user table + with op.batch_alter_table('user') as batch_op: + batch_op.add_column(sa.Column('username', sa.String(), nullable=True)) + batch_op.create_index(op.f('ix_user_username'), 'username', unique=True) + + # For existing users, initialize username from email + # This is just SQL template code - in a real migration with existing data, + # you would need to handle this appropriately + op.execute(""" + UPDATE "user" SET username = email + WHERE username IS NULL + """) + + # Now make username non-nullable for future records + with op.batch_alter_table('user') as batch_op: + batch_op.alter_column('username', nullable=False) + + +def downgrade() -> None: + # Remove username column from user table + with op.batch_alter_table('user') as batch_op: + batch_op.drop_index(op.f('ix_user_username')) + batch_op.drop_column('username') \ No newline at end of file