3rd-project-7hltbg/helpers/fruit_helpers.py

116 lines
3.3 KiB
Python

from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from models.fruit import Fruit
from schemas.fruit import FruitCreate
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException
def create_fruit(db: Session, fruit_data: FruitCreate) -> Fruit:
"""
Creates a new fruit in the database.
Args:
db (Session): The database session.
fruit_data (FruitCreate): The data for the fruit to create.
Returns:
Fruit: The newly created fruit object.
Raises:
HTTPException: If there's an error creating the fruit.
"""
try:
db_fruit = Fruit(**fruit_data.dict())
db.add(db_fruit)
db.commit()
db.refresh(db_fruit)
return db_fruit
except IntegrityError:
db.rollback()
raise HTTPException(status_code=400, detail="Fruit with this name already exists")
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
def get_fruit_by_name(db: Session, name: str) -> Optional[Fruit]:
"""
Retrieves a fruit by its name.
Args:
db (Session): The database session.
name (str): The name of the fruit to retrieve.
Returns:
Optional[Fruit]: The fruit object if found, otherwise None.
"""
return db.query(Fruit).filter(Fruit.name == name).first()
def get_fruits_by_color(db: Session, color: str, shape: Optional[str] = None) -> List[Fruit]:
"""
Retrieves all fruits of a specific color, optionally filtered by shape.
Args:
db (Session): The database session.
color (str): The color to filter by.
shape (Optional[str]): Optional shape to filter by.
Returns:
List[Fruit]: A list of fruits matching the specified criteria.
"""
query = db.query(Fruit).filter(Fruit.color == color)
if shape:
query = query.filter(Fruit.shape == shape)
return query.all()
def validate_fruit_data(fruit_data: Dict[str, Any]) -> bool:
"""
Validates fruit data before creation.
Args:
fruit_data (Dict[str, Any]): The fruit data to validate.
Returns:
bool: True if the data is valid, False otherwise.
"""
if not fruit_data:
return False
required_fields = ["name", "color", "shape"]
for field in required_fields:
if not isinstance(fruit_data.get(field), str):
return False
if len(fruit_data.get(field, "")) < 1:
return False
return True
def sanitize_fruit_name(name: str) -> str:
"""
Sanitizes the fruit name by removing leading/trailing whitespace and converting to lowercase.
Args:
name (str): The fruit name to sanitize.
Returns:
str: The sanitized fruit name.
"""
return name.strip().lower()
def normalize_color_name(color: str) -> str:
"""
Normalizes color names to a standard format.
Args:
color (str): The color name to normalize.
Returns:
str: The normalized color name.
"""
color_mapping = {
"red": "Red",
"green": "Green",
"yellow": "Yellow",
"orange": "Orange",
"purple": "Purple",
"brown": "Brown",
}
normalized = color.strip().lower()
return color_mapping.get(normalized, normalized.capitalize())