from typing import List, Optional, Dict, Any from sqlalchemy.orm import Session from models.fruit import Fruit from schemas.fruit import FruitCreate from sqlalchemy.exc import IntegrityError from fastapi import HTTPException def create_fruit(db: Session, fruit_data: FruitCreate) -> Fruit: """ Creates a new fruit in the database. Args: db (Session): The database session. fruit_data (FruitCreate): The data for the fruit to create. Returns: Fruit: The newly created fruit object. Raises: HTTPException: If there's an error creating the fruit. """ try: db_fruit = Fruit(**fruit_data.dict()) db.add(db_fruit) db.commit() db.refresh(db_fruit) return db_fruit except IntegrityError: db.rollback() raise HTTPException(status_code=400, detail="Fruit with this name already exists") except Exception as e: db.rollback() raise HTTPException(status_code=500, detail=str(e)) def get_fruit_by_name(db: Session, name: str) -> Optional[Fruit]: """ Retrieves a fruit by its name. Args: db (Session): The database session. name (str): The name of the fruit to retrieve. Returns: Optional[Fruit]: The fruit object if found, otherwise None. """ return db.query(Fruit).filter(Fruit.name == name).first() def get_fruits_by_color(db: Session, color: str, shape: Optional[str] = None) -> List[Fruit]: """ Retrieves all fruits of a specific color, optionally filtered by shape. Args: db (Session): The database session. color (str): The color to filter by. shape (Optional[str]): Optional shape to filter by. Returns: List[Fruit]: A list of fruits matching the specified criteria. """ query = db.query(Fruit).filter(Fruit.color == color) if shape: query = query.filter(Fruit.shape == shape) return query.all() def validate_fruit_data(fruit_data: Dict[str, Any]) -> bool: """ Validates fruit data before creation. Args: fruit_data (Dict[str, Any]): The fruit data to validate. Returns: bool: True if the data is valid, False otherwise. """ if not fruit_data: return False required_fields = ["name", "color", "shape"] for field in required_fields: if not isinstance(fruit_data.get(field), str): return False if len(fruit_data.get(field, "")) < 1: return False return True def sanitize_fruit_name(name: str) -> str: """ Sanitizes the fruit name by removing leading/trailing whitespace and converting to lowercase. Args: name (str): The fruit name to sanitize. Returns: str: The sanitized fruit name. """ return name.strip().lower() def normalize_color_name(color: str) -> str: """ Normalizes color names to a standard format. Args: color (str): The color name to normalize. Returns: str: The normalized color name. """ color_mapping = { "red": "Red", "green": "Green", "yellow": "Yellow", "orange": "Orange", "purple": "Purple", "brown": "Brown", } normalized = color.strip().lower() return color_mapping.get(normalized, normalized.capitalize())