122 lines
3.8 KiB
Python
122 lines
3.8 KiB
Python
from sqlalchemy import create_engine, event
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker
|
|
from pathlib import Path
|
|
import time
|
|
import logging
|
|
|
|
# Import app config
|
|
from app.config import (
|
|
PROJECT_ROOT,
|
|
CONTAINER_DB_PATH,
|
|
DB_CONNECT_RETRY,
|
|
DB_CONNECT_RETRY_DELAY,
|
|
)
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Get paths for database storage
|
|
# Try container path first, then fallback to local path
|
|
CONTAINER_DB_PATH = Path(CONTAINER_DB_PATH)
|
|
LOCAL_DB_PATH = PROJECT_ROOT / "storage" / "db"
|
|
|
|
# Use container path if it exists and is writable, otherwise use local path
|
|
DB_PATH = CONTAINER_DB_PATH if CONTAINER_DB_PATH.exists() else LOCAL_DB_PATH
|
|
DB_PATH.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"Using database path: {DB_PATH}")
|
|
|
|
# SQLite database URL
|
|
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_PATH}/db.sqlite"
|
|
|
|
# Connection retry settings
|
|
MAX_RETRIES = DB_CONNECT_RETRY
|
|
RETRY_DELAY = DB_CONNECT_RETRY_DELAY # seconds
|
|
|
|
|
|
# Create the SQLAlchemy engine with retry logic
|
|
def get_engine():
|
|
for attempt in range(MAX_RETRIES):
|
|
try:
|
|
engine = create_engine(
|
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
)
|
|
# Test connection
|
|
with engine.connect() as conn:
|
|
conn.execute("SELECT 1")
|
|
return engine
|
|
except Exception as e:
|
|
if attempt < MAX_RETRIES - 1:
|
|
logger.warning(
|
|
f"Database connection attempt {attempt + 1} failed: {e}. Retrying in {RETRY_DELAY}s..."
|
|
)
|
|
time.sleep(RETRY_DELAY)
|
|
else:
|
|
logger.error(
|
|
f"Failed to connect to database after {MAX_RETRIES} attempts: {e}"
|
|
)
|
|
# Still return the engine, we'll handle connection errors in the request handlers
|
|
return create_engine(
|
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
)
|
|
|
|
|
|
# Create engine
|
|
engine = get_engine()
|
|
|
|
|
|
# Add event listener for connection pool "checkout" events
|
|
@event.listens_for(engine, "connect")
|
|
def ping_connection(dbapi_connection, connection_record):
|
|
# Ping the connection to ensure it's valid
|
|
try:
|
|
cursor = dbapi_connection.cursor()
|
|
cursor.execute("SELECT 1")
|
|
cursor.close()
|
|
except Exception:
|
|
# Reconnect if the connection is invalid
|
|
logger.warning("Connection ping failed. Connection will be recycled.")
|
|
connection_record.connection = None
|
|
raise
|
|
|
|
|
|
# Create a SessionLocal class
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
# Create a Base class for declarative models
|
|
Base = declarative_base()
|
|
|
|
|
|
# Create tables (important for first run)
|
|
def create_tables():
|
|
try:
|
|
Base.metadata.create_all(bind=engine)
|
|
logger.info("Database tables created successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error creating database tables: {e}")
|
|
# Don't raise the exception - let the application start even if tables can't be created
|
|
# Tables will be created later when the database becomes available
|
|
|
|
|
|
# Dependency to get a database session with improved error handling
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
# Test the connection
|
|
db.execute("SELECT 1")
|
|
yield db
|
|
except Exception as e:
|
|
# Log the error
|
|
logger.error(f"Database connection error in get_db: {e}")
|
|
# Provide a user-friendly error
|
|
from fastapi import HTTPException, status
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Database connection error. Please try again later.",
|
|
)
|
|
finally:
|
|
db.close()
|