diff --git a/src/mqtt_home/mqtt_client.py b/src/mqtt_home/mqtt_client.py new file mode 100644 index 0000000..f0fae06 --- /dev/null +++ b/src/mqtt_home/mqtt_client.py @@ -0,0 +1,114 @@ +import asyncio +import uuid +import logging +from typing import Callable, Awaitable + +import aiomqtt + +from mqtt_home.config import Settings + +logger = logging.getLogger(__name__) + +MessageCallback = Callable[[str, str], Awaitable[None]] + + +class MqttClient: + def __init__(self, settings: Settings): + self._settings = settings + self._client: aiomqtt.Client | None = None + self._callbacks: dict[str, list[MessageCallback]] = {} + self._connected = False + self._task: asyncio.Task | None = None + self._client_id = f"mqtt-home-{uuid.uuid4().hex[:8]}" + + @property + def is_connected(self) -> bool: + return self._connected + + def on_message(self, topic_filter: str, callback: MessageCallback): + self._callbacks.setdefault(topic_filter, []).append(callback) + + async def start(self): + self._task = asyncio.create_task(self._run()) + + async def stop(self): + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + async def publish(self, topic: str, payload: str, qos: int = 1, retain: bool = False): + if not self._client: + raise RuntimeError("MQTT client not connected") + await self._client.publish(topic, payload, qos=qos, retain=retain) + + async def subscribe(self, topic: str, qos: int = 1): + if not self._client: + raise RuntimeError("MQTT client not connected") + await self._client.subscribe(topic, qos=qos) + + async def _run(self): + retry_delay = 1 + max_delay = 60 + while True: + try: + async with aiomqtt.Client( + hostname=self._settings.mqtt_host, + port=self._settings.mqtt_port, + username=self._settings.mqtt_username or None, + password=self._settings.mqtt_password or None, + identifier=self._client_id, + clean_session=True, + ) as client: + self._client = client + self._connected = True + logger.info("MQTT connected to %s:%d", self._settings.mqtt_host, self._settings.mqtt_port) + retry_delay = 1 + + for topic_filter in self._callbacks: + await client.subscribe(topic_filter, qos=1) + + async for message in client.messages: + topic = str(message.topic) + payload = message.payload.decode("utf-8", errors="replace") + await self._dispatch(topic, payload) + + except asyncio.CancelledError: + self._connected = False + raise + except Exception as e: + self._connected = False + logger.warning("MQTT connection failed: %s, retrying in %ds", e, retry_delay) + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 2, max_delay) + + async def _dispatch(self, topic: str, payload: str): + matched = False + for topic_filter, callbacks in self._callbacks.items(): + if self._topic_matches(topic, topic_filter): + matched = True + for cb in callbacks: + try: + await cb(topic, payload) + except Exception as e: + logger.error("MQTT callback error for %s: %s", topic, e) + if not matched: + logger.debug("No callback for topic: %s", topic) + + @staticmethod + def _topic_matches(topic: str, pattern: str) -> bool: + """Simple MQTT wildcard matching (supports + and #)""" + topic_parts = topic.split("/") + pattern_parts = pattern.split("/") + pi = 0 + for ti, tp in enumerate(pattern_parts): + if tp == "#": + return True + if ti >= len(topic_parts): + return False + if tp != "+" and tp != topic_parts[ti]: + return False + pi = ti + return len(topic_parts) == len(pattern_parts) diff --git a/tests/test_mqtt_client.py b/tests/test_mqtt_client.py new file mode 100644 index 0000000..3ac9c9f --- /dev/null +++ b/tests/test_mqtt_client.py @@ -0,0 +1,48 @@ +import pytest +from mqtt_home.mqtt_client import MqttClient +from mqtt_home.config import Settings + + +@pytest.fixture +def mqtt_client(): + settings = Settings( + mqtt_host="localhost", + mqtt_port=1883, + emqx_api_url="http://localhost:18083/api/v5", + emqx_api_key="test-key", + emqx_api_secret="test-secret", + database_url="sqlite+aiosqlite:///:memory:", + ) + return MqttClient(settings) + + +def test_topic_matches_exact(): + assert MqttClient._topic_matches("home/light", "home/light") is True + assert MqttClient._topic_matches("home/light", "home/switch") is False + + +def test_topic_matches_single_level_wildcard(): + assert MqttClient._topic_matches("home/light/status", "home/+/status") is True + assert MqttClient._topic_matches("home/switch/status", "home/+/status") is True + assert MqttClient._topic_matches("home/light", "home/+") is True + assert MqttClient._topic_matches("home/light/extra", "home/+") is False + + +def test_topic_matches_multi_level_wildcard(): + assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/#") is True + assert MqttClient._topic_matches("homeassistant/light/abc/config", "homeassistant/light/#") is True + assert MqttClient._topic_matches("other/topic", "homeassistant/#") is False + + +def test_register_callback(mqtt_client): + async def cb(topic, payload): + pass + + mqtt_client.on_message("home/#", cb) + assert "home/#" in mqtt_client._callbacks + assert len(mqtt_client._callbacks["home/#"]) == 1 + + +def test_initial_state(mqtt_client): + assert mqtt_client.is_connected is False + assert len(mqtt_client._client_id) > 0