feat: Updated endpoint endpoints/get-fruit-by-color.get.py via AI

This commit is contained in:
Backend IM Bot 2025-04-14 14:40:57 +00:00
parent e0ec828be7
commit a7c3db994d
5 changed files with 41 additions and 15 deletions

View File

@ -0,0 +1,19 @@
"""add shape field to fruits table
Revision ID: 1a2b3c4d5e6f
Revises: 0002
Create Date: 2024-01-20 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1a2b3c4d5e6f'
down_revision = '0002'
branch_labels = None
depends_on = None
def upgrade():
op.add_column('fruits', sa.Column('shape', sa.String(), nullable=False))
def downgrade():
op.drop_column('fruits', 'shape')

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from core.database import get_db from core.database import get_db
@ -10,14 +10,11 @@ router = APIRouter()
@router.get("/get-fruit-by-color", response_model=List[FruitSchema]) @router.get("/get-fruit-by-color", response_model=List[FruitSchema])
async def get_fruits_by_color_endpoint( async def get_fruits_by_color_endpoint(
color: str, color: str,
shape: str = None,
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Get fruits by their color"""
normalized_color = normalize_color_name(color) normalized_color = normalize_color_name(color)
fruits = get_fruits_by_color(db=db, color=normalized_color) fruits = get_fruits_by_color(db=db, color=normalized_color, shape=shape)
if not fruits: if not fruits:
raise HTTPException( raise HTTPException(status_code=404, detail=f"No fruits found with color {color}")
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No fruits found with color: {normalized_color}"
)
return fruits return fruits

View File

@ -45,18 +45,22 @@ def get_fruit_by_name(db: Session, name: str) -> Optional[Fruit]:
""" """
return db.query(Fruit).filter(Fruit.name == name).first() return db.query(Fruit).filter(Fruit.name == name).first()
def get_fruits_by_color(db: Session, color: str) -> List[Fruit]: def get_fruits_by_color(db: Session, color: str, shape: Optional[str] = None) -> List[Fruit]:
""" """
Retrieves all fruits of a specific color. Retrieves all fruits of a specific color, optionally filtered by shape.
Args: Args:
db (Session): The database session. db (Session): The database session.
color (str): The color to filter by. color (str): The color to filter by.
shape (Optional[str]): Optional shape to filter by.
Returns: Returns:
List[Fruit]: A list of fruits matching the specified color. List[Fruit]: A list of fruits matching the specified criteria.
""" """
return db.query(Fruit).filter(Fruit.color == color).all() query = db.query(Fruit).filter(Fruit.color == color)
if shape:
query = query.filter(Fruit.shape == shape)
return query.all()
def validate_fruit_data(fruit_data: Dict[str, Any]) -> bool: def validate_fruit_data(fruit_data: Dict[str, Any]) -> bool:
""" """
@ -70,10 +74,12 @@ def validate_fruit_data(fruit_data: Dict[str, Any]) -> bool:
""" """
if not fruit_data: if not fruit_data:
return False return False
if not isinstance(fruit_data.get("name"), str) or not isinstance(fruit_data.get("color"), str): required_fields = ["name", "color", "shape"]
return False for field in required_fields:
if len(fruit_data.get("name", "")) < 1 or len(fruit_data.get("color", "")) < 1: if not isinstance(fruit_data.get(field), str):
return False return False
if len(fruit_data.get(field, "")) < 1:
return False
return True return True
def sanitize_fruit_name(name: str) -> str: def sanitize_fruit_name(name: str) -> str:

View File

@ -10,5 +10,6 @@ class Fruit(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=False, index=True) name = Column(String, nullable=False, index=True)
color = Column(String, nullable=False) color = Column(String, nullable=False)
shape = Column(String, nullable=False)
created_at = Column(DateTime, default=func.now()) created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now())

View File

@ -6,6 +6,7 @@ from uuid import UUID
class FruitBase(BaseModel): class FruitBase(BaseModel):
name: str = Field(..., min_length=1, description="Name of the fruit") name: str = Field(..., min_length=1, description="Name of the fruit")
color: str = Field(..., min_length=1, description="Color of the fruit") color: str = Field(..., min_length=1, description="Color of the fruit")
shape: str = Field(..., min_length=1, description="Shape of the fruit")
class FruitCreate(FruitBase): class FruitCreate(FruitBase):
pass pass
@ -13,6 +14,7 @@ class FruitCreate(FruitBase):
class FruitUpdate(BaseModel): class FruitUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, description="Name of the fruit") name: Optional[str] = Field(None, min_length=1, description="Name of the fruit")
color: Optional[str] = Field(None, min_length=1, description="Color of the fruit") color: Optional[str] = Field(None, min_length=1, description="Color of the fruit")
shape: Optional[str] = Field(None, min_length=1, description="Shape of the fruit")
class FruitSchema(FruitBase): class FruitSchema(FruitBase):
id: UUID id: UUID
@ -26,6 +28,7 @@ class FruitSchema(FruitBase):
"id": "f47ac10b-58cc-4372-a567-0e02b2c3d479", "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
"name": "Apple", "name": "Apple",
"color": "Red", "color": "Red",
"shape": "Round",
"created_at": "2023-01-01T12:00:00", "created_at": "2023-01-01T12:00:00",
"updated_at": "2023-01-01T12:00:00" "updated_at": "2023-01-01T12:00:00"
} }