155 lines
4.5 KiB
Python

from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TypeVar, Generic, Union
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, Field, validator, create_model
T = TypeVar('T', bound=BaseModel)
class FilterableModel(Generic[T]):
"""
Wrapper class for filtering fields in Pydantic models.
"""
def __init__(self, model_class: Type[T]):
self.model_class = model_class
def filter_model(self, fields: str = None) -> Type[BaseModel]:
"""
Create a new model with only the specified fields from the original model.
If fields is None, returns the original model.
Args:
fields (str, optional): Comma-separated string of fields to include.
Returns:
Type[BaseModel]: A new model class with only the specified fields.
"""
if fields is None:
return self.model_class
# Parse the fields
field_set = {field.strip() for field in fields.split(",")} if fields else set()
if not field_set:
return self.model_class
# Get fields from the original model
original_fields = self.model_class.__annotations__
original_field_defaults = {
field_name: (field_type, self.model_class.__fields__[field_name].default)
for field_name, field_type in original_fields.items()
if field_name in field_set and field_name in self.model_class.__fields__
}
# Create a new model with only the specified fields
filtered_model = create_model(
f"{self.model_class.__name__}Filtered",
**original_field_defaults
)
# Copy the Config from the original model if it exists
if hasattr(self.model_class, "Config"):
filtered_model.Config = type("Config", (), dict(vars(self.model_class.Config)))
return filtered_model
def process_response(self, obj: Union[T, List[T]], fields: str = None) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Process the response to include only the specified fields.
Args:
obj (Union[T, List[T]]): The object(s) to process
fields (str, optional): Comma-separated string of fields to include.
Returns:
Union[Dict[str, Any], List[Dict[str, Any]]]: The filtered response
"""
if fields is None:
return jsonable_encoder(obj)
filtered_model = self.filter_model(fields)
if isinstance(obj, list):
return [jsonable_encoder(filtered_model.parse_obj(item)) for item in obj]
else:
return jsonable_encoder(filtered_model.parse_obj(obj))
class InvoiceItemBase(BaseModel):
description: str
quantity: float
unit_price: float
class InvoiceItemCreate(InvoiceItemBase):
pass
class InvoiceItemDB(InvoiceItemBase):
id: int
invoice_id: int
amount: float
class Config:
from_attributes = True
class InvoiceBase(BaseModel):
customer_name: str
customer_email: Optional[str] = None
customer_address: Optional[str] = None
due_date: datetime
notes: Optional[str] = None
class InvoiceCreate(InvoiceBase):
items: List[InvoiceItemCreate]
@validator("items")
def validate_items(cls, v):
if not v or len(v) == 0:
raise ValueError("Invoice must have at least one item")
return v
class InvoiceUpdate(BaseModel):
customer_name: Optional[str] = None
customer_email: Optional[str] = None
customer_address: Optional[str] = None
due_date: Optional[datetime] = None
status: Optional[str] = None
notes: Optional[str] = None
class InvoiceDB(InvoiceBase):
id: int
invoice_number: str
date_created: datetime
total_amount: float
status: str
items: List[InvoiceItemDB]
class Config:
from_attributes = True
# Create filterable versions of our models
invoice_db_filterable = FilterableModel(InvoiceDB)
invoice_item_db_filterable = FilterableModel(InvoiceItemDB)
class InvoiceSearchQuery(BaseModel):
invoice_number: str
class InvoiceStatusUpdate(BaseModel):
status: str = Field(..., description="New status for the invoice")
@validator("status")
def validate_status(cls, v):
allowed_statuses = ["PENDING", "PAID", "CANCELLED"]
if v not in allowed_statuses:
raise ValueError(f"Status must be one of {', '.join(allowed_statuses)}")
return v