feat: HA Discovery protocol handler for auto-registering devices
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
84
src/mqtt_home/discovery.py
Normal file
84
src/mqtt_home/discovery.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from mqtt_home.models import Device
|
||||||
|
from mqtt_home.mqtt_client import MqttClient
|
||||||
|
from mqtt_home.device_registry import create_device, handle_state_update
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DISCOVERY_TOPIC_PREFIX = "homeassistant/"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_discovery_topic(topic: str) -> dict[str, str] | None:
|
||||||
|
"""Parse HA Discovery topic, returns component and node_id"""
|
||||||
|
parts = topic.split("/")
|
||||||
|
if len(parts) < 4 or parts[0] != DISCOVERY_TOPIC_PREFIX.strip("/"):
|
||||||
|
return None
|
||||||
|
if parts[-1] != "config":
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"component": parts[1],
|
||||||
|
"node_id": "/".join(parts[2:-1]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_discovery_message(
|
||||||
|
topic: str, payload: str, db: AsyncSession, mqtt_client: MqttClient
|
||||||
|
) -> Device | None:
|
||||||
|
"""Handle HA Discovery config message, auto-register device and subscribe to state topic"""
|
||||||
|
parsed = parse_discovery_topic(topic)
|
||||||
|
if not parsed:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = json.loads(payload)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Invalid JSON in discovery payload for %s", topic)
|
||||||
|
return None
|
||||||
|
|
||||||
|
state_topic = config.get("state_topic")
|
||||||
|
command_topic = config.get("command_topic")
|
||||||
|
device_name = config.get("name", config.get("device", {}).get("name", parsed["node_id"]))
|
||||||
|
device_class = config.get("device_class", "")
|
||||||
|
|
||||||
|
if not state_topic:
|
||||||
|
logger.warning("Discovery config for %s has no state_topic", topic)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if already exists (deduplicate by discovery_topic)
|
||||||
|
from sqlalchemy import select
|
||||||
|
result = await db.execute(
|
||||||
|
select(Device).where(Device.discovery_topic == topic)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing:
|
||||||
|
logger.debug("Device already registered: %s", topic)
|
||||||
|
return existing
|
||||||
|
|
||||||
|
from mqtt_home.schemas import DeviceCreate
|
||||||
|
device = await create_device(db, DeviceCreate(
|
||||||
|
name=device_name,
|
||||||
|
type=parsed["component"],
|
||||||
|
protocol="ha_discovery",
|
||||||
|
mqtt_topic=state_topic,
|
||||||
|
command_topic=command_topic,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Update discovery-related fields
|
||||||
|
device.discovery_topic = topic
|
||||||
|
device.discovery_payload = payload
|
||||||
|
device.attributes = json.dumps({
|
||||||
|
"device_class": device_class,
|
||||||
|
"unit_of_measurement": config.get("unit_of_measurement"),
|
||||||
|
"icon": config.get("icon"),
|
||||||
|
})
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(device)
|
||||||
|
|
||||||
|
# Subscribe to state topic
|
||||||
|
await mqtt_client.subscribe(state_topic, qos=1)
|
||||||
|
logger.info("HA Discovery device registered: %s -> %s", device.name, state_topic)
|
||||||
|
|
||||||
|
return device
|
||||||
78
tests/test_discovery.py
Normal file
78
tests/test_discovery.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from mqtt_home.discovery import parse_discovery_topic, handle_discovery_message
|
||||||
|
from mqtt_home.config import Settings
|
||||||
|
from mqtt_home.mqtt_client import MqttClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings():
|
||||||
|
return Settings(
|
||||||
|
mqtt_host="localhost",
|
||||||
|
emqx_api_url="http://localhost:18083/api/v5",
|
||||||
|
emqx_api_key="test-key",
|
||||||
|
emqx_api_secret="test-secret",
|
||||||
|
database_url="sqlite+aiosqlite:///:memory:",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_discovery_topic_valid():
|
||||||
|
result = parse_discovery_topic("homeassistant/light/abc123/config")
|
||||||
|
assert result is not None
|
||||||
|
assert result["component"] == "light"
|
||||||
|
assert result["node_id"] == "abc123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_discovery_topic_nested_node():
|
||||||
|
result = parse_discovery_topic("homeassistant/sensor/room/temperature/config")
|
||||||
|
assert result is not None
|
||||||
|
assert result["component"] == "sensor"
|
||||||
|
assert result["node_id"] == "room/temperature"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_discovery_topic_invalid():
|
||||||
|
assert parse_discovery_topic("other/topic/config") is None
|
||||||
|
assert parse_discovery_topic("homeassistant/light/abc") is None
|
||||||
|
assert parse_discovery_topic("homeassistant/light/abc/status") is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_handle_discovery_creates_device(db_session, settings):
|
||||||
|
mqtt_client = MqttClient(settings)
|
||||||
|
mqtt_client.subscribe = AsyncMock()
|
||||||
|
|
||||||
|
payload = json.dumps({
|
||||||
|
"name": "客厅灯",
|
||||||
|
"state_topic": "home/living/light/status",
|
||||||
|
"command_topic": "home/living/light/set",
|
||||||
|
"device_class": "light",
|
||||||
|
})
|
||||||
|
|
||||||
|
device = await handle_discovery_message(
|
||||||
|
"homeassistant/light/living_room/config",
|
||||||
|
payload,
|
||||||
|
db_session,
|
||||||
|
mqtt_client,
|
||||||
|
)
|
||||||
|
assert device is not None
|
||||||
|
assert device.name == "客厅灯"
|
||||||
|
assert device.protocol == "ha_discovery"
|
||||||
|
assert device.mqtt_topic == "home/living/light/status"
|
||||||
|
mqtt_client.subscribe.assert_called_once_with("home/living/light/status", qos=1)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_handle_discovery_duplicate(db_session, settings):
|
||||||
|
mqtt_client = MqttClient(settings)
|
||||||
|
mqtt_client.subscribe = AsyncMock()
|
||||||
|
|
||||||
|
payload = json.dumps({
|
||||||
|
"name": "灯",
|
||||||
|
"state_topic": "home/light",
|
||||||
|
})
|
||||||
|
|
||||||
|
await handle_discovery_message("homeassistant/light/test/config", payload, db_session, mqtt_client)
|
||||||
|
device2 = await handle_discovery_message("homeassistant/light/test/config", payload, db_session, mqtt_client)
|
||||||
|
# Same discovery_topic should not create duplicate
|
||||||
|
assert device2 is not None
|
||||||
|
mqtt_client.subscribe.assert_called_once() # Only called once
|
||||||
Reference in New Issue
Block a user