144 lines
5.6 KiB
Python
144 lines
5.6 KiB
Python
import uvicorn
|
|
from fastapi import FastAPI, Request, Response
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.api.v1.api import api_router
|
|
from app.core.config import settings
|
|
|
|
|
|
class CustomCORSMiddleware(BaseHTTPMiddleware):
|
|
"""Custom middleware to ensure CORS headers are set correctly for all responses."""
|
|
|
|
def is_origin_allowed(self, origin: str) -> bool:
|
|
"""Check if the origin is allowed based on the CORS_ORIGINS settings."""
|
|
if not origin:
|
|
return False
|
|
|
|
# Direct match
|
|
if origin in settings.CORS_ORIGINS:
|
|
return True
|
|
|
|
# Wildcard match - if "*" is in the allowed origins list
|
|
if "*" in settings.CORS_ORIGINS:
|
|
return True
|
|
|
|
# Check for pattern matching (e.g., https://*.vercel.app)
|
|
for allowed_origin in settings.CORS_ORIGINS:
|
|
if "*" in allowed_origin and not allowed_origin == "*":
|
|
pattern_parts = allowed_origin.split("*")
|
|
if len(pattern_parts) == 2:
|
|
if origin.startswith(pattern_parts[0]) and origin.endswith(pattern_parts[1]):
|
|
return True
|
|
|
|
return False
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
origin = request.headers.get("origin", "")
|
|
|
|
# Always respond to OPTIONS requests directly for preflight handling
|
|
if request.method == "OPTIONS":
|
|
# Create a new response for preflight
|
|
response = Response(status_code=204) # No content needed for preflight
|
|
|
|
# If no origin or not allowed, return 204 with minimal headers
|
|
# This will not block the request but won't allow CORS either
|
|
if not origin or not self.is_origin_allowed(origin):
|
|
return response
|
|
|
|
# If origin is allowed, set the full CORS headers
|
|
response.headers["Access-Control-Allow-Origin"] = origin
|
|
|
|
# Include all possible headers that might be used by the frontend
|
|
# Make sure Content-Type is included to support application/json
|
|
response.headers["Access-Control-Allow-Headers"] = (
|
|
"Authorization, Content-Type, Accept, Accept-Language, " +
|
|
"Content-Language, Content-Length, Origin, X-Requested-With, " +
|
|
"X-CSRF-Token, Access-Control-Allow-Origin, Access-Control-Allow-Credentials, " +
|
|
"X-Requested-With, X-HTTP-Method-Override"
|
|
)
|
|
|
|
# Expose headers that frontend might need to access
|
|
response.headers["Access-Control-Expose-Headers"] = (
|
|
"Content-Length, Content-Type, Authorization"
|
|
)
|
|
|
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
|
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, PATCH"
|
|
response.headers["Access-Control-Max-Age"] = "3600" # 1 hour cache
|
|
response.status_code = 200 # OK for successful preflight
|
|
|
|
return response
|
|
|
|
# For regular requests, process normally then add CORS headers
|
|
response = await call_next(request)
|
|
|
|
# Add CORS headers to all responses if origin is allowed
|
|
if self.is_origin_allowed(origin):
|
|
# Set required CORS headers
|
|
response.headers["Access-Control-Allow-Origin"] = origin
|
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
|
|
|
# Add Vary header to indicate caching should consider Origin
|
|
response.headers["Vary"] = "Origin"
|
|
|
|
return response
|
|
|
|
app = FastAPI(
|
|
title=settings.PROJECT_NAME,
|
|
description=settings.PROJECT_DESCRIPTION,
|
|
version=settings.PROJECT_VERSION,
|
|
openapi_url="/openapi.json",
|
|
docs_url="/docs",
|
|
redoc_url="/redoc",
|
|
)
|
|
|
|
# Add our custom CORS middleware first (higher priority)
|
|
app.add_middleware(CustomCORSMiddleware)
|
|
|
|
# Set up standard CORS middleware as a backup if not disabled
|
|
if not settings.USE_CUSTOM_CORS_ONLY:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.CORS_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
|
allow_headers=["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With", "X-CSRF-Token", "Access-Control-Allow-Credentials"],
|
|
expose_headers=["Content-Length", "Content-Type"],
|
|
max_age=600, # 10 minutes cache for preflight requests
|
|
)
|
|
|
|
# Include API router
|
|
app.include_router(api_router)
|
|
|
|
# Root health check endpoint
|
|
@app.get("/health", tags=["health"])
|
|
async def health_check():
|
|
return {"status": "healthy"}
|
|
|
|
# CORS test endpoint
|
|
@app.options("/api/v1/cors-test", tags=["cors"])
|
|
async def cors_preflight_test():
|
|
"""Test endpoint for CORS preflight requests."""
|
|
return None
|
|
|
|
@app.post("/api/v1/cors-test", tags=["cors"])
|
|
async def cors_test(request: Request):
|
|
"""Test endpoint for CORS POST requests with JSON."""
|
|
try:
|
|
body = await request.json()
|
|
return {
|
|
"success": True,
|
|
"message": "CORS is working correctly for POST requests with JSON",
|
|
"received_data": body,
|
|
"headers": dict(request.headers)
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"message": f"Error: {str(e)}",
|
|
"headers": dict(request.headers)
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |