import sqlalchemy from sqlalchemy import create_engine, text from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session from pathlib import Path from contextlib import contextmanager from typing import Generator, Optional class Base(DeclarativeBase): """SQLAlchemy 声明基类.""" pass _engine = None _session_factory = None def init_db(db_path: Path) -> None: """ 初始化数据库引擎和会话工厂. Args: db_path: SQLite 数据库文件路径 """ global _engine, _session_factory # 确保父目录存在 db_path.parent.mkdir(parents=True, exist_ok=True) # 创建引擎 _engine = create_engine( f'sqlite:///{db_path}', connect_args={'check_same_thread': False} ) _session_factory = sessionmaker(bind=_engine, autocommit=False, autoflush=False) # 确保 SQLite 数据库文件被创建(SQLite 是惰性创建的) with _engine.connect() as conn: conn.execute(text("SELECT 1")) def get_engine(): """获取数据库引擎.""" return _engine def get_session_factory(): """获取会话工厂.""" return _session_factory @contextmanager def get_db(db_path: Optional[Path] = None) -> Generator[Session, None, None]: """ 获取数据库会话. Args: db_path: 可选,用于初始化数据库 Yields: SQLAlchemy 会话 """ if db_path and _engine is None: init_db(db_path) if _session_factory is None: raise RuntimeError("Database not initialized. Call init_db() first.") session = _session_factory() try: yield session finally: session.close()