from datetime import datetime, date from typing import Any, Dict, List, Optional, Type, TypeVar, Generic, Union from fastapi.encoders import jsonable_encoder from pydantic import BaseModel, Field, validator, create_model T = TypeVar('T', bound=BaseModel) class FilterableModel(Generic[T]): """ Wrapper class for filtering fields in Pydantic models. """ def __init__(self, model_class: Type[T]): self.model_class = model_class def filter_model(self, fields: str = None) -> Type[BaseModel]: """ Create a new model with only the specified fields from the original model. If fields is None, returns the original model. Args: fields (str, optional): Comma-separated string of fields to include. Returns: Type[BaseModel]: A new model class with only the specified fields. """ if fields is None: return self.model_class # Parse the fields field_set = {field.strip() for field in fields.split(",")} if fields else set() if not field_set: return self.model_class # Get fields from the original model original_fields = self.model_class.__annotations__ original_field_defaults = { field_name: (field_type, self.model_class.__fields__[field_name].default) for field_name, field_type in original_fields.items() if field_name in field_set and field_name in self.model_class.__fields__ } # Create a new model with only the specified fields filtered_model = create_model( f"{self.model_class.__name__}Filtered", **original_field_defaults ) # Copy the Config from the original model if it exists if hasattr(self.model_class, "Config"): filtered_model.Config = type("Config", (), dict(vars(self.model_class.Config))) return filtered_model def process_response(self, obj: Union[T, List[T]], fields: str = None) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """ Process the response to include only the specified fields. Args: obj (Union[T, List[T]]): The object(s) to process fields (str, optional): Comma-separated string of fields to include. Returns: Union[Dict[str, Any], List[Dict[str, Any]]]: The filtered response """ if fields is None: return jsonable_encoder(obj) filtered_model = self.filter_model(fields) if isinstance(obj, list): return [jsonable_encoder(filtered_model.parse_obj(item)) for item in obj] else: return jsonable_encoder(filtered_model.parse_obj(obj)) class InvoiceItemBase(BaseModel): description: str quantity: float unit_price: float class InvoiceItemCreate(InvoiceItemBase): pass class InvoiceItemDB(InvoiceItemBase): id: int invoice_id: int amount: float class Config: from_attributes = True class InvoiceBase(BaseModel): customer_name: str customer_email: Optional[str] = None customer_address: Optional[str] = None due_date: datetime notes: Optional[str] = None class InvoiceCreate(InvoiceBase): items: List[InvoiceItemCreate] @validator("items") def validate_items(cls, v): if not v or len(v) == 0: raise ValueError("Invoice must have at least one item") return v class InvoiceUpdate(BaseModel): customer_name: Optional[str] = None customer_email: Optional[str] = None customer_address: Optional[str] = None due_date: Optional[datetime] = None status: Optional[str] = None notes: Optional[str] = None class InvoiceDB(InvoiceBase): id: int invoice_number: str date_created: datetime total_amount: float status: str user_id: int items: List[InvoiceItemDB] class Config: from_attributes = True # Create filterable versions of our models invoice_db_filterable = FilterableModel(InvoiceDB) invoice_item_db_filterable = FilterableModel(InvoiceItemDB) class InvoiceSearchQuery(BaseModel): invoice_number: str class InvoiceStatusUpdate(BaseModel): status: str = Field(..., description="New status for the invoice") @validator("status") def validate_status(cls, v): allowed_statuses = ["PENDING", "PAID", "CANCELLED"] if v not in allowed_statuses: raise ValueError(f"Status must be one of {', '.join(allowed_statuses)}") return v class InvoiceAdvancedFilter(BaseModel): """ Schema for advanced filtering of invoices. """ # Date range filtering created_after: Optional[date] = Field( None, description="Filter invoices created on or after this date (YYYY-MM-DD)" ) created_before: Optional[date] = Field( None, description="Filter invoices created on or before this date (YYYY-MM-DD)" ) due_after: Optional[date] = Field( None, description="Filter invoices due on or after this date (YYYY-MM-DD)" ) due_before: Optional[date] = Field( None, description="Filter invoices due on or before this date (YYYY-MM-DD)" ) # Customer filtering customer_name: Optional[str] = Field( None, description="Filter invoices by customer name (case-insensitive, partial match)" ) customer_email: Optional[str] = Field( None, description="Filter invoices by customer email (case-insensitive, partial match)" ) # Amount range filtering min_amount: Optional[float] = Field( None, description="Filter invoices with total amount greater than or equal to this value" ) max_amount: Optional[float] = Field( None, description="Filter invoices with total amount less than or equal to this value" ) # Advanced sorting sort_by: Optional[str] = Field( None, description="Field to sort by (date_created, due_date, total_amount, customer_name)" ) sort_order: Optional[str] = Field( "desc", description="Sort order: 'asc' for ascending, 'desc' for descending" ) @validator("sort_by") def validate_sort_by(cls, v): if v is not None: allowed_sort_fields = ["date_created", "due_date", "total_amount", "customer_name"] if v not in allowed_sort_fields: raise ValueError(f"Sort field must be one of {', '.join(allowed_sort_fields)}") return v @validator("sort_order") def validate_sort_order(cls, v): if v is not None: allowed_sort_orders = ["asc", "desc"] if v.lower() not in allowed_sort_orders: raise ValueError(f"Sort order must be one of {', '.join(allowed_sort_orders)}") return v.lower() if v else v