from pathlib import Path from typing import AsyncGenerator, Generator from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker # Ensure DB directory exists DB_DIR = Path("/app") / "storage" / "db" DB_DIR.mkdir(parents=True, exist_ok=True) # Create SQLite URL (with URI flag for foreign key support) SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_DIR}/db.sqlite" # Create engine for synchronous operations (like migrations) engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) # Create session factories for synchronous operations SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Create a base class for declarative models Base = declarative_base() # For async operations, we need to use a different driver try: import aiosqlite # noqa: F401 - needed for sqlalchemy dialect registration ASYNC_SQLALCHEMY_DATABASE_URL = f"sqlite+aiosqlite:///{DB_DIR}/db.sqlite" async_engine = create_async_engine( ASYNC_SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) AsyncSessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=async_engine, class_=AsyncSession, ) # Dependency function for getting an async database session async def get_db() -> AsyncGenerator[AsyncSession, None]: """ Dependency function that yields an async database session. """ async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() except ImportError: # Fallback to synchronous operations if aiosqlite is not available async_engine = None AsyncSessionLocal = None # Fallback synchronous dependency def get_db_sync() -> Generator[Session, None, None]: """ Synchronous dependency function for getting a database session. For use when aiosqlite is not available. """ db = SessionLocal() try: yield db db.commit() except Exception: db.rollback() raise finally: db.close() # Wrapper for backward compatibility async def get_db() -> Generator[Session, None, None]: """ Compatibility wrapper that provides a synchronous session when aiosqlite is not available. """ yield from get_db_sync()