diff --git a/README.md b/README.md index c35c265..be7fcf8 100644 --- a/README.md +++ b/README.md @@ -220,23 +220,26 @@ The API has robust CORS (Cross-Origin Resource Sharing) enabled with the followi ### Custom CORS Handling -This application implements a custom CORS middleware that properly handles preflight OPTIONS requests for all endpoints, including authentication routes. The middleware includes: +This application implements a low-level ASGI CORS middleware that properly handles preflight OPTIONS requests for all endpoints, including authentication routes. The implementation includes: -1. Direct handling of OPTIONS requests for all endpoints -2. Proper header handling for preflight responses -3. Explicit support for POST requests with JSON content-type +1. Low-level ASGI middleware that directly handles HTTP requests before FastAPI routing +2. Special handling for OPTIONS preflight requests for all routes +3. Explicit support for POST requests with JSON content-type 4. Full support for Authorization headers for authenticated endpoints -5. Pattern matching for wildcard domains (e.g., *.vercel.app) +5. Dedicated OPTIONS route handlers for critical endpoints like authentication -### CORS Test Endpoint +The CORS system is implemented at multiple levels to ensure maximum compatibility: -The API includes a special endpoint for testing CORS functionality: +1. **ASGI Middleware**: Intercepts all requests at the ASGI protocol level before FastAPI processing +2. **Dedicated OPTIONS Handlers**: Specific route handlers for authentication endpoints +3. **Response Header Injection**: Adds proper CORS headers to all responses +### Critical Endpoints with Special CORS Support + +The API includes dedicated OPTIONS handlers for these critical endpoints: + +- `OPTIONS /api/v1/auth/register` - Register endpoint preflight support +- `OPTIONS /api/v1/auth/login` - Login endpoint preflight support +- `OPTIONS /api/v1/users/me` - User profile endpoint preflight support - `OPTIONS /api/v1/cors-test` - Test preflight requests -- `POST /api/v1/cors-test` - Test POST requests with JSON body - -### Environment Variables - -| Variable | Description | Default | -|----------|-------------|---------| -| USE_CUSTOM_CORS_ONLY | Whether to use only the custom CORS middleware | True | \ No newline at end of file +- `POST /api/v1/cors-test` - Test POST requests with JSON body \ No newline at end of file diff --git a/main.py b/main.py index 36ac595..d72cd9e 100644 --- a/main.py +++ b/main.py @@ -1,88 +1,89 @@ import uvicorn from fastapi import FastAPI, Request, Response -from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp, Receive, Scope, Send from app.api.v1.api import api_router from app.core.config import settings -class CustomCORSMiddleware(BaseHTTPMiddleware): - """Custom middleware to ensure CORS headers are set correctly for all responses.""" +class CORSMiddlewareASGI: + """A lower-level ASGI middleware for CORS that intercepts all requests.""" - def is_origin_allowed(self, origin: str) -> bool: - """Check if the origin is allowed based on the CORS_ORIGINS settings.""" - if not origin: - return False + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": + # Pass through other types of requests (WebSocket, lifespan) + await self.app(scope, receive, send) + return + + # Get request info + method = scope.get("method", "") + headers = dict(scope.get("headers", [])) + + # Convert byte headers to strings + origin = None + if b'origin' in headers: + origin = headers[b'origin'].decode('utf-8') - # Direct match - if origin in settings.CORS_ORIGINS: - return True - - # Wildcard match - if "*" is in the allowed origins list - if "*" in settings.CORS_ORIGINS: - return True - - # Check for pattern matching (e.g., https://*.vercel.app) - for allowed_origin in settings.CORS_ORIGINS: - if "*" in allowed_origin and not allowed_origin == "*": - pattern_parts = allowed_origin.split("*") - if len(pattern_parts) == 2: - if origin.startswith(pattern_parts[0]) and origin.endswith(pattern_parts[1]): - return True + # Handle OPTIONS requests for CORS preflight + if method == "OPTIONS": + async def send_preflight_response(message): + if message["type"] == "http.response.start": + # Create a custom response for preflight + headers = [ + (b"content-type", b"text/plain"), + (b"content-length", b"0"), + ] - return False - - async def dispatch(self, request: Request, call_next): - origin = request.headers.get("origin", "") - - # Always respond to OPTIONS requests directly for preflight handling - if request.method == "OPTIONS": - # Create a new response for preflight - response = Response(status_code=204) # No content needed for preflight + # Add CORS headers + if origin: + headers.extend([ + (b"access-control-allow-origin", origin.encode()), + (b"access-control-allow-credentials", b"true"), + (b"access-control-allow-methods", b"GET, POST, PUT, DELETE, OPTIONS, PATCH".encode()), + (b"access-control-allow-headers", b"Authorization, Content-Type, Accept, Accept-Language, Content-Language, Content-Length, Origin, X-Requested-With, X-CSRF-Token, Access-Control-Allow-Origin, Access-Control-Allow-Credentials, X-HTTP-Method-Override"), + (b"access-control-max-age", b"3600"), + (b"vary", b"Origin"), + ]) + + # Send the response + await send({ + "type": "http.response.start", + "status": 200, + "headers": headers + }) + else: + await send(message) + + # Handle the preflight request with our custom response + await send_preflight_response({"type": "http.response.start"}) + await send({"type": "http.response.body", "body": b""}) + return - # If no origin or not allowed, return 204 with minimal headers - # This will not block the request but won't allow CORS either - if not origin or not self.is_origin_allowed(origin): - return response + # For non-OPTIONS requests, wrap the send function to add CORS headers + async def cors_send(message): + if message["type"] == "http.response.start": + # Get original headers + headers = list(message.get("headers", [])) + + # Add CORS headers if origin is present + if origin: + # Add CORS headers + headers.extend([ + (b"access-control-allow-origin", origin.encode()), + (b"access-control-allow-credentials", b"true"), + (b"vary", b"Origin"), + ]) + + # Send modified response + message["headers"] = headers + + await send(message) - # If origin is allowed, set the full CORS headers - response.headers["Access-Control-Allow-Origin"] = origin - - # Include all possible headers that might be used by the frontend - # Make sure Content-Type is included to support application/json - response.headers["Access-Control-Allow-Headers"] = ( - "Authorization, Content-Type, Accept, Accept-Language, " + - "Content-Language, Content-Length, Origin, X-Requested-With, " + - "X-CSRF-Token, Access-Control-Allow-Origin, Access-Control-Allow-Credentials, " + - "X-Requested-With, X-HTTP-Method-Override" - ) - - # Expose headers that frontend might need to access - response.headers["Access-Control-Expose-Headers"] = ( - "Content-Length, Content-Type, Authorization" - ) - - response.headers["Access-Control-Allow-Credentials"] = "true" - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, PATCH" - response.headers["Access-Control-Max-Age"] = "3600" # 1 hour cache - response.status_code = 200 # OK for successful preflight - - return response - - # For regular requests, process normally then add CORS headers - response = await call_next(request) - - # Add CORS headers to all responses if origin is allowed - if self.is_origin_allowed(origin): - # Set required CORS headers - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Credentials"] = "true" - - # Add Vary header to indicate caching should consider Origin - response.headers["Vary"] = "Origin" - - return response + # Process the request with CORS headers added to response + await self.app(scope, receive, cors_send) app = FastAPI( title=settings.PROJECT_NAME, @@ -93,20 +94,8 @@ app = FastAPI( redoc_url="/redoc", ) -# Add our custom CORS middleware first (higher priority) -app.add_middleware(CustomCORSMiddleware) - -# Set up standard CORS middleware as a backup if not disabled -if not settings.USE_CUSTOM_CORS_ONLY: - app.add_middleware( - CORSMiddleware, - allow_origins=settings.CORS_ORIGINS, - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], - allow_headers=["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With", "X-CSRF-Token", "Access-Control-Allow-Credentials"], - expose_headers=["Content-Length", "Content-Type"], - max_age=600, # 10 minutes cache for preflight requests - ) +# Remove all middleware and mount the main app with our ASGI middleware +app = CORSMiddlewareASGI(app) # Include API router app.include_router(api_router) @@ -116,11 +105,11 @@ app.include_router(api_router) async def health_check(): return {"status": "healthy"} -# CORS test endpoint +# CORS test endpoints @app.options("/api/v1/cors-test", tags=["cors"]) async def cors_preflight_test(): """Test endpoint for CORS preflight requests.""" - return None + return Response(status_code=200) @app.post("/api/v1/cors-test", tags=["cors"]) async def cors_test(request: Request): @@ -140,5 +129,21 @@ async def cors_test(request: Request): "headers": dict(request.headers) } +# Additional OPTIONS handlers for critical endpoints +@app.options("/api/v1/auth/register", include_in_schema=False) +async def auth_register_options(): + """Handle OPTIONS preflight requests for auth register endpoint.""" + return Response(status_code=200) + +@app.options("/api/v1/auth/login", include_in_schema=False) +async def auth_login_options(): + """Handle OPTIONS preflight requests for auth login endpoint.""" + return Response(status_code=200) + +@app.options("/api/v1/users/me", include_in_schema=False) +async def users_me_options(): + """Handle OPTIONS preflight requests for users/me endpoint.""" + return Response(status_code=200) + if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file