feat: MQTT client with reconnection, topic matching, and callback dispatch

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
walkpan
2026-03-29 21:37:03 +08:00
parent 79b2878b3d
commit afe9de51c5
2 changed files with 162 additions and 0 deletions

View File

@@ -0,0 +1,114 @@
import asyncio
import uuid
import logging
from typing import Callable, Awaitable
import aiomqtt
from mqtt_home.config import Settings
logger = logging.getLogger(__name__)
MessageCallback = Callable[[str, str], Awaitable[None]]
class MqttClient:
def __init__(self, settings: Settings):
self._settings = settings
self._client: aiomqtt.Client | None = None
self._callbacks: dict[str, list[MessageCallback]] = {}
self._connected = False
self._task: asyncio.Task | None = None
self._client_id = f"mqtt-home-{uuid.uuid4().hex[:8]}"
@property
def is_connected(self) -> bool:
return self._connected
def on_message(self, topic_filter: str, callback: MessageCallback):
self._callbacks.setdefault(topic_filter, []).append(callback)
async def start(self):
self._task = asyncio.create_task(self._run())
async def stop(self):
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
async def publish(self, topic: str, payload: str, qos: int = 1, retain: bool = False):
if not self._client:
raise RuntimeError("MQTT client not connected")
await self._client.publish(topic, payload, qos=qos, retain=retain)
async def subscribe(self, topic: str, qos: int = 1):
if not self._client:
raise RuntimeError("MQTT client not connected")
await self._client.subscribe(topic, qos=qos)
async def _run(self):
retry_delay = 1
max_delay = 60
while True:
try:
async with aiomqtt.Client(
hostname=self._settings.mqtt_host,
port=self._settings.mqtt_port,
username=self._settings.mqtt_username or None,
password=self._settings.mqtt_password or None,
identifier=self._client_id,
clean_session=True,
) as client:
self._client = client
self._connected = True
logger.info("MQTT connected to %s:%d", self._settings.mqtt_host, self._settings.mqtt_port)
retry_delay = 1
for topic_filter in self._callbacks:
await client.subscribe(topic_filter, qos=1)
async for message in client.messages:
topic = str(message.topic)
payload = message.payload.decode("utf-8", errors="replace")
await self._dispatch(topic, payload)
except asyncio.CancelledError:
self._connected = False
raise
except Exception as e:
self._connected = False
logger.warning("MQTT connection failed: %s, retrying in %ds", e, retry_delay)
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, max_delay)
async def _dispatch(self, topic: str, payload: str):
matched = False
for topic_filter, callbacks in self._callbacks.items():
if self._topic_matches(topic, topic_filter):
matched = True
for cb in callbacks:
try:
await cb(topic, payload)
except Exception as e:
logger.error("MQTT callback error for %s: %s", topic, e)
if not matched:
logger.debug("No callback for topic: %s", topic)
@staticmethod
def _topic_matches(topic: str, pattern: str) -> bool:
"""Simple MQTT wildcard matching (supports + and #)"""
topic_parts = topic.split("/")
pattern_parts = pattern.split("/")
pi = 0
for ti, tp in enumerate(pattern_parts):
if tp == "#":
return True
if ti >= len(topic_parts):
return False
if tp != "+" and tp != topic_parts[ti]:
return False
pi = ti
return len(topic_parts) == len(pattern_parts)

48
tests/test_mqtt_client.py Normal file
View File

@@ -0,0 +1,48 @@
import pytest
from mqtt_home.mqtt_client import MqttClient
from mqtt_home.config import Settings
@pytest.fixture
def mqtt_client():
settings = Settings(
mqtt_host="localhost",
mqtt_port=1883,
emqx_api_url="http://localhost:18083/api/v5",
emqx_api_key="test-key",
emqx_api_secret="test-secret",
database_url="sqlite+aiosqlite:///:memory:",
)
return MqttClient(settings)
def test_topic_matches_exact():
assert MqttClient._topic_matches("home/light", "home/light") is True
assert MqttClient._topic_matches("home/light", "home/switch") is False
def test_topic_matches_single_level_wildcard():
assert MqttClient._topic_matches("home/light/status", "home/+/status") is True
assert MqttClient._topic_matches("home/switch/status", "home/+/status") is True
assert MqttClient._topic_matches("home/light", "home/+") is True
assert MqttClient._topic_matches("home/light/extra", "home/+") is False
def test_topic_matches_multi_level_wildcard():
assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/#") is True
assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/light/#") is True
assert MqttClient._topic_matches("other/topic", "homeassistant/#") is False
def test_register_callback(mqtt_client):
async def cb(topic, payload):
pass
mqtt_client.on_message("home/#", cb)
assert "home/#" in mqtt_client._callbacks
assert len(mqtt_client._callbacks["home/#"]) == 1
def test_initial_state(mqtt_client):
assert mqtt_client.is_connected is False
assert len(mqtt_client._client_id) > 0