feat: HA Discovery protocol handler for auto-registering devices
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
84
src/mqtt_home/discovery.py
Normal file
84
src/mqtt_home/discovery.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import logging
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from mqtt_home.models import Device
|
||||
from mqtt_home.mqtt_client import MqttClient
|
||||
from mqtt_home.device_registry import create_device, handle_state_update
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DISCOVERY_TOPIC_PREFIX = "homeassistant/"
|
||||
|
||||
|
||||
def parse_discovery_topic(topic: str) -> dict[str, str] | None:
|
||||
"""Parse HA Discovery topic, returns component and node_id"""
|
||||
parts = topic.split("/")
|
||||
if len(parts) < 4 or parts[0] != DISCOVERY_TOPIC_PREFIX.strip("/"):
|
||||
return None
|
||||
if parts[-1] != "config":
|
||||
return None
|
||||
return {
|
||||
"component": parts[1],
|
||||
"node_id": "/".join(parts[2:-1]),
|
||||
}
|
||||
|
||||
|
||||
async def handle_discovery_message(
|
||||
topic: str, payload: str, db: AsyncSession, mqtt_client: MqttClient
|
||||
) -> Device | None:
|
||||
"""Handle HA Discovery config message, auto-register device and subscribe to state topic"""
|
||||
parsed = parse_discovery_topic(topic)
|
||||
if not parsed:
|
||||
return None
|
||||
|
||||
try:
|
||||
config = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON in discovery payload for %s", topic)
|
||||
return None
|
||||
|
||||
state_topic = config.get("state_topic")
|
||||
command_topic = config.get("command_topic")
|
||||
device_name = config.get("name", config.get("device", {}).get("name", parsed["node_id"]))
|
||||
device_class = config.get("device_class", "")
|
||||
|
||||
if not state_topic:
|
||||
logger.warning("Discovery config for %s has no state_topic", topic)
|
||||
return None
|
||||
|
||||
# Check if already exists (deduplicate by discovery_topic)
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.discovery_topic == topic)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
logger.debug("Device already registered: %s", topic)
|
||||
return existing
|
||||
|
||||
from mqtt_home.schemas import DeviceCreate
|
||||
device = await create_device(db, DeviceCreate(
|
||||
name=device_name,
|
||||
type=parsed["component"],
|
||||
protocol="ha_discovery",
|
||||
mqtt_topic=state_topic,
|
||||
command_topic=command_topic,
|
||||
))
|
||||
|
||||
# Update discovery-related fields
|
||||
device.discovery_topic = topic
|
||||
device.discovery_payload = payload
|
||||
device.attributes = json.dumps({
|
||||
"device_class": device_class,
|
||||
"unit_of_measurement": config.get("unit_of_measurement"),
|
||||
"icon": config.get("icon"),
|
||||
})
|
||||
await db.commit()
|
||||
await db.refresh(device)
|
||||
|
||||
# Subscribe to state topic
|
||||
await mqtt_client.subscribe(state_topic, qos=1)
|
||||
logger.info("HA Discovery device registered: %s -> %s", device.name, state_topic)
|
||||
|
||||
return device
|
||||
78
tests/test_discovery.py
Normal file
78
tests/test_discovery.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from mqtt_home.discovery import parse_discovery_topic, handle_discovery_message
|
||||
from mqtt_home.config import Settings
|
||||
from mqtt_home.mqtt_client import MqttClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings():
|
||||
return Settings(
|
||||
mqtt_host="localhost",
|
||||
emqx_api_url="http://localhost:18083/api/v5",
|
||||
emqx_api_key="test-key",
|
||||
emqx_api_secret="test-secret",
|
||||
database_url="sqlite+aiosqlite:///:memory:",
|
||||
)
|
||||
|
||||
|
||||
def test_parse_discovery_topic_valid():
|
||||
result = parse_discovery_topic("homeassistant/light/abc123/config")
|
||||
assert result is not None
|
||||
assert result["component"] == "light"
|
||||
assert result["node_id"] == "abc123"
|
||||
|
||||
|
||||
def test_parse_discovery_topic_nested_node():
|
||||
result = parse_discovery_topic("homeassistant/sensor/room/temperature/config")
|
||||
assert result is not None
|
||||
assert result["component"] == "sensor"
|
||||
assert result["node_id"] == "room/temperature"
|
||||
|
||||
|
||||
def test_parse_discovery_topic_invalid():
|
||||
assert parse_discovery_topic("other/topic/config") is None
|
||||
assert parse_discovery_topic("homeassistant/light/abc") is None
|
||||
assert parse_discovery_topic("homeassistant/light/abc/status") is None
|
||||
|
||||
|
||||
async def test_handle_discovery_creates_device(db_session, settings):
|
||||
mqtt_client = MqttClient(settings)
|
||||
mqtt_client.subscribe = AsyncMock()
|
||||
|
||||
payload = json.dumps({
|
||||
"name": "客厅灯",
|
||||
"state_topic": "home/living/light/status",
|
||||
"command_topic": "home/living/light/set",
|
||||
"device_class": "light",
|
||||
})
|
||||
|
||||
device = await handle_discovery_message(
|
||||
"homeassistant/light/living_room/config",
|
||||
payload,
|
||||
db_session,
|
||||
mqtt_client,
|
||||
)
|
||||
assert device is not None
|
||||
assert device.name == "客厅灯"
|
||||
assert device.protocol == "ha_discovery"
|
||||
assert device.mqtt_topic == "home/living/light/status"
|
||||
mqtt_client.subscribe.assert_called_once_with("home/living/light/status", qos=1)
|
||||
|
||||
|
||||
async def test_handle_discovery_duplicate(db_session, settings):
|
||||
mqtt_client = MqttClient(settings)
|
||||
mqtt_client.subscribe = AsyncMock()
|
||||
|
||||
payload = json.dumps({
|
||||
"name": "灯",
|
||||
"state_topic": "home/light",
|
||||
})
|
||||
|
||||
await handle_discovery_message("homeassistant/light/test/config", payload, db_session, mqtt_client)
|
||||
device2 = await handle_discovery_message("homeassistant/light/test/config", payload, db_session, mqtt_client)
|
||||
# Same discovery_topic should not create duplicate
|
||||
assert device2 is not None
|
||||
mqtt_client.subscribe.assert_called_once() # Only called once
|
||||
Reference in New Issue
Block a user