diff --git a/src/mqtt_home/database.py b/src/mqtt_home/database.py index 1c2dcc4..4d43c7a 100644 --- a/src/mqtt_home/database.py +++ b/src/mqtt_home/database.py @@ -1,5 +1,33 @@ +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, async_session from sqlalchemy.orm import DeclarativeBase +from mqtt_home.config import get_settings class Base(DeclarativeBase): - pass \ No newline at end of file + pass + + +def get_engine(): + settings = get_settings() + return create_async_engine( + settings.database_url, + echo=False, + ) + + +def get_session_factory(engine=None): + _engine = engine or get_engine() + return async_sessionmaker(_engine, expire_on_commit=False) + + +async def init_db(): + engine = get_engine() + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + await engine.dispose() + + +async def get_db(): + factory = get_session_factory() + async with factory() as session: + yield session \ No newline at end of file diff --git a/src/mqtt_home/models.py b/src/mqtt_home/models.py new file mode 100644 index 0000000..bb3cd71 --- /dev/null +++ b/src/mqtt_home/models.py @@ -0,0 +1,49 @@ +import uuid +from datetime import datetime, timezone + +from sqlalchemy import Column, String, Boolean, DateTime, Text, Integer, ForeignKey +from sqlalchemy.orm import relationship + +from mqtt_home.database import Base + + +def _utcnow(): + return datetime.now(timezone.utc) + + +def _new_id(): + return str(uuid.uuid4()) + + +class Device(Base): + __tablename__ = "devices" + + id = Column(String(36), primary_key=True, default=_new_id) + name = Column(String(200), nullable=False) + type = Column(String(50), nullable=False, default="switch") + protocol = Column(String(20), nullable=False, default="custom") + mqtt_topic = Column(String(500), nullable=False) + command_topic = Column(String(500), nullable=True) + discovery_topic = Column(String(500), nullable=True) + discovery_payload = Column(Text, nullable=True) + attributes = Column(Text, nullable=True, default="{}") + state = Column(String(500), nullable=True, default=None) + is_online = Column(Boolean, nullable=False, default=False) + last_seen = Column(DateTime(timezone=True), nullable=True) + created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow) + updated_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow, onupdate=_utcnow) + + logs = relationship("DeviceLog", back_populates="device", lazy="select") + + +class DeviceLog(Base): + __tablename__ = "device_logs" + + id = Column(Integer, primary_key=True, autoincrement=True) + device_id = Column(String(36), ForeignKey("devices.id", ondelete="CASCADE"), nullable=False) + direction = Column(String(10), nullable=False) + topic = Column(String(500), nullable=False) + payload = Column(Text, nullable=False) + timestamp = Column(DateTime(timezone=True), nullable=False, default=_utcnow) + + device = relationship("Device", back_populates="logs") \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..16a88f9 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,49 @@ +import pytest +from mqtt_home.models import Device, DeviceLog + + +async def test_create_device(db_session): + device = Device( + name="客厅灯", + type="light", + protocol="custom", + mqtt_topic="home/living/light", + command_topic="home/living/light/set", + ) + db_session.add(device) + await db_session.commit() + await db_session.refresh(device) + + assert device.id is not None + assert len(device.id) == 36 + assert device.name == "客厅灯" + assert device.state is None + assert device.is_online is False + + +async def test_create_device_log(db_session): + device = Device( + name="温度传感器", + type="sensor", + protocol="custom", + mqtt_topic="home/temperature", + ) + db_session.add(device) + await db_session.flush() + + log = DeviceLog( + device_id=device.id, + direction="rx", + topic="home/temperature", + payload='{"temperature": 25.5}', + ) + db_session.add(log) + await db_session.commit() + + # Query logs directly instead of using relationship.count() + from sqlalchemy import select + result = await db_session.execute( + select(DeviceLog).where(DeviceLog.device_id == device.id) + ) + logs = result.scalars().all() + assert len(logs) == 1