diff --git a/alembic/versions/20250414_144046_ebe9468d_update_fruit.py b/alembic/versions/20250414_144046_ebe9468d_update_fruit.py new file mode 100644 index 0000000..42973df --- /dev/null +++ b/alembic/versions/20250414_144046_ebe9468d_update_fruit.py @@ -0,0 +1,19 @@ +"""add shape field to fruits table +Revision ID: 1a2b3c4d5e6f +Revises: 0002 +Create Date: 2024-01-20 10:00:00.000000 +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '1a2b3c4d5e6f' +down_revision = '0002' +branch_labels = None +depends_on = None + +def upgrade(): + op.add_column('fruits', sa.Column('shape', sa.String(), nullable=False)) + +def downgrade(): + op.drop_column('fruits', 'shape') \ No newline at end of file diff --git a/endpoints/get-fruit-by-color.get.py b/endpoints/get-fruit-by-color.get.py index ffea04e..cf54e2f 100644 --- a/endpoints/get-fruit-by-color.get.py +++ b/endpoints/get-fruit-by-color.get.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from typing import List from core.database import get_db @@ -10,14 +10,11 @@ router = APIRouter() @router.get("/get-fruit-by-color", response_model=List[FruitSchema]) async def get_fruits_by_color_endpoint( color: str, + shape: str = None, db: Session = Depends(get_db) ): - """Get fruits by their color""" normalized_color = normalize_color_name(color) - fruits = get_fruits_by_color(db=db, color=normalized_color) + fruits = get_fruits_by_color(db=db, color=normalized_color, shape=shape) if not fruits: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No fruits found with color: {normalized_color}" - ) + raise HTTPException(status_code=404, detail=f"No fruits found with color {color}") return fruits \ No newline at end of file diff --git a/helpers/fruit_helpers.py b/helpers/fruit_helpers.py index 9df05e0..010b597 100644 --- a/helpers/fruit_helpers.py +++ b/helpers/fruit_helpers.py @@ -45,18 +45,22 @@ def get_fruit_by_name(db: Session, name: str) -> Optional[Fruit]: """ return db.query(Fruit).filter(Fruit.name == name).first() -def get_fruits_by_color(db: Session, color: str) -> List[Fruit]: +def get_fruits_by_color(db: Session, color: str, shape: Optional[str] = None) -> List[Fruit]: """ - Retrieves all fruits of a specific color. + Retrieves all fruits of a specific color, optionally filtered by shape. Args: db (Session): The database session. color (str): The color to filter by. + shape (Optional[str]): Optional shape to filter by. Returns: - List[Fruit]: A list of fruits matching the specified color. + List[Fruit]: A list of fruits matching the specified criteria. """ - return db.query(Fruit).filter(Fruit.color == color).all() + query = db.query(Fruit).filter(Fruit.color == color) + if shape: + query = query.filter(Fruit.shape == shape) + return query.all() def validate_fruit_data(fruit_data: Dict[str, Any]) -> bool: """ @@ -70,10 +74,12 @@ def validate_fruit_data(fruit_data: Dict[str, Any]) -> bool: """ if not fruit_data: return False - if not isinstance(fruit_data.get("name"), str) or not isinstance(fruit_data.get("color"), str): - return False - if len(fruit_data.get("name", "")) < 1 or len(fruit_data.get("color", "")) < 1: - return False + required_fields = ["name", "color", "shape"] + for field in required_fields: + if not isinstance(fruit_data.get(field), str): + return False + if len(fruit_data.get(field, "")) < 1: + return False return True def sanitize_fruit_name(name: str) -> str: diff --git a/models/fruit.py b/models/fruit.py index 67332f9..3087fd4 100644 --- a/models/fruit.py +++ b/models/fruit.py @@ -10,5 +10,6 @@ class Fruit(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = Column(String, nullable=False, index=True) color = Column(String, nullable=False) + shape = Column(String, nullable=False) created_at = Column(DateTime, default=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) \ No newline at end of file diff --git a/schemas/fruit.py b/schemas/fruit.py index 2c18a4a..7f7d0ad 100644 --- a/schemas/fruit.py +++ b/schemas/fruit.py @@ -6,6 +6,7 @@ from uuid import UUID class FruitBase(BaseModel): name: str = Field(..., min_length=1, description="Name of the fruit") color: str = Field(..., min_length=1, description="Color of the fruit") + shape: str = Field(..., min_length=1, description="Shape of the fruit") class FruitCreate(FruitBase): pass @@ -13,6 +14,7 @@ class FruitCreate(FruitBase): class FruitUpdate(BaseModel): name: Optional[str] = Field(None, min_length=1, description="Name of the fruit") color: Optional[str] = Field(None, min_length=1, description="Color of the fruit") + shape: Optional[str] = Field(None, min_length=1, description="Shape of the fruit") class FruitSchema(FruitBase): id: UUID @@ -26,6 +28,7 @@ class FruitSchema(FruitBase): "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479", "name": "Apple", "color": "Red", + "shape": "Round", "created_at": "2023-01-01T12:00:00", "updated_at": "2023-01-01T12:00:00" }