From 2614ae8880b9edbaa93a6de4cdc08773cdfc1570 Mon Sep 17 00:00:00 2001 From: walkpan Date: Sun, 29 Mar 2026 21:40:52 +0800 Subject: [PATCH] feat: device registry with CRUD, state tracking, command sending, and log management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/mqtt_home/device_registry.py | 147 +++++++++++++++++++++++++++++++ tests/test_device_registry.py | 98 +++++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 src/mqtt_home/device_registry.py create mode 100644 tests/test_device_registry.py diff --git a/src/mqtt_home/device_registry.py b/src/mqtt_home/device_registry.py new file mode 100644 index 0000000..36bb721 --- /dev/null +++ b/src/mqtt_home/device_registry.py @@ -0,0 +1,147 @@ +import json +import logging +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy import select, func, desc +from sqlalchemy.ext.asyncio import AsyncSession + +from mqtt_home.models import Device, DeviceLog +from mqtt_home.schemas import DeviceCreate, DeviceUpdate + +logger = logging.getLogger(__name__) + +MAX_LOGS_PER_DEVICE = 100 + + +async def list_devices(db: AsyncSession) -> list[Device]: + result = await db.execute(select(Device).order_by(Device.created_at.desc())) + return list(result.scalars().all()) + + +async def get_device(db: AsyncSession, device_id: str) -> Optional[Device]: + return await db.get(Device, device_id) + + +async def create_device(db: AsyncSession, data: DeviceCreate) -> Device: + device = Device( + name=data.name, + type=data.type, + protocol=data.protocol, + mqtt_topic=data.mqtt_topic, + command_topic=data.command_topic, + ) + db.add(device) + await db.commit() + await db.refresh(device) + logger.info("Device created: %s (%s)", device.name, device.id) + return device + + +async def update_device(db: AsyncSession, device_id: str, data: DeviceUpdate) -> Optional[Device]: + device = await db.get(Device, device_id) + if not device: + return None + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(device, key, value) + await db.commit() + await db.refresh(device) + return device + + +async def delete_device(db: AsyncSession, device_id: str) -> bool: + device = await db.get(Device, device_id) + if not device: + return False + await db.delete(device) + await db.commit() + logger.info("Device deleted: %s", device_id) + return True + + +async def send_command(db: AsyncSession, device_id: str, payload: str, publish_fn=None) -> Optional[DeviceLog]: + device = await db.get(Device, device_id) + if not device: + return None + if not device.command_topic: + raise ValueError(f"Device {device_id} has no command_topic configured") + + log = DeviceLog( + device_id=device_id, + direction="tx", + topic=device.command_topic, + payload=payload, + ) + db.add(log) + + if publish_fn: + await publish_fn(device.command_topic, payload) + + await db.commit() + await db.refresh(log) + logger.info("Command sent to %s: %s", device.command_topic, payload) + return log + + +async def handle_state_update(db: AsyncSession, topic: str, payload: str) -> Optional[Device]: + result = await db.execute(select(Device).where(Device.mqtt_topic == topic)) + device = result.scalar_one_or_none() + if not device: + return None + + device.state = payload + device.last_seen = datetime.now(timezone.utc) + device.is_online = True + + log = DeviceLog( + device_id=device.id, + direction="rx", + topic=topic, + payload=payload, + ) + db.add(log) + + # Clean old logs, keep MAX_LOGS_PER_DEVICE + count_result = await db.execute( + select(func.count()).select_from(DeviceLog).where(DeviceLog.device_id == device.id) + ) + count = count_result.scalar() or 0 + if count >= MAX_LOGS_PER_DEVICE: + oldest = await db.execute( + select(DeviceLog).where(DeviceLog.device_id == device.id) + .order_by(DeviceLog.timestamp.asc()) + .limit(count - MAX_LOGS_PER_DEVICE + 1) + ) + for old_log in oldest.scalars().all(): + await db.delete(old_log) + + await db.commit() + await db.refresh(device) + return device + + +async def get_device_logs(db: AsyncSession, device_id: str, limit: int = 20) -> list[DeviceLog]: + result = await db.execute( + select(DeviceLog) + .where(DeviceLog.device_id == device_id) + .order_by(desc(DeviceLog.timestamp)) + .limit(limit) + ) + return list(result.scalars().all()) + + +async def get_dashboard_stats(db: AsyncSession) -> dict: + total_result = await db.execute(select(func.count()).select_from(Device)) + total = total_result.scalar() or 0 + online_result = await db.execute(select(func.count()).select_from(Device).where(Device.is_online == True)) + online = online_result.scalar() or 0 + recent_logs_result = await db.execute( + select(DeviceLog).order_by(desc(DeviceLog.timestamp)).limit(10) + ) + return { + "total_devices": total, + "online_devices": online, + "offline_devices": total - online, + "recent_logs": list(recent_logs_result.scalars().all()), + } diff --git a/tests/test_device_registry.py b/tests/test_device_registry.py new file mode 100644 index 0000000..9c39678 --- /dev/null +++ b/tests/test_device_registry.py @@ -0,0 +1,98 @@ +import pytest +from mqtt_home.device_registry import ( + create_device, get_device, list_devices, delete_device, + update_device, send_command, handle_state_update, get_device_logs, + get_dashboard_stats, +) +from mqtt_home.schemas import DeviceCreate, DeviceUpdate + + +async def test_create_and_get_device(db_session): + data = DeviceCreate(name="客厅灯", type="light", mqtt_topic="home/light") + device = await create_device(db_session, data) + assert device.name == "客厅灯" + + fetched = await get_device(db_session, device.id) + assert fetched is not None + assert fetched.id == device.id + + +async def test_list_devices(db_session): + await create_device(db_session, DeviceCreate(name="设备1", type="switch", mqtt_topic="t1")) + await create_device(db_session, DeviceCreate(name="设备2", type="sensor", mqtt_topic="t2")) + devices = await list_devices(db_session) + assert len(devices) == 2 + + +async def test_update_device(db_session): + device = await create_device(db_session, DeviceCreate(name="灯", type="light", mqtt_topic="t")) + updated = await update_device(db_session, device.id, DeviceUpdate(name="新名字")) + assert updated.name == "新名字" + + +async def test_delete_device(db_session): + device = await create_device(db_session, DeviceCreate(name="灯", type="light", mqtt_topic="t")) + assert await delete_device(db_session, device.id) is True + assert await get_device(db_session, device.id) is None + + +async def test_delete_nonexistent_device(db_session): + assert await delete_device(db_session, "nonexistent") is False + + +async def test_send_command(db_session): + device = await create_device( + db_session, DeviceCreate(name="灯", type="light", mqtt_topic="t", command_topic="t/set") + ) + published = {} + + async def mock_publish(topic, payload): + published[topic] = payload + + log = await send_command(db_session, device.id, '{"state":"on"}', mock_publish) + assert log is not None + assert log.direction == "tx" + assert published["t/set"] == '{"state":"on"}' + + +async def test_send_command_no_command_topic(db_session): + device = await create_device(db_session, DeviceCreate(name="传感器", type="sensor", mqtt_topic="t")) + with pytest.raises(ValueError, match="no command_topic"): + await send_command(db_session, device.id, '{"value":1}') + + +async def test_handle_state_update(db_session): + device = await create_device(db_session, DeviceCreate(name="灯", type="light", mqtt_topic="home/light")) + updated = await handle_state_update(db_session, "home/light", '{"state":"on"}') + assert updated is not None + assert updated.state == '{"state":"on"}' + assert updated.is_online is True + + +async def test_handle_state_update_unknown_topic(db_session): + result = await handle_state_update(db_session, "unknown/topic", "on") + assert result is None + + +async def test_get_device_logs(db_session): + device = await create_device( + db_session, DeviceCreate(name="灯", type="light", mqtt_topic="t", command_topic="t/set") + ) + async def noop(t, p): + pass + await send_command(db_session, device.id, '{"state":"on"}', noop) + await send_command(db_session, device.id, '{"state":"off"}', noop) + + logs = await get_device_logs(db_session, device.id, limit=10) + assert len(logs) == 2 + + +async def test_dashboard_stats(db_session): + await create_device(db_session, DeviceCreate(name="在线设备", type="switch", mqtt_topic="t1")) + await create_device(db_session, DeviceCreate(name="离线设备", type="sensor", mqtt_topic="t2")) + await handle_state_update(db_session, "t1", "online") + + stats = await get_dashboard_stats(db_session) + assert stats["total_devices"] == 2 + assert stats["online_devices"] == 1 + assert stats["offline_devices"] == 1