101 lines
3.0 KiB
Python

from typing import List, Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.category import Category
from app.models.product import Product
from app.schemas.product import ProductCreate, ProductUpdate
async def get_products(
db: AsyncSession,
skip: int = 0,
limit: int = 100,
anime_title: Optional[str] = None,
character_name: Optional[str] = None,
) -> List[Product]:
"""Get all products with optional filtering."""
query = select(Product).options(selectinload(Product.categories))
if anime_title:
query = query.filter(Product.anime_title == anime_title)
if character_name:
query = query.filter(Product.character_name == character_name)
query = query.offset(skip).limit(limit)
result = await db.execute(query)
return result.scalars().all()
async def get_product(db: AsyncSession, product_id: int) -> Optional[Product]:
"""Get a product by ID with its categories."""
query = (
select(Product)
.where(Product.id == product_id)
.options(selectinload(Product.categories))
)
result = await db.execute(query)
return result.scalar_one_or_none()
async def create_product(db: AsyncSession, product: ProductCreate) -> Product:
"""Create a new product with optional category associations."""
# Extract category_ids from the schema
category_ids = product.category_ids
product_data = product.model_dump(exclude={"category_ids"})
# Create the product
db_product = Product(**product_data)
# Add categories if provided
if category_ids:
categories = await get_categories_by_ids(db, category_ids)
db_product.categories = categories
db.add(db_product)
await db.commit()
await db.refresh(db_product)
return db_product
async def update_product(
db: AsyncSession, db_product: Product, product_update: ProductUpdate
) -> Product:
"""Update a product and its category associations."""
# Extract category_ids from the schema if present
update_data = product_update.model_dump(exclude_unset=True)
category_ids = update_data.pop("category_ids", None)
# Update product fields
for field, value in update_data.items():
setattr(db_product, field, value)
# Update categories if provided
if category_ids is not None:
categories = await get_categories_by_ids(db, category_ids)
db_product.categories = categories
await db.commit()
await db.refresh(db_product)
return db_product
async def delete_product(db: AsyncSession, db_product: Product) -> None:
"""Delete a product."""
await db.delete(db_product)
await db.commit()
async def get_categories_by_ids(
db: AsyncSession, category_ids: List[int]
) -> List[Category]:
"""Helper function to get categories by their IDs."""
if not category_ids:
return []
result = await db.execute(select(Category).where(Category.id.in_(category_ids)))
return result.scalars().all()