55 lines
1.6 KiB
Python

from pathlib import Path
from typing import AsyncGenerator
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# Ensure DB directory exists
DB_DIR = Path("/app") / "storage" / "db"
DB_DIR.mkdir(parents=True, exist_ok=True)
# Create SQLite URL (with URI flag for foreign key support)
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_DIR}/db.sqlite"
# Create engine for synchronous operations (like migrations)
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
# For async operations, we need to use a different driver
ASYNC_SQLALCHEMY_DATABASE_URL = f"sqlite+aiosqlite:///{DB_DIR}/db.sqlite"
async_engine = create_async_engine(
ASYNC_SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
# Create session factories
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
AsyncSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=async_engine,
class_=AsyncSession,
)
# Create a base class for declarative models
Base = declarative_base()
# Dependency function for getting a database session
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency function that yields an async database session.
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()