Update code via agent code generation
This commit is contained in:
parent
c1543dd5e7
commit
1ca2bd18ff
@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Callable, Awaitable
|
||||
from typing import List, Dict, Any, Callable, Awaitable, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.bot_simulation import process_completed_bot_purchases
|
||||
from app.db.session import SessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -13,6 +12,7 @@ class BackgroundTaskManager:
|
||||
def __init__(self):
|
||||
self.tasks: List[Dict[str, Any]] = []
|
||||
self.is_running = False
|
||||
self._main_task: Optional[asyncio.Task] = None
|
||||
|
||||
def add_task(
|
||||
self,
|
||||
@ -39,11 +39,23 @@ class BackgroundTaskManager:
|
||||
self.is_running = True
|
||||
logger.info("Starting background tasks")
|
||||
|
||||
# Start tasks in a separate task to avoid blocking startup
|
||||
self._main_task = asyncio.create_task(self._run_all_tasks())
|
||||
|
||||
async def _run_all_tasks(self) -> None:
|
||||
"""Run all tasks in parallel."""
|
||||
if not self.tasks:
|
||||
logger.warning("No background tasks to run")
|
||||
return
|
||||
|
||||
tasks = []
|
||||
for task_info in self.tasks:
|
||||
tasks.append(self._run_task_periodically(task_info))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background tasks: {str(e)}")
|
||||
|
||||
async def _run_task_periodically(self, task_info: Dict[str, Any]) -> None:
|
||||
"""Run a task periodically at the specified interval."""
|
||||
@ -68,17 +80,24 @@ class BackgroundTaskManager:
|
||||
def stop(self) -> None:
|
||||
"""Stop all background tasks."""
|
||||
self.is_running = False
|
||||
if self._main_task:
|
||||
self._main_task.cancel()
|
||||
logger.info("Stopping background tasks")
|
||||
|
||||
|
||||
# Define our background tasks
|
||||
async def process_bot_purchases() -> None:
|
||||
"""Process completed bot purchases."""
|
||||
# Lazy import to avoid circular imports
|
||||
from app.services.bot_simulation import process_completed_bot_purchases
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
count = process_completed_bot_purchases(db)
|
||||
if count > 0:
|
||||
logger.info(f"Processed {count} completed bot purchases")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing bot purchases: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@ -86,9 +105,10 @@ async def process_bot_purchases() -> None:
|
||||
# Create the task manager instance
|
||||
task_manager = BackgroundTaskManager()
|
||||
|
||||
# Add the bot simulation task
|
||||
task_manager.add_task(
|
||||
"process_bot_purchases",
|
||||
process_bot_purchases,
|
||||
interval_seconds=settings.BOT_SIMULATION_INTERVAL,
|
||||
)
|
||||
# Add the bot simulation task (this will run when the application starts)
|
||||
if settings.BOT_SIMULATION_INTERVAL > 0:
|
||||
task_manager.add_task(
|
||||
"process_bot_purchases",
|
||||
process_bot_purchases,
|
||||
interval_seconds=settings.BOT_SIMULATION_INTERVAL,
|
||||
)
|
@ -1,10 +1,14 @@
|
||||
import os
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
from typing import List, Union, Optional, Annotated
|
||||
|
||||
from pydantic import EmailStr, validator
|
||||
from pydantic import EmailStr, field_validator, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic.functional_validators import BeforeValidator
|
||||
|
||||
# Debug flag
|
||||
DEBUG: bool = os.environ.get("DEBUG", "False").lower() == "true"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@ -21,10 +25,13 @@ class Settings(BaseSettings):
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.environ.get("REFRESH_TOKEN_EXPIRE_DAYS", 7))
|
||||
ALGORITHM: str = "HS256"
|
||||
|
||||
# Debug flag
|
||||
DEBUG: bool = DEBUG
|
||||
|
||||
# CORS settings
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["*"]
|
||||
|
||||
@validator("BACKEND_CORS_ORIGINS", pre=True)
|
||||
@field_validator("BACKEND_CORS_ORIGINS", mode="before")
|
||||
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
|
||||
if isinstance(v, str) and not v.startswith("["):
|
||||
return [i.strip() for i in v.split(",")]
|
||||
@ -33,7 +40,8 @@ class Settings(BaseSettings):
|
||||
raise ValueError(v)
|
||||
|
||||
# Database
|
||||
DB_DIR: Path = Path("/app") / "storage" / "db"
|
||||
# Use relative path for local development
|
||||
DB_DIR: Path = Path("./projects/defitradingsimulationplatformbackend-9xewa6/app/storage/db")
|
||||
|
||||
# Email settings
|
||||
EMAILS_ENABLED: bool = os.environ.get("EMAILS_ENABLED", "False").lower() == "true"
|
||||
@ -42,17 +50,17 @@ class Settings(BaseSettings):
|
||||
SMTP_HOST: Optional[str] = os.environ.get("SMTP_HOST")
|
||||
SMTP_USER: Optional[str] = os.environ.get("SMTP_USER")
|
||||
SMTP_PASSWORD: Optional[str] = os.environ.get("SMTP_PASSWORD")
|
||||
EMAILS_FROM_EMAIL: Optional[EmailStr] = os.environ.get("EMAILS_FROM_EMAIL")
|
||||
EMAILS_FROM_EMAIL: Optional[str] = os.environ.get("EMAILS_FROM_EMAIL")
|
||||
EMAILS_FROM_NAME: Optional[str] = os.environ.get("EMAILS_FROM_NAME")
|
||||
|
||||
# File upload
|
||||
UPLOAD_DIR: Path = Path("/app") / "storage" / "uploads"
|
||||
KYC_UPLOAD_DIR: Path = Path("/app") / "storage" / "kyc"
|
||||
DEPOSIT_PROOFS_DIR: Path = Path("/app") / "storage" / "deposit_proofs"
|
||||
UPLOAD_DIR: Path = Path("./projects/defitradingsimulationplatformbackend-9xewa6/app/storage/uploads")
|
||||
KYC_UPLOAD_DIR: Path = Path("./projects/defitradingsimulationplatformbackend-9xewa6/app/storage/kyc")
|
||||
DEPOSIT_PROOFS_DIR: Path = Path("./projects/defitradingsimulationplatformbackend-9xewa6/app/storage/deposit_proofs")
|
||||
MAX_UPLOAD_SIZE: int = int(os.environ.get("MAX_UPLOAD_SIZE", 5 * 1024 * 1024)) # 5 MB default
|
||||
|
||||
# Admin default settings
|
||||
ADMIN_EMAIL: EmailStr = os.environ.get("ADMIN_EMAIL", "admin@defttrade.com")
|
||||
ADMIN_EMAIL: str = os.environ.get("ADMIN_EMAIL", "admin@defttrade.com")
|
||||
ADMIN_PASSWORD: str = os.environ.get("ADMIN_PASSWORD", "change-me-please")
|
||||
|
||||
# 2FA settings
|
||||
@ -66,9 +74,10 @@ class Settings(BaseSettings):
|
||||
MIN_WITHDRAWAL_AMOUNT: float = float(os.environ.get("MIN_WITHDRAWAL_AMOUNT", 10.0))
|
||||
WITHDRAWAL_FEE_PERCENTAGE: float = float(os.environ.get("WITHDRAWAL_FEE_PERCENTAGE", 1.0))
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = ".env"
|
||||
model_config = {
|
||||
"case_sensitive": True,
|
||||
"env_file": ".env",
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
@ -1,8 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import emails
|
||||
from emails.template import JinjaTemplate
|
||||
# Try to import emails, but provide a fallback if it fails
|
||||
try:
|
||||
import emails
|
||||
from emails.template import JinjaTemplate
|
||||
EMAILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMAILS_AVAILABLE = False
|
||||
logging.warning("Email functionality is disabled due to missing 'emails' package")
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -6,7 +8,22 @@ from app.core.config import settings
|
||||
# Create database directory if it doesn't exist
|
||||
settings.DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = f"sqlite:///{settings.DB_DIR}/db.sqlite"
|
||||
# Ensure we have a safe fallback path if the configured path doesn't work
|
||||
try:
|
||||
db_path = settings.DB_DIR / "db.sqlite"
|
||||
SQLALCHEMY_DATABASE_URL = f"sqlite:///{db_path}"
|
||||
# Test creating a file in the directory to ensure permissions are correct
|
||||
test_file = settings.DB_DIR / "test_access.txt"
|
||||
with open(test_file, "w") as f:
|
||||
f.write("test")
|
||||
os.remove(test_file)
|
||||
except (IOError, PermissionError):
|
||||
# Fallback to a local path if the configured path doesn't work
|
||||
backup_db_dir = Path("./app/storage/db")
|
||||
backup_db_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_path = backup_db_dir / "db.sqlite"
|
||||
SQLALCHEMY_DATABASE_URL = f"sqlite:///{db_path}"
|
||||
print(f"WARNING: Using fallback database path: {db_path}")
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
|
42
main.py
42
main.py
@ -1,8 +1,11 @@
|
||||
import uvicorn
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi import FastAPI, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.v1.api import api_router
|
||||
@ -10,6 +13,15 @@ from app.core.config import settings
|
||||
from app.db.session import get_db
|
||||
from app.core.background_tasks import task_manager
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if settings.DEBUG else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
description=settings.PROJECT_DESCRIPTION,
|
||||
@ -28,6 +40,15 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Global exception handler
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error", "message": str(exc) if settings.DEBUG else None},
|
||||
)
|
||||
|
||||
# Include API router
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
@ -51,6 +72,7 @@ async def health_check(db: Session = Depends(get_db)):
|
||||
db_status = "connected"
|
||||
except Exception as e:
|
||||
db_status = f"error: {str(e)}"
|
||||
logger.error(f"Database health check failed: {str(e)}")
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
@ -63,14 +85,24 @@ async def health_check(db: Session = Depends(get_db)):
|
||||
# Startup event
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# Start background tasks
|
||||
asyncio.create_task(task_manager.start())
|
||||
logger.info("Application startup initiated")
|
||||
try:
|
||||
# Start background tasks
|
||||
asyncio.create_task(task_manager.start())
|
||||
logger.info("Background tasks started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting background tasks: {str(e)}", exc_info=True)
|
||||
|
||||
# Shutdown event
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
# Stop background tasks
|
||||
task_manager.stop()
|
||||
logger.info("Application shutdown initiated")
|
||||
try:
|
||||
# Stop background tasks
|
||||
task_manager.stop()
|
||||
logger.info("Background tasks stopped successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping background tasks: {str(e)}", exc_info=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
Loading…
x
Reference in New Issue
Block a user