feat: MQTT client with reconnection, topic matching, and callback dispatch
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
114
src/mqtt_home/mqtt_client.py
Normal file
114
src/mqtt_home/mqtt_client.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
import aiomqtt
|
||||
|
||||
from mqtt_home.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MessageCallback = Callable[[str, str], Awaitable[None]]
|
||||
|
||||
|
||||
class MqttClient:
|
||||
def __init__(self, settings: Settings):
|
||||
self._settings = settings
|
||||
self._client: aiomqtt.Client | None = None
|
||||
self._callbacks: dict[str, list[MessageCallback]] = {}
|
||||
self._connected = False
|
||||
self._task: asyncio.Task | None = None
|
||||
self._client_id = f"mqtt-home-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
def on_message(self, topic_filter: str, callback: MessageCallback):
|
||||
self._callbacks.setdefault(topic_filter, []).append(callback)
|
||||
|
||||
async def start(self):
|
||||
self._task = asyncio.create_task(self._run())
|
||||
|
||||
async def stop(self):
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def publish(self, topic: str, payload: str, qos: int = 1, retain: bool = False):
|
||||
if not self._client:
|
||||
raise RuntimeError("MQTT client not connected")
|
||||
await self._client.publish(topic, payload, qos=qos, retain=retain)
|
||||
|
||||
async def subscribe(self, topic: str, qos: int = 1):
|
||||
if not self._client:
|
||||
raise RuntimeError("MQTT client not connected")
|
||||
await self._client.subscribe(topic, qos=qos)
|
||||
|
||||
async def _run(self):
|
||||
retry_delay = 1
|
||||
max_delay = 60
|
||||
while True:
|
||||
try:
|
||||
async with aiomqtt.Client(
|
||||
hostname=self._settings.mqtt_host,
|
||||
port=self._settings.mqtt_port,
|
||||
username=self._settings.mqtt_username or None,
|
||||
password=self._settings.mqtt_password or None,
|
||||
identifier=self._client_id,
|
||||
clean_session=True,
|
||||
) as client:
|
||||
self._client = client
|
||||
self._connected = True
|
||||
logger.info("MQTT connected to %s:%d", self._settings.mqtt_host, self._settings.mqtt_port)
|
||||
retry_delay = 1
|
||||
|
||||
for topic_filter in self._callbacks:
|
||||
await client.subscribe(topic_filter, qos=1)
|
||||
|
||||
async for message in client.messages:
|
||||
topic = str(message.topic)
|
||||
payload = message.payload.decode("utf-8", errors="replace")
|
||||
await self._dispatch(topic, payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self._connected = False
|
||||
raise
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
logger.warning("MQTT connection failed: %s, retrying in %ds", e, retry_delay)
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay = min(retry_delay * 2, max_delay)
|
||||
|
||||
async def _dispatch(self, topic: str, payload: str):
|
||||
matched = False
|
||||
for topic_filter, callbacks in self._callbacks.items():
|
||||
if self._topic_matches(topic, topic_filter):
|
||||
matched = True
|
||||
for cb in callbacks:
|
||||
try:
|
||||
await cb(topic, payload)
|
||||
except Exception as e:
|
||||
logger.error("MQTT callback error for %s: %s", topic, e)
|
||||
if not matched:
|
||||
logger.debug("No callback for topic: %s", topic)
|
||||
|
||||
@staticmethod
|
||||
def _topic_matches(topic: str, pattern: str) -> bool:
|
||||
"""Simple MQTT wildcard matching (supports + and #)"""
|
||||
topic_parts = topic.split("/")
|
||||
pattern_parts = pattern.split("/")
|
||||
pi = 0
|
||||
for ti, tp in enumerate(pattern_parts):
|
||||
if tp == "#":
|
||||
return True
|
||||
if ti >= len(topic_parts):
|
||||
return False
|
||||
if tp != "+" and tp != topic_parts[ti]:
|
||||
return False
|
||||
pi = ti
|
||||
return len(topic_parts) == len(pattern_parts)
|
||||
48
tests/test_mqtt_client.py
Normal file
48
tests/test_mqtt_client.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
from mqtt_home.mqtt_client import MqttClient
|
||||
from mqtt_home.config import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mqtt_client():
|
||||
settings = Settings(
|
||||
mqtt_host="localhost",
|
||||
mqtt_port=1883,
|
||||
emqx_api_url="http://localhost:18083/api/v5",
|
||||
emqx_api_key="test-key",
|
||||
emqx_api_secret="test-secret",
|
||||
database_url="sqlite+aiosqlite:///:memory:",
|
||||
)
|
||||
return MqttClient(settings)
|
||||
|
||||
|
||||
def test_topic_matches_exact():
|
||||
assert MqttClient._topic_matches("home/light", "home/light") is True
|
||||
assert MqttClient._topic_matches("home/light", "home/switch") is False
|
||||
|
||||
|
||||
def test_topic_matches_single_level_wildcard():
|
||||
assert MqttClient._topic_matches("home/light/status", "home/+/status") is True
|
||||
assert MqttClient._topic_matches("home/switch/status", "home/+/status") is True
|
||||
assert MqttClient._topic_matches("home/light", "home/+") is True
|
||||
assert MqttClient._topic_matches("home/light/extra", "home/+") is False
|
||||
|
||||
|
||||
def test_topic_matches_multi_level_wildcard():
|
||||
assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/#") is True
|
||||
assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/light/#") is True
|
||||
assert MqttClient._topic_matches("other/topic", "homeassistant/#") is False
|
||||
|
||||
|
||||
def test_register_callback(mqtt_client):
|
||||
async def cb(topic, payload):
|
||||
pass
|
||||
|
||||
mqtt_client.on_message("home/#", cb)
|
||||
assert "home/#" in mqtt_client._callbacks
|
||||
assert len(mqtt_client._callbacks["home/#"]) == 1
|
||||
|
||||
|
||||
def test_initial_state(mqtt_client):
|
||||
assert mqtt_client.is_connected is False
|
||||
assert len(mqtt_client._client_id) > 0
|
||||
Reference in New Issue
Block a user