feat: MQTT client with reconnection, topic matching, and callback dispatch
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
114
src/mqtt_home/mqtt_client.py
Normal file
114
src/mqtt_home/mqtt_client.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Awaitable
|
||||||
|
|
||||||
|
import aiomqtt
|
||||||
|
|
||||||
|
from mqtt_home.config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MessageCallback = Callable[[str, str], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
class MqttClient:
|
||||||
|
def __init__(self, settings: Settings):
|
||||||
|
self._settings = settings
|
||||||
|
self._client: aiomqtt.Client | None = None
|
||||||
|
self._callbacks: dict[str, list[MessageCallback]] = {}
|
||||||
|
self._connected = False
|
||||||
|
self._task: asyncio.Task | None = None
|
||||||
|
self._client_id = f"mqtt-home-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._connected
|
||||||
|
|
||||||
|
def on_message(self, topic_filter: str, callback: MessageCallback):
|
||||||
|
self._callbacks.setdefault(topic_filter, []).append(callback)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
self._task = asyncio.create_task(self._run())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def publish(self, topic: str, payload: str, qos: int = 1, retain: bool = False):
|
||||||
|
if not self._client:
|
||||||
|
raise RuntimeError("MQTT client not connected")
|
||||||
|
await self._client.publish(topic, payload, qos=qos, retain=retain)
|
||||||
|
|
||||||
|
async def subscribe(self, topic: str, qos: int = 1):
|
||||||
|
if not self._client:
|
||||||
|
raise RuntimeError("MQTT client not connected")
|
||||||
|
await self._client.subscribe(topic, qos=qos)
|
||||||
|
|
||||||
|
async def _run(self):
|
||||||
|
retry_delay = 1
|
||||||
|
max_delay = 60
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
async with aiomqtt.Client(
|
||||||
|
hostname=self._settings.mqtt_host,
|
||||||
|
port=self._settings.mqtt_port,
|
||||||
|
username=self._settings.mqtt_username or None,
|
||||||
|
password=self._settings.mqtt_password or None,
|
||||||
|
identifier=self._client_id,
|
||||||
|
clean_session=True,
|
||||||
|
) as client:
|
||||||
|
self._client = client
|
||||||
|
self._connected = True
|
||||||
|
logger.info("MQTT connected to %s:%d", self._settings.mqtt_host, self._settings.mqtt_port)
|
||||||
|
retry_delay = 1
|
||||||
|
|
||||||
|
for topic_filter in self._callbacks:
|
||||||
|
await client.subscribe(topic_filter, qos=1)
|
||||||
|
|
||||||
|
async for message in client.messages:
|
||||||
|
topic = str(message.topic)
|
||||||
|
payload = message.payload.decode("utf-8", errors="replace")
|
||||||
|
await self._dispatch(topic, payload)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self._connected = False
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._connected = False
|
||||||
|
logger.warning("MQTT connection failed: %s, retrying in %ds", e, retry_delay)
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
retry_delay = min(retry_delay * 2, max_delay)
|
||||||
|
|
||||||
|
async def _dispatch(self, topic: str, payload: str):
|
||||||
|
matched = False
|
||||||
|
for topic_filter, callbacks in self._callbacks.items():
|
||||||
|
if self._topic_matches(topic, topic_filter):
|
||||||
|
matched = True
|
||||||
|
for cb in callbacks:
|
||||||
|
try:
|
||||||
|
await cb(topic, payload)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("MQTT callback error for %s: %s", topic, e)
|
||||||
|
if not matched:
|
||||||
|
logger.debug("No callback for topic: %s", topic)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _topic_matches(topic: str, pattern: str) -> bool:
|
||||||
|
"""Simple MQTT wildcard matching (supports + and #)"""
|
||||||
|
topic_parts = topic.split("/")
|
||||||
|
pattern_parts = pattern.split("/")
|
||||||
|
pi = 0
|
||||||
|
for ti, tp in enumerate(pattern_parts):
|
||||||
|
if tp == "#":
|
||||||
|
return True
|
||||||
|
if ti >= len(topic_parts):
|
||||||
|
return False
|
||||||
|
if tp != "+" and tp != topic_parts[ti]:
|
||||||
|
return False
|
||||||
|
pi = ti
|
||||||
|
return len(topic_parts) == len(pattern_parts)
|
||||||
48
tests/test_mqtt_client.py
Normal file
48
tests/test_mqtt_client.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
from mqtt_home.mqtt_client import MqttClient
|
||||||
|
from mqtt_home.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mqtt_client():
|
||||||
|
settings = Settings(
|
||||||
|
mqtt_host="localhost",
|
||||||
|
mqtt_port=1883,
|
||||||
|
emqx_api_url="http://localhost:18083/api/v5",
|
||||||
|
emqx_api_key="test-key",
|
||||||
|
emqx_api_secret="test-secret",
|
||||||
|
database_url="sqlite+aiosqlite:///:memory:",
|
||||||
|
)
|
||||||
|
return MqttClient(settings)
|
||||||
|
|
||||||
|
|
||||||
|
def test_topic_matches_exact():
|
||||||
|
assert MqttClient._topic_matches("home/light", "home/light") is True
|
||||||
|
assert MqttClient._topic_matches("home/light", "home/switch") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_topic_matches_single_level_wildcard():
|
||||||
|
assert MqttClient._topic_matches("home/light/status", "home/+/status") is True
|
||||||
|
assert MqttClient._topic_matches("home/switch/status", "home/+/status") is True
|
||||||
|
assert MqttClient._topic_matches("home/light", "home/+") is True
|
||||||
|
assert MqttClient._topic_matches("home/light/extra", "home/+") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_topic_matches_multi_level_wildcard():
|
||||||
|
assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/#") is True
|
||||||
|
assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/light/#") is True
|
||||||
|
assert MqttClient._topic_matches("other/topic", "homeassistant/#") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_callback(mqtt_client):
|
||||||
|
async def cb(topic, payload):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mqtt_client.on_message("home/#", cb)
|
||||||
|
assert "home/#" in mqtt_client._callbacks
|
||||||
|
assert len(mqtt_client._callbacks["home/#"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_initial_state(mqtt_client):
|
||||||
|
assert mqtt_client.is_connected is False
|
||||||
|
assert len(mqtt_client._client_id) > 0
|
||||||
Reference in New Issue
Block a user