from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy.orm import Session from pydantic import BaseModel from typing import Optional from app.db.session import get_db from app.models.user import User from app.services.stripe_service import StripeService from app.services.user_service import UserService from app.core.config import settings router = APIRouter() class CheckoutRequest(BaseModel): plan_id: str user_email: str success_url: Optional[str] = None cancel_url: Optional[str] = None class SubscriptionUpdate(BaseModel): subscription_id: str action: str # "cancel", "pause", "resume" @router.get("/plans") async def get_pricing_plans(): return { "plans": [ { "id": "starter", "name": "Starter Plan", "price": 9.99, "currency": "USD", "interval": "month", "features": [ "Up to 10 projects", "Basic templates", "Email support", "5GB storage" ], "stripe_price_id": "price_starter_monthly" }, { "id": "professional", "name": "Professional Plan", "price": 29.99, "currency": "USD", "interval": "month", "features": [ "Unlimited projects", "Premium templates", "Priority support", "50GB storage", "Advanced analytics", "Custom branding" ], "stripe_price_id": "price_pro_monthly" }, { "id": "business", "name": "Business Plan", "price": 99.99, "currency": "USD", "interval": "month", "features": [ "Everything in Professional", "Team collaboration", "API access", "500GB storage", "White-label solution", "Dedicated account manager" ], "stripe_price_id": "price_business_monthly" }, { "id": "enterprise", "name": "Enterprise Plan", "price": 299.99, "currency": "USD", "interval": "month", "features": [ "Everything in Business", "Unlimited storage", "Custom integrations", "SLA guarantee", "On-premise deployment", "24/7 phone support" ], "stripe_price_id": "price_enterprise_monthly" } ] } @router.post("/checkout") async def create_checkout_session( checkout_request: CheckoutRequest, db: Session = Depends(get_db) ): stripe_service = StripeService() user_service = UserService(db) # Get or create user user = user_service.get_user_by_email(checkout_request.user_email) if not user: raise HTTPException(status_code=404, detail="User not found") # Get or create Stripe customer if not user.stripe_customer_id: customer_result = await stripe_service.create_customer( email=user.email, name=user.full_name ) if not customer_result["success"]: raise HTTPException(status_code=400, detail="Failed to create customer") user.stripe_customer_id = customer_result["customer"]["id"] db.commit() # Map plan IDs to Stripe price IDs price_mapping = { "starter": "price_starter_monthly", "professional": "price_pro_monthly", "business": "price_business_monthly", "enterprise": "price_enterprise_monthly" } stripe_price_id = price_mapping.get(checkout_request.plan_id) if not stripe_price_id: raise HTTPException(status_code=400, detail="Invalid plan ID") # Create checkout session success_url = checkout_request.success_url or f"{settings.frontend_url}/subscription/success" cancel_url = checkout_request.cancel_url or f"{settings.frontend_url}/pricing" session_result = await stripe_service.create_checkout_session( price_id=stripe_price_id, customer_id=user.stripe_customer_id, success_url=success_url, cancel_url=cancel_url ) if not session_result["success"]: raise HTTPException(status_code=400, detail="Failed to create checkout session") return { "checkout_url": session_result["session"]["url"], "session_id": session_result["session"]["id"] } @router.post("/webhook") async def stripe_webhook(request: Request, db: Session = Depends(get_db)): payload = await request.body() sig_header = request.headers.get('stripe-signature') try: import stripe event = stripe.Webhook.construct_event( payload, sig_header, settings.stripe_webhook_secret ) except ValueError: raise HTTPException(status_code=400, detail="Invalid payload") except stripe.error.SignatureVerificationError: raise HTTPException(status_code=400, detail="Invalid signature") # Handle the event if event['type'] == 'checkout.session.completed': session = event['data']['object'] customer_id = session['customer'] # Update user subscription status user = db.query(User).filter(User.stripe_customer_id == customer_id).first() if user: user.subscription_status = "active" user.subscription_plan = "premium" # You might want to derive this from the session db.commit() elif event['type'] == 'customer.subscription.updated': subscription = event['data']['object'] customer_id = subscription['customer'] user = db.query(User).filter(User.stripe_customer_id == customer_id).first() if user: user.subscription_status = subscription['status'] db.commit() elif event['type'] == 'customer.subscription.deleted': subscription = event['data']['object'] customer_id = subscription['customer'] user = db.query(User).filter(User.stripe_customer_id == customer_id).first() if user: user.subscription_status = "cancelled" user.subscription_plan = None db.commit() return {"status": "success"} @router.get("/subscription/{user_id}") async def get_user_subscription(user_id: int, db: Session = Depends(get_db)): user_service = UserService(db) user = user_service.get_user(user_id) if not user: raise HTTPException(status_code=404, detail="User not found") return { "status": user.subscription_status, "plan": user.subscription_plan, "stripe_customer_id": user.stripe_customer_id } @router.post("/subscription/manage") async def manage_subscription( subscription_update: SubscriptionUpdate, db: Session = Depends(get_db) ): stripe_service = StripeService() if subscription_update.action == "cancel": result = await stripe_service.cancel_subscription(subscription_update.subscription_id) if result["success"]: return {"message": "Subscription cancelled successfully"} else: raise HTTPException(status_code=400, detail="Failed to cancel subscription") # Add more subscription management actions as needed raise HTTPException(status_code=400, detail="Invalid action") @router.post("/setup-products") async def setup_stripe_products(): stripe_service = StripeService() try: products = await stripe_service.create_product_and_prices() return { "message": "Products and prices created successfully", "products": len(products) } except Exception as e: raise HTTPException(status_code=500, detail=f"Error setting up products: {str(e)}")