96 lines
3.6 KiB
Python

import uvicorn
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from app.api.v1.api import api_router
from app.core.config import settings
class CustomCORSMiddleware(BaseHTTPMiddleware):
"""Custom middleware to ensure CORS headers are set correctly for all responses."""
def is_origin_allowed(self, origin: str) -> bool:
"""Check if the origin is allowed based on the CORS_ORIGINS settings."""
if not origin:
return False
# Direct match
if origin in settings.CORS_ORIGINS:
return True
# Wildcard match
if "*" in settings.CORS_ORIGINS:
return True
# Check for pattern matching (e.g., https://*.vercel.app)
for allowed_origin in settings.CORS_ORIGINS:
if "*" in allowed_origin:
pattern = allowed_origin.replace("*", "")
if origin.startswith(pattern.split("*")[0]) and origin.endswith(pattern.split("*")[-1]):
return True
return False
async def dispatch(self, request: Request, call_next):
origin = request.headers.get("origin", "")
if request.method == "OPTIONS":
# Handle preflight requests
response = Response()
# Check if the origin is allowed
if self.is_origin_allowed(origin):
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, PATCH"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, Accept, Origin, X-Requested-With, X-CSRF-Token, Access-Control-Allow-Credentials"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "600"
response.status_code = 200
return response
# Process regular requests
response = await call_next(request)
# Set CORS headers for the response
if self.is_origin_allowed(origin):
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
app = FastAPI(
title=settings.PROJECT_NAME,
description=settings.PROJECT_DESCRIPTION,
version=settings.PROJECT_VERSION,
openapi_url="/openapi.json",
docs_url="/docs",
redoc_url="/redoc",
)
# Add our custom CORS middleware first (higher priority)
app.add_middleware(CustomCORSMiddleware)
# Set up standard CORS middleware as a backup if not disabled
if not settings.USE_CUSTOM_CORS_ONLY:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
allow_headers=["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With", "X-CSRF-Token", "Access-Control-Allow-Credentials"],
expose_headers=["Content-Length", "Content-Type"],
max_age=600, # 10 minutes cache for preflight requests
)
# Include API router
app.include_router(api_router)
# Root health check endpoint
@app.get("/health", tags=["health"])
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)