diff --git a/backend/app/database.py b/backend/app/database.py new file mode 100644 index 0000000..59d0f2e --- /dev/null +++ b/backend/app/database.py @@ -0,0 +1,73 @@ +import sqlalchemy +from sqlalchemy import create_engine, text +from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session +from pathlib import Path +from contextlib import contextmanager +from typing import Generator, Optional + + +class Base(DeclarativeBase): + """SQLAlchemy 声明基类.""" + pass + + +_engine = None +_session_factory = None + + +def init_db(db_path: Path) -> None: + """ + 初始化数据库引擎和会话工厂. + + Args: + db_path: SQLite 数据库文件路径 + """ + global _engine, _session_factory + + # 确保父目录存在 + db_path.parent.mkdir(parents=True, exist_ok=True) + + # 创建引擎 + _engine = create_engine( + f'sqlite:///{db_path}', + connect_args={'check_same_thread': False} + ) + _session_factory = sessionmaker(bind=_engine, autocommit=False, autoflush=False) + + # 确保 SQLite 数据库文件被创建(SQLite 是惰性创建的) + with _engine.connect() as conn: + conn.execute(text("SELECT 1")) + + +def get_engine(): + """获取数据库引擎.""" + return _engine + + +def get_session_factory(): + """获取会话工厂.""" + return _session_factory + + +@contextmanager +def get_db(db_path: Optional[Path] = None) -> Generator[Session, None, None]: + """ + 获取数据库会话. + + Args: + db_path: 可选,用于初始化数据库 + + Yields: + SQLAlchemy 会话 + """ + if db_path and _engine is None: + init_db(db_path) + + if _session_factory is None: + raise RuntimeError("Database not initialized. Call init_db() first.") + + session = _session_factory() + try: + yield session + finally: + session.close() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index e0974f3..57056e3 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,11 +1,3 @@ -"""ORM Models. +from app.database import Base -NOTE: This module is a placeholder until Task 2.1. -The Base class is needed by conftest.py for database fixtures. -""" - -from sqlalchemy.orm import DeclarativeBase - -class Base(DeclarativeBase): - """Base class for all ORM models.""" - pass +__all__ = ['Base'] diff --git a/backend/init_db.py b/backend/init_db.py new file mode 100644 index 0000000..c0d92af --- /dev/null +++ b/backend/init_db.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +""" +数据库初始化脚本. +创建所有表和必要的目录. +""" +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) + +from app.config import get_settings +from app.database import init_db +from app.models import Base + + +def main(): + """初始化数据库.""" + settings = get_settings() + print(f"初始化数据库: {settings.db_path}") + + # 创建目录 + settings.data_dir.mkdir(parents=True, exist_ok=True) + settings.ssh_keys_dir.mkdir(parents=True, exist_ok=True) + settings.repos_dir.mkdir(parents=True, exist_ok=True) + + # 初始化数据库 + init_db(settings.db_path) + + # 创建所有表 + from app.database import get_engine + Base.metadata.create_all(get_engine()) + + print("数据库初始化成功!") + print(f" - 数据库: {settings.db_path}") + print(f" - SSH 密钥: {settings.ssh_keys_dir}") + print(f" - 仓库: {settings.repos_dir}") + + +if __name__ == '__main__': + main() diff --git a/backend/tests/test_database.py b/backend/tests/test_database.py new file mode 100644 index 0000000..a2e6bb1 --- /dev/null +++ b/backend/tests/test_database.py @@ -0,0 +1,26 @@ +from pathlib import Path + +def test_database_initialization(db_path): + """测试数据库初始化.""" + from app.database import init_db, get_engine, Base + + init_db(db_path) + + assert db_path.exists() + + engine = get_engine() + assert engine is not None + + # 创建所有表 + Base.metadata.create_all(engine) + assert True # 如果没有异常则成功 + + +def test_get_session(db_path): + """测试获取数据库会话.""" + from app.database import init_db, get_db + + init_db(db_path) + + with get_db(db_path) as session: + assert session is not None