import uuid from typing import List, Optional from sqlalchemy.orm import Session from app.models.cart import Cart from app.models.cart_item import CartItem from app.models.product import Product from app.schemas.cart import CartCreate, CartItemCreate, CartItemUpdate def get_by_user_id(db: Session, user_id: str) -> Optional[Cart]: return db.query(Cart).filter(Cart.user_id == user_id).first() def get_cart_items(db: Session, cart_id: str) -> List[CartItem]: return db.query(CartItem).filter(CartItem.cart_id == cart_id).all() def create_cart(db: Session, *, obj_in: CartCreate) -> Cart: db_obj = Cart( id=str(uuid.uuid4()), user_id=obj_in.user_id, ) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def ensure_cart_exists(db: Session, user_id: str) -> Cart: cart = get_by_user_id(db, user_id=user_id) if not cart: cart_in = CartCreate(user_id=user_id) cart = create_cart(db, obj_in=cart_in) return cart def add_to_cart(db: Session, *, cart_id: str, obj_in: CartItemCreate) -> CartItem: # Check if item already exists in cart cart_item = db.query(CartItem).filter( CartItem.cart_id == cart_id, CartItem.product_id == obj_in.product_id ).first() if cart_item: # Update quantity if item already exists cart_item.quantity += obj_in.quantity db.add(cart_item) db.commit() db.refresh(cart_item) return cart_item # Create new cart item db_obj = CartItem( id=str(uuid.uuid4()), cart_id=cart_id, product_id=obj_in.product_id, quantity=obj_in.quantity, ) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def update_cart_item(db: Session, *, item_id: str, obj_in: CartItemUpdate) -> Optional[CartItem]: cart_item = db.query(CartItem).filter(CartItem.id == item_id).first() if not cart_item: return None cart_item.quantity = obj_in.quantity db.add(cart_item) db.commit() db.refresh(cart_item) return cart_item def remove_from_cart(db: Session, *, item_id: str) -> Optional[CartItem]: cart_item = db.query(CartItem).filter(CartItem.id == item_id).first() if not cart_item: return None db.delete(cart_item) db.commit() return cart_item def clear_cart(db: Session, *, cart_id: str) -> None: db.query(CartItem).filter(CartItem.cart_id == cart_id).delete() db.commit() return None def calculate_cart_total(db: Session, *, cart_id: str) -> float: cart_items = get_cart_items(db, cart_id=cart_id) total = 0.0 for item in cart_items: product = db.query(Product).filter(Product.id == item.product_id).first() if product: total += product.price * item.quantity return total