from typing import Any, Dict, List, Optional, Union from decimal import Decimal from sqlalchemy.orm import Session, joinedload from app.models.invoice import Invoice, InvoiceItem, InvoiceStatus from app.models.product import Product from app.schemas.invoice import InvoiceCreate, InvoiceUpdate, InvoiceItemCreate def get_invoice(db: Session, user_id: int, invoice_id: int) -> Optional[Invoice]: """Get an invoice by ID for a specific user.""" return db.query(Invoice).filter( Invoice.id == invoice_id, Invoice.user_id == user_id ).options( joinedload(Invoice.items) ).first() def get_invoices( db: Session, user_id: int, skip: int = 0, limit: int = 100, status: Optional[InvoiceStatus] = None ) -> List[Invoice]: """Get a list of invoices for a specific user.""" query = db.query(Invoice).filter(Invoice.user_id == user_id) if status: query = query.filter(Invoice.status == status) return query.order_by(Invoice.created_at.desc()).offset(skip).limit(limit).all() def _calculate_invoice_item(item_in: InvoiceItemCreate, product: Optional[Product] = None) -> Dict[str, Any]: """Calculate invoice item totals.""" unit_price = item_in.unit_price tax_rate = item_in.tax_rate # If a product is specified, use its price and tax rate if product: unit_price = product.price tax_rate = product.tax_rate # Calculate subtotal subtotal = unit_price * item_in.quantity # Calculate tax amount tax_amount = subtotal * (tax_rate / 100) # Calculate total total = subtotal + tax_amount return { "description": item_in.description, "quantity": item_in.quantity, "unit_price": unit_price, "tax_rate": tax_rate, "tax_amount": tax_amount, "subtotal": subtotal, "total": total, "product_id": item_in.product_id } def create_invoice(db: Session, user_id: int, invoice_in: InvoiceCreate) -> Invoice: """Create a new invoice for a specific user.""" # Get base invoice data invoice_data = invoice_in.model_dump(exclude={"items"}) # Create invoice db_invoice = Invoice(**invoice_data, user_id=user_id, subtotal=0, tax_amount=0, total=0) db.add(db_invoice) db.flush() # Flush to get the invoice ID # Create invoice items subtotal = Decimal("0.0") tax_amount = Decimal("0.0") total = Decimal("0.0") for item_in in invoice_in.items: # Get product if specified product = None if item_in.product_id: product = db.query(Product).filter( Product.id == item_in.product_id, Product.user_id == user_id ).first() # Calculate item totals item_data = _calculate_invoice_item(item_in, product) # Add to invoice totals subtotal += item_data["subtotal"] tax_amount += item_data["tax_amount"] total += item_data["total"] # Create invoice item db_item = InvoiceItem(**item_data, invoice_id=db_invoice.id) db.add(db_item) # Update invoice totals db_invoice.subtotal = subtotal db_invoice.tax_amount = tax_amount db_invoice.total = total db.commit() db.refresh(db_invoice) return db_invoice def update_invoice( db: Session, user_id: int, db_invoice: Invoice, invoice_in: Union[InvoiceUpdate, Dict[str, Any]] ) -> Invoice: """Update an invoice.""" invoice_data = db_invoice.to_dict() if isinstance(invoice_in, dict): update_data = invoice_in else: update_data = invoice_in.model_dump(exclude_unset=True) # Update invoice fields for field in invoice_data: if field in update_data: setattr(db_invoice, field, update_data[field]) db.add(db_invoice) db.commit() db.refresh(db_invoice) return db_invoice def delete_invoice(db: Session, user_id: int, invoice_id: int) -> Optional[Invoice]: """Delete an invoice.""" invoice = db.query(Invoice).filter( Invoice.id == invoice_id, Invoice.user_id == user_id ).first() if not invoice: return None db.delete(invoice) db.commit() return invoice def update_invoice_status( db: Session, user_id: int, invoice_id: int, status: InvoiceStatus ) -> Optional[Invoice]: """Update an invoice status.""" invoice = db.query(Invoice).filter( Invoice.id == invoice_id, Invoice.user_id == user_id ).first() if not invoice: return None invoice.status = status db.add(invoice) db.commit() db.refresh(invoice) return invoice