
- Create User model and database schema - Add JWT authentication with secure password hashing - Create authentication endpoints for registration and login - Update invoice routes to require authentication - Ensure users can only access their own invoices - Update documentation in README.md
215 lines
6.8 KiB
Python
215 lines
6.8 KiB
Python
from datetime import datetime, date
|
|
from typing import Any, Dict, List, Optional, Type, TypeVar, Generic, Union
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
from pydantic import BaseModel, Field, validator, create_model
|
|
|
|
|
|
T = TypeVar('T', bound=BaseModel)
|
|
|
|
|
|
class FilterableModel(Generic[T]):
|
|
"""
|
|
Wrapper class for filtering fields in Pydantic models.
|
|
"""
|
|
def __init__(self, model_class: Type[T]):
|
|
self.model_class = model_class
|
|
|
|
def filter_model(self, fields: str = None) -> Type[BaseModel]:
|
|
"""
|
|
Create a new model with only the specified fields from the original model.
|
|
If fields is None, returns the original model.
|
|
|
|
Args:
|
|
fields (str, optional): Comma-separated string of fields to include.
|
|
|
|
Returns:
|
|
Type[BaseModel]: A new model class with only the specified fields.
|
|
"""
|
|
if fields is None:
|
|
return self.model_class
|
|
|
|
# Parse the fields
|
|
field_set = {field.strip() for field in fields.split(",")} if fields else set()
|
|
|
|
if not field_set:
|
|
return self.model_class
|
|
|
|
# Get fields from the original model
|
|
original_fields = self.model_class.__annotations__
|
|
original_field_defaults = {
|
|
field_name: (field_type, self.model_class.__fields__[field_name].default)
|
|
for field_name, field_type in original_fields.items()
|
|
if field_name in field_set and field_name in self.model_class.__fields__
|
|
}
|
|
|
|
# Create a new model with only the specified fields
|
|
filtered_model = create_model(
|
|
f"{self.model_class.__name__}Filtered",
|
|
**original_field_defaults
|
|
)
|
|
|
|
# Copy the Config from the original model if it exists
|
|
if hasattr(self.model_class, "Config"):
|
|
filtered_model.Config = type("Config", (), dict(vars(self.model_class.Config)))
|
|
|
|
return filtered_model
|
|
|
|
def process_response(self, obj: Union[T, List[T]], fields: str = None) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
|
"""
|
|
Process the response to include only the specified fields.
|
|
|
|
Args:
|
|
obj (Union[T, List[T]]): The object(s) to process
|
|
fields (str, optional): Comma-separated string of fields to include.
|
|
|
|
Returns:
|
|
Union[Dict[str, Any], List[Dict[str, Any]]]: The filtered response
|
|
"""
|
|
if fields is None:
|
|
return jsonable_encoder(obj)
|
|
|
|
filtered_model = self.filter_model(fields)
|
|
|
|
if isinstance(obj, list):
|
|
return [jsonable_encoder(filtered_model.parse_obj(item)) for item in obj]
|
|
else:
|
|
return jsonable_encoder(filtered_model.parse_obj(obj))
|
|
|
|
|
|
class InvoiceItemBase(BaseModel):
|
|
description: str
|
|
quantity: float
|
|
unit_price: float
|
|
|
|
|
|
class InvoiceItemCreate(InvoiceItemBase):
|
|
pass
|
|
|
|
|
|
class InvoiceItemDB(InvoiceItemBase):
|
|
id: int
|
|
invoice_id: int
|
|
amount: float
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class InvoiceBase(BaseModel):
|
|
customer_name: str
|
|
customer_email: Optional[str] = None
|
|
customer_address: Optional[str] = None
|
|
due_date: datetime
|
|
notes: Optional[str] = None
|
|
|
|
|
|
class InvoiceCreate(InvoiceBase):
|
|
items: List[InvoiceItemCreate]
|
|
|
|
@validator("items")
|
|
def validate_items(cls, v):
|
|
if not v or len(v) == 0:
|
|
raise ValueError("Invoice must have at least one item")
|
|
return v
|
|
|
|
|
|
class InvoiceUpdate(BaseModel):
|
|
customer_name: Optional[str] = None
|
|
customer_email: Optional[str] = None
|
|
customer_address: Optional[str] = None
|
|
due_date: Optional[datetime] = None
|
|
status: Optional[str] = None
|
|
notes: Optional[str] = None
|
|
|
|
|
|
class InvoiceDB(InvoiceBase):
|
|
id: int
|
|
invoice_number: str
|
|
date_created: datetime
|
|
total_amount: float
|
|
status: str
|
|
user_id: int
|
|
items: List[InvoiceItemDB]
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
# Create filterable versions of our models
|
|
invoice_db_filterable = FilterableModel(InvoiceDB)
|
|
invoice_item_db_filterable = FilterableModel(InvoiceItemDB)
|
|
|
|
|
|
class InvoiceSearchQuery(BaseModel):
|
|
invoice_number: str
|
|
|
|
|
|
class InvoiceStatusUpdate(BaseModel):
|
|
status: str = Field(..., description="New status for the invoice")
|
|
|
|
@validator("status")
|
|
def validate_status(cls, v):
|
|
allowed_statuses = ["PENDING", "PAID", "CANCELLED"]
|
|
if v not in allowed_statuses:
|
|
raise ValueError(f"Status must be one of {', '.join(allowed_statuses)}")
|
|
return v
|
|
|
|
|
|
class InvoiceAdvancedFilter(BaseModel):
|
|
"""
|
|
Schema for advanced filtering of invoices.
|
|
"""
|
|
# Date range filtering
|
|
created_after: Optional[date] = Field(
|
|
None, description="Filter invoices created on or after this date (YYYY-MM-DD)"
|
|
)
|
|
created_before: Optional[date] = Field(
|
|
None, description="Filter invoices created on or before this date (YYYY-MM-DD)"
|
|
)
|
|
due_after: Optional[date] = Field(
|
|
None, description="Filter invoices due on or after this date (YYYY-MM-DD)"
|
|
)
|
|
due_before: Optional[date] = Field(
|
|
None, description="Filter invoices due on or before this date (YYYY-MM-DD)"
|
|
)
|
|
|
|
# Customer filtering
|
|
customer_name: Optional[str] = Field(
|
|
None, description="Filter invoices by customer name (case-insensitive, partial match)"
|
|
)
|
|
customer_email: Optional[str] = Field(
|
|
None, description="Filter invoices by customer email (case-insensitive, partial match)"
|
|
)
|
|
|
|
# Amount range filtering
|
|
min_amount: Optional[float] = Field(
|
|
None, description="Filter invoices with total amount greater than or equal to this value"
|
|
)
|
|
max_amount: Optional[float] = Field(
|
|
None, description="Filter invoices with total amount less than or equal to this value"
|
|
)
|
|
|
|
# Advanced sorting
|
|
sort_by: Optional[str] = Field(
|
|
None, description="Field to sort by (date_created, due_date, total_amount, customer_name)"
|
|
)
|
|
sort_order: Optional[str] = Field(
|
|
"desc", description="Sort order: 'asc' for ascending, 'desc' for descending"
|
|
)
|
|
|
|
@validator("sort_by")
|
|
def validate_sort_by(cls, v):
|
|
if v is not None:
|
|
allowed_sort_fields = ["date_created", "due_date", "total_amount", "customer_name"]
|
|
if v not in allowed_sort_fields:
|
|
raise ValueError(f"Sort field must be one of {', '.join(allowed_sort_fields)}")
|
|
return v
|
|
|
|
@validator("sort_order")
|
|
def validate_sort_order(cls, v):
|
|
if v is not None:
|
|
allowed_sort_orders = ["asc", "desc"]
|
|
if v.lower() not in allowed_sort_orders:
|
|
raise ValueError(f"Sort order must be one of {', '.join(allowed_sort_orders)}")
|
|
return v.lower() if v else v |