diff --git a/helpers/pen_helpers.py b/helpers/pen_helpers.py new file mode 100644 index 0000000..9f40717 --- /dev/null +++ b/helpers/pen_helpers.py @@ -0,0 +1,115 @@ +from typing import List, Dict, Optional, Union, Any +from sqlalchemy.orm import Session +from models.pen import Pen +from schemas.pen import PenCreate, PenUpdate +from decimal import Decimal + +def validate_pen_price(price: int) -> bool: + """ + Validate that pen price is within acceptable range. + + Args: + price: The pen price to validate + + Returns: + bool: True if price is valid, False otherwise + """ + return price > 0 and price < 100000 + +def validate_pen_color(color: str) -> bool: + """ + Validate that pen color is an acceptable value. + + Args: + color: The pen color to validate + + Returns: + bool: True if color is valid, False otherwise + """ + valid_colors = ['black', 'blue', 'red', 'green', 'purple'] + return color.lower() in valid_colors + +def get_pen_by_name_and_brand(db: Session, name: str, brand: str) -> Optional[Pen]: + """ + Get a pen by name and brand combination. + + Args: + db: Database session + name: Pen name + brand: Pen brand + + Returns: + Pen object if found, None otherwise + """ + return db.query(Pen).filter(Pen.name == name, Pen.brand == brand).first() + +def create_pen_safely(db: Session, pen_data: PenCreate) -> Union[Pen, Dict[str, str]]: + """ + Create a new pen with validation and error handling. + + Args: + db: Database session + pen_data: Pen data for creation + + Returns: + Pen object if created successfully, error dict otherwise + """ + if not validate_pen_price(pen_data.price): + return {"error": "Invalid price range"} + + if not validate_pen_color(pen_data.color): + return {"error": "Invalid pen color"} + + existing_pen = get_pen_by_name_and_brand(db, pen_data.name, pen_data.brand) + if existing_pen: + return {"error": "Pen with this name and brand already exists"} + + db_pen = Pen( + name=pen_data.name, + brand=pen_data.brand, + color=pen_data.color.lower(), + price=pen_data.price, + in_stock=True, + description=pen_data.description + ) + + db.add(db_pen) + db.commit() + db.refresh(db_pen) + + return db_pen + +def update_pen_stock(db: Session, pen_id: int, in_stock: bool) -> Union[Pen, Dict[str, str]]: + """ + Update the stock status of a pen. + + Args: + db: Database session + pen_id: ID of pen to update + in_stock: New stock status + + Returns: + Updated pen object if successful, error dict otherwise + """ + pen = db.query(Pen).filter(Pen.id == pen_id).first() + if not pen: + return {"error": "Pen not found"} + + pen.in_stock = in_stock + db.commit() + db.refresh(pen) + + return pen + +def get_pens_by_brand(db: Session, brand: str) -> List[Pen]: + """ + Get all pens from a specific brand. + + Args: + db: Database session + brand: Brand name to filter by + + Returns: + List of pen objects for the brand + """ + return db.query(Pen).filter(Pen.brand == brand).all() \ No newline at end of file