149 lines
5.6 KiB
Python
149 lines
5.6 KiB
Python
import uvicorn
|
|
from fastapi import FastAPI, Request, Response
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
from app.api.v1.api import api_router
|
|
from app.core.config import settings
|
|
|
|
|
|
class CORSMiddlewareASGI:
|
|
"""A lower-level ASGI middleware for CORS that intercepts all requests."""
|
|
|
|
def __init__(self, app: ASGIApp):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
|
if scope["type"] != "http":
|
|
# Pass through other types of requests (WebSocket, lifespan)
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
# Get request info
|
|
method = scope.get("method", "")
|
|
headers = dict(scope.get("headers", []))
|
|
|
|
# Convert byte headers to strings
|
|
origin = None
|
|
if b'origin' in headers:
|
|
origin = headers[b'origin'].decode('utf-8')
|
|
|
|
# Handle OPTIONS requests for CORS preflight
|
|
if method == "OPTIONS":
|
|
async def send_preflight_response(message):
|
|
if message["type"] == "http.response.start":
|
|
# Create a custom response for preflight
|
|
headers = [
|
|
(b"content-type", b"text/plain"),
|
|
(b"content-length", b"0"),
|
|
]
|
|
|
|
# Add CORS headers
|
|
if origin:
|
|
headers.extend([
|
|
(b"access-control-allow-origin", origin.encode()),
|
|
(b"access-control-allow-credentials", b"true"),
|
|
(b"access-control-allow-methods", b"GET, POST, PUT, DELETE, OPTIONS, PATCH".encode()),
|
|
(b"access-control-allow-headers", b"Authorization, Content-Type, Accept, Accept-Language, Content-Language, Content-Length, Origin, X-Requested-With, X-CSRF-Token, Access-Control-Allow-Origin, Access-Control-Allow-Credentials, X-HTTP-Method-Override"),
|
|
(b"access-control-max-age", b"3600"),
|
|
(b"vary", b"Origin"),
|
|
])
|
|
|
|
# Send the response
|
|
await send({
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": headers
|
|
})
|
|
else:
|
|
await send(message)
|
|
|
|
# Handle the preflight request with our custom response
|
|
await send_preflight_response({"type": "http.response.start"})
|
|
await send({"type": "http.response.body", "body": b""})
|
|
return
|
|
|
|
# For non-OPTIONS requests, wrap the send function to add CORS headers
|
|
async def cors_send(message):
|
|
if message["type"] == "http.response.start":
|
|
# Get original headers
|
|
headers = list(message.get("headers", []))
|
|
|
|
# Add CORS headers if origin is present
|
|
if origin:
|
|
# Add CORS headers
|
|
headers.extend([
|
|
(b"access-control-allow-origin", origin.encode()),
|
|
(b"access-control-allow-credentials", b"true"),
|
|
(b"vary", b"Origin"),
|
|
])
|
|
|
|
# Send modified response
|
|
message["headers"] = headers
|
|
|
|
await send(message)
|
|
|
|
# Process the request with CORS headers added to response
|
|
await self.app(scope, receive, cors_send)
|
|
|
|
app = FastAPI(
|
|
title=settings.PROJECT_NAME,
|
|
description=settings.PROJECT_DESCRIPTION,
|
|
version=settings.PROJECT_VERSION,
|
|
openapi_url="/openapi.json",
|
|
docs_url="/docs",
|
|
redoc_url="/redoc",
|
|
)
|
|
|
|
# Remove all middleware and mount the main app with our ASGI middleware
|
|
app = CORSMiddlewareASGI(app)
|
|
|
|
# Include API router
|
|
app.include_router(api_router)
|
|
|
|
# Root health check endpoint
|
|
@app.get("/health", tags=["health"])
|
|
async def health_check():
|
|
return {"status": "healthy"}
|
|
|
|
# CORS test endpoints
|
|
@app.options("/api/v1/cors-test", tags=["cors"])
|
|
async def cors_preflight_test():
|
|
"""Test endpoint for CORS preflight requests."""
|
|
return Response(status_code=200)
|
|
|
|
@app.post("/api/v1/cors-test", tags=["cors"])
|
|
async def cors_test(request: Request):
|
|
"""Test endpoint for CORS POST requests with JSON."""
|
|
try:
|
|
body = await request.json()
|
|
return {
|
|
"success": True,
|
|
"message": "CORS is working correctly for POST requests with JSON",
|
|
"received_data": body,
|
|
"headers": dict(request.headers)
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"message": f"Error: {str(e)}",
|
|
"headers": dict(request.headers)
|
|
}
|
|
|
|
# Additional OPTIONS handlers for critical endpoints
|
|
@app.options("/api/v1/auth/register", include_in_schema=False)
|
|
async def auth_register_options():
|
|
"""Handle OPTIONS preflight requests for auth register endpoint."""
|
|
return Response(status_code=200)
|
|
|
|
@app.options("/api/v1/auth/login", include_in_schema=False)
|
|
async def auth_login_options():
|
|
"""Handle OPTIONS preflight requests for auth login endpoint."""
|
|
return Response(status_code=200)
|
|
|
|
@app.options("/api/v1/users/me", include_in_schema=False)
|
|
async def users_me_options():
|
|
"""Handle OPTIONS preflight requests for users/me endpoint."""
|
|
return Response(status_code=200)
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |