From abb170ace60cd9605863568afe242cb9c91c6c81 Mon Sep 17 00:00:00 2001 From: walkpan Date: Sun, 29 Mar 2026 21:42:49 +0800 Subject: [PATCH] feat: HA Discovery protocol handler for auto-registering devices MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/mqtt_home/discovery.py | 84 ++++++++++++++++++++++++++++++++++++++ tests/test_discovery.py | 78 +++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 src/mqtt_home/discovery.py create mode 100644 tests/test_discovery.py diff --git a/src/mqtt_home/discovery.py b/src/mqtt_home/discovery.py new file mode 100644 index 0000000..66fced7 --- /dev/null +++ b/src/mqtt_home/discovery.py @@ -0,0 +1,84 @@ +import json +import logging +from sqlalchemy.ext.asyncio import AsyncSession + +from mqtt_home.models import Device +from mqtt_home.mqtt_client import MqttClient +from mqtt_home.device_registry import create_device, handle_state_update + +logger = logging.getLogger(__name__) + +DISCOVERY_TOPIC_PREFIX = "homeassistant/" + + +def parse_discovery_topic(topic: str) -> dict[str, str] | None: + """Parse HA Discovery topic, returns component and node_id""" + parts = topic.split("/") + if len(parts) < 4 or parts[0] != DISCOVERY_TOPIC_PREFIX.strip("/"): + return None + if parts[-1] != "config": + return None + return { + "component": parts[1], + "node_id": "/".join(parts[2:-1]), + } + + +async def handle_discovery_message( + topic: str, payload: str, db: AsyncSession, mqtt_client: MqttClient +) -> Device | None: + """Handle HA Discovery config message, auto-register device and subscribe to state topic""" + parsed = parse_discovery_topic(topic) + if not parsed: + return None + + try: + config = json.loads(payload) + except json.JSONDecodeError: + logger.warning("Invalid JSON in discovery payload for %s", topic) + return None + + state_topic = config.get("state_topic") + command_topic = config.get("command_topic") + device_name = config.get("name", config.get("device", {}).get("name", parsed["node_id"])) + device_class = config.get("device_class", "") + + if not state_topic: + logger.warning("Discovery config for %s has no state_topic", topic) + return None + + # Check if already exists (deduplicate by discovery_topic) + from sqlalchemy import select + result = await db.execute( + select(Device).where(Device.discovery_topic == topic) + ) + existing = result.scalar_one_or_none() + if existing: + logger.debug("Device already registered: %s", topic) + return existing + + from mqtt_home.schemas import DeviceCreate + device = await create_device(db, DeviceCreate( + name=device_name, + type=parsed["component"], + protocol="ha_discovery", + mqtt_topic=state_topic, + command_topic=command_topic, + )) + + # Update discovery-related fields + device.discovery_topic = topic + device.discovery_payload = payload + device.attributes = json.dumps({ + "device_class": device_class, + "unit_of_measurement": config.get("unit_of_measurement"), + "icon": config.get("icon"), + }) + await db.commit() + await db.refresh(device) + + # Subscribe to state topic + await mqtt_client.subscribe(state_topic, qos=1) + logger.info("HA Discovery device registered: %s -> %s", device.name, state_topic) + + return device diff --git a/tests/test_discovery.py b/tests/test_discovery.py new file mode 100644 index 0000000..3a1d7a7 --- /dev/null +++ b/tests/test_discovery.py @@ -0,0 +1,78 @@ +import json +import pytest +from unittest.mock import AsyncMock + +from mqtt_home.discovery import parse_discovery_topic, handle_discovery_message +from mqtt_home.config import Settings +from mqtt_home.mqtt_client import MqttClient + + +@pytest.fixture +def settings(): + return Settings( + mqtt_host="localhost", + emqx_api_url="http://localhost:18083/api/v5", + emqx_api_key="test-key", + emqx_api_secret="test-secret", + database_url="sqlite+aiosqlite:///:memory:", + ) + + +def test_parse_discovery_topic_valid(): + result = parse_discovery_topic("homeassistant/light/abc123/config") + assert result is not None + assert result["component"] == "light" + assert result["node_id"] == "abc123" + + +def test_parse_discovery_topic_nested_node(): + result = parse_discovery_topic("homeassistant/sensor/room/temperature/config") + assert result is not None + assert result["component"] == "sensor" + assert result["node_id"] == "room/temperature" + + +def test_parse_discovery_topic_invalid(): + assert parse_discovery_topic("other/topic/config") is None + assert parse_discovery_topic("homeassistant/light/abc") is None + assert parse_discovery_topic("homeassistant/light/abc/status") is None + + +async def test_handle_discovery_creates_device(db_session, settings): + mqtt_client = MqttClient(settings) + mqtt_client.subscribe = AsyncMock() + + payload = json.dumps({ + "name": "客厅灯", + "state_topic": "home/living/light/status", + "command_topic": "home/living/light/set", + "device_class": "light", + }) + + device = await handle_discovery_message( + "homeassistant/light/living_room/config", + payload, + db_session, + mqtt_client, + ) + assert device is not None + assert device.name == "客厅灯" + assert device.protocol == "ha_discovery" + assert device.mqtt_topic == "home/living/light/status" + mqtt_client.subscribe.assert_called_once_with("home/living/light/status", qos=1) + + +async def test_handle_discovery_duplicate(db_session, settings): + mqtt_client = MqttClient(settings) + mqtt_client.subscribe = AsyncMock() + + payload = json.dumps({ + "name": "灯", + "state_topic": "home/light", + }) + + await handle_discovery_message("homeassistant/light/test/config", payload, db_session, mqtt_client) + device2 = await handle_discovery_message("homeassistant/light/test/config", payload, db_session, mqtt_client) + # Same discovery_topic should not create duplicate + assert device2 is not None + mqtt_client.subscribe.assert_called_once() # Only called once