from datetime import datetime from typing import Any, Dict, List, Optional, Type, TypeVar, Generic, Union from fastapi.encoders import jsonable_encoder from pydantic import BaseModel, Field, validator, create_model T = TypeVar('T', bound=BaseModel) class FilterableModel(Generic[T]): """ Wrapper class for filtering fields in Pydantic models. """ def __init__(self, model_class: Type[T]): self.model_class = model_class def filter_model(self, fields: str = None) -> Type[BaseModel]: """ Create a new model with only the specified fields from the original model. If fields is None, returns the original model. Args: fields (str, optional): Comma-separated string of fields to include. Returns: Type[BaseModel]: A new model class with only the specified fields. """ if fields is None: return self.model_class # Parse the fields field_set = {field.strip() for field in fields.split(",")} if fields else set() if not field_set: return self.model_class # Get fields from the original model original_fields = self.model_class.__annotations__ original_field_defaults = { field_name: (field_type, self.model_class.__fields__[field_name].default) for field_name, field_type in original_fields.items() if field_name in field_set and field_name in self.model_class.__fields__ } # Create a new model with only the specified fields filtered_model = create_model( f"{self.model_class.__name__}Filtered", **original_field_defaults ) # Copy the Config from the original model if it exists if hasattr(self.model_class, "Config"): filtered_model.Config = type("Config", (), dict(vars(self.model_class.Config))) return filtered_model def process_response(self, obj: Union[T, List[T]], fields: str = None) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """ Process the response to include only the specified fields. Args: obj (Union[T, List[T]]): The object(s) to process fields (str, optional): Comma-separated string of fields to include. Returns: Union[Dict[str, Any], List[Dict[str, Any]]]: The filtered response """ if fields is None: return jsonable_encoder(obj) filtered_model = self.filter_model(fields) if isinstance(obj, list): return [jsonable_encoder(filtered_model.parse_obj(item)) for item in obj] else: return jsonable_encoder(filtered_model.parse_obj(obj)) class InvoiceItemBase(BaseModel): description: str quantity: float unit_price: float class InvoiceItemCreate(InvoiceItemBase): pass class InvoiceItemDB(InvoiceItemBase): id: int invoice_id: int amount: float class Config: from_attributes = True class InvoiceBase(BaseModel): customer_name: str customer_email: Optional[str] = None customer_address: Optional[str] = None due_date: datetime notes: Optional[str] = None class InvoiceCreate(InvoiceBase): items: List[InvoiceItemCreate] @validator("items") def validate_items(cls, v): if not v or len(v) == 0: raise ValueError("Invoice must have at least one item") return v class InvoiceUpdate(BaseModel): customer_name: Optional[str] = None customer_email: Optional[str] = None customer_address: Optional[str] = None due_date: Optional[datetime] = None status: Optional[str] = None notes: Optional[str] = None class InvoiceDB(InvoiceBase): id: int invoice_number: str date_created: datetime total_amount: float status: str items: List[InvoiceItemDB] class Config: from_attributes = True # Create filterable versions of our models invoice_db_filterable = FilterableModel(InvoiceDB) invoice_item_db_filterable = FilterableModel(InvoiceItemDB) class InvoiceSearchQuery(BaseModel): invoice_number: str class InvoiceStatusUpdate(BaseModel): status: str = Field(..., description="New status for the invoice") @validator("status") def validate_status(cls, v): allowed_statuses = ["PENDING", "PAID", "CANCELLED"] if v not in allowed_statuses: raise ValueError(f"Status must be one of {', '.join(allowed_statuses)}") return v