feat: add database module
- Add SQLAlchemy database module with DeclarativeBase - Implement engine and session factory management - Add context manager for database sessions - Add database initialization script - Update models/__init__.py to import Base from database - Fix Python 3.8 compatibility issues (use Optional instead of |) - Ensure SQLite database file is created on init_db Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
73
backend/app/database.py
Normal file
73
backend/app/database.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
|
||||||
|
from pathlib import Path
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Generator, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""SQLAlchemy 声明基类."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_engine = None
|
||||||
|
_session_factory = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_db(db_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
初始化数据库引擎和会话工厂.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: SQLite 数据库文件路径
|
||||||
|
"""
|
||||||
|
global _engine, _session_factory
|
||||||
|
|
||||||
|
# 确保父目录存在
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 创建引擎
|
||||||
|
_engine = create_engine(
|
||||||
|
f'sqlite:///{db_path}',
|
||||||
|
connect_args={'check_same_thread': False}
|
||||||
|
)
|
||||||
|
_session_factory = sessionmaker(bind=_engine, autocommit=False, autoflush=False)
|
||||||
|
|
||||||
|
# 确保 SQLite 数据库文件被创建(SQLite 是惰性创建的)
|
||||||
|
with _engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine():
|
||||||
|
"""获取数据库引擎."""
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_factory():
|
||||||
|
"""获取会话工厂."""
|
||||||
|
return _session_factory
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_db(db_path: Optional[Path] = None) -> Generator[Session, None, None]:
|
||||||
|
"""
|
||||||
|
获取数据库会话.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 可选,用于初始化数据库
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SQLAlchemy 会话
|
||||||
|
"""
|
||||||
|
if db_path and _engine is None:
|
||||||
|
init_db(db_path)
|
||||||
|
|
||||||
|
if _session_factory is None:
|
||||||
|
raise RuntimeError("Database not initialized. Call init_db() first.")
|
||||||
|
|
||||||
|
session = _session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
@@ -1,11 +1,3 @@
|
|||||||
"""ORM Models.
|
from app.database import Base
|
||||||
|
|
||||||
NOTE: This module is a placeholder until Task 2.1.
|
__all__ = ['Base']
|
||||||
The Base class is needed by conftest.py for database fixtures.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
|
||||||
"""Base class for all ORM models."""
|
|
||||||
pass
|
|
||||||
|
|||||||
40
backend/init_db.py
Normal file
40
backend/init_db.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
数据库初始化脚本.
|
||||||
|
创建所有表和必要的目录.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.database import init_db
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""初始化数据库."""
|
||||||
|
settings = get_settings()
|
||||||
|
print(f"初始化数据库: {settings.db_path}")
|
||||||
|
|
||||||
|
# 创建目录
|
||||||
|
settings.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
settings.ssh_keys_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
settings.repos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
init_db(settings.db_path)
|
||||||
|
|
||||||
|
# 创建所有表
|
||||||
|
from app.database import get_engine
|
||||||
|
Base.metadata.create_all(get_engine())
|
||||||
|
|
||||||
|
print("数据库初始化成功!")
|
||||||
|
print(f" - 数据库: {settings.db_path}")
|
||||||
|
print(f" - SSH 密钥: {settings.ssh_keys_dir}")
|
||||||
|
print(f" - 仓库: {settings.repos_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
26
backend/tests/test_database.py
Normal file
26
backend/tests/test_database.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def test_database_initialization(db_path):
|
||||||
|
"""测试数据库初始化."""
|
||||||
|
from app.database import init_db, get_engine, Base
|
||||||
|
|
||||||
|
init_db(db_path)
|
||||||
|
|
||||||
|
assert db_path.exists()
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
assert engine is not None
|
||||||
|
|
||||||
|
# 创建所有表
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
assert True # 如果没有异常则成功
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session(db_path):
|
||||||
|
"""测试获取数据库会话."""
|
||||||
|
from app.database import init_db, get_db
|
||||||
|
|
||||||
|
init_db(db_path)
|
||||||
|
|
||||||
|
with get_db(db_path) as session:
|
||||||
|
assert session is not None
|
||||||
Reference in New Issue
Block a user