85 lines
2.7 KiB
Python

from pathlib import Path
from typing import AsyncGenerator, Generator
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
# Ensure DB directory exists
DB_DIR = Path("/app") / "storage" / "db"
DB_DIR.mkdir(parents=True, exist_ok=True)
# Create SQLite URL (with URI flag for foreign key support)
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_DIR}/db.sqlite"
# Create engine for synchronous operations (like migrations)
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
# Create session factories for synchronous operations
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Create a base class for declarative models
Base = declarative_base()
# For async operations, we need to use a different driver
try:
import aiosqlite # noqa: F401 - needed for sqlalchemy dialect registration
ASYNC_SQLALCHEMY_DATABASE_URL = f"sqlite+aiosqlite:///{DB_DIR}/db.sqlite"
async_engine = create_async_engine(
ASYNC_SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
AsyncSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=async_engine,
class_=AsyncSession,
)
# Dependency function for getting an async database session
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency function that yields an async database session.
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
except ImportError:
# Fallback to synchronous operations if aiosqlite is not available
async_engine = None
AsyncSessionLocal = None
# Fallback synchronous dependency
def get_db_sync() -> Generator[Session, None, None]:
"""
Synchronous dependency function for getting a database session.
For use when aiosqlite is not available.
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# Wrapper for backward compatibility
async def get_db() -> Generator[Session, None, None]:
"""
Compatibility wrapper that provides a synchronous session
when aiosqlite is not available.
"""
yield from get_db_sync()