2025-06-02 18:32:05 +00:00

122 lines
3.8 KiB
Python

from sqlalchemy import create_engine, event
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from pathlib import Path
import time
import logging
# Import app config
from app.config import (
PROJECT_ROOT,
CONTAINER_DB_PATH,
DB_CONNECT_RETRY,
DB_CONNECT_RETRY_DELAY,
)
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Get paths for database storage
# Try container path first, then fallback to local path
CONTAINER_DB_PATH = Path(CONTAINER_DB_PATH)
LOCAL_DB_PATH = PROJECT_ROOT / "storage" / "db"
# Use container path if it exists and is writable, otherwise use local path
DB_PATH = CONTAINER_DB_PATH if CONTAINER_DB_PATH.exists() else LOCAL_DB_PATH
DB_PATH.mkdir(parents=True, exist_ok=True)
logger.info(f"Using database path: {DB_PATH}")
# SQLite database URL
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_PATH}/db.sqlite"
# Connection retry settings
MAX_RETRIES = DB_CONNECT_RETRY
RETRY_DELAY = DB_CONNECT_RETRY_DELAY # seconds
# Create the SQLAlchemy engine with retry logic
def get_engine():
for attempt in range(MAX_RETRIES):
try:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
# Test connection
with engine.connect() as conn:
conn.execute("SELECT 1")
return engine
except Exception as e:
if attempt < MAX_RETRIES - 1:
logger.warning(
f"Database connection attempt {attempt + 1} failed: {e}. Retrying in {RETRY_DELAY}s..."
)
time.sleep(RETRY_DELAY)
else:
logger.error(
f"Failed to connect to database after {MAX_RETRIES} attempts: {e}"
)
# Still return the engine, we'll handle connection errors in the request handlers
return create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
# Create engine
engine = get_engine()
# Add event listener for connection pool "checkout" events
@event.listens_for(engine, "connect")
def ping_connection(dbapi_connection, connection_record):
# Ping the connection to ensure it's valid
try:
cursor = dbapi_connection.cursor()
cursor.execute("SELECT 1")
cursor.close()
except Exception:
# Reconnect if the connection is invalid
logger.warning("Connection ping failed. Connection will be recycled.")
connection_record.connection = None
raise
# Create a SessionLocal class
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Create a Base class for declarative models
Base = declarative_base()
# Create tables (important for first run)
def create_tables():
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables created successfully")
except Exception as e:
logger.error(f"Error creating database tables: {e}")
# Don't raise the exception - let the application start even if tables can't be created
# Tables will be created later when the database becomes available
# Dependency to get a database session with improved error handling
def get_db():
db = SessionLocal()
try:
# Test the connection
db.execute("SELECT 1")
yield db
except Exception as e:
# Log the error
logger.error(f"Database connection error in get_db: {e}")
# Provide a user-friendly error
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database connection error. Please try again later.",
)
finally:
db.close()