diff --git a/frontend/src/views/RulesView.vue b/frontend/src/views/RulesView.vue
new file mode 100644
index 0000000..2f7366a
--- /dev/null
+++ b/frontend/src/views/RulesView.vue
@@ -0,0 +1,69 @@
+
+
+
+
规则管理
+
+
+
+
加载中...
+
+
+
暂无规则
+
添加主题匹配规则,系统将自动发现 MQTT 设备
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/mqtt_home/api/__init__.py b/src/mqtt_home/api/__init__.py
index bdf76f2..40f569a 100644
--- a/src/mqtt_home/api/__init__.py
+++ b/src/mqtt_home/api/__init__.py
@@ -2,8 +2,10 @@ from fastapi import APIRouter
from mqtt_home.api.devices import router as devices_router
from mqtt_home.api.broker import router as broker_router
from mqtt_home.api.dashboard import router as dashboard_router
+from mqtt_home.api.rules import router as rules_router
api_router = APIRouter(prefix="/api")
api_router.include_router(devices_router, prefix="/devices", tags=["devices"])
api_router.include_router(broker_router, prefix="/broker", tags=["broker"])
api_router.include_router(dashboard_router, prefix="/dashboard", tags=["dashboard"])
+api_router.include_router(rules_router, prefix="/rules", tags=["rules"])
diff --git a/src/mqtt_home/api/dashboard.py b/src/mqtt_home/api/dashboard.py
index 1b20462..0c9cdea 100644
--- a/src/mqtt_home/api/dashboard.py
+++ b/src/mqtt_home/api/dashboard.py
@@ -4,10 +4,42 @@ from sqlalchemy.ext.asyncio import AsyncSession
from mqtt_home.database import get_db
from mqtt_home.device_registry import get_dashboard_stats
from mqtt_home.schemas import DashboardStats
+from mqtt_home.emqx_api import EmqxApiClient
router = APIRouter()
-@router.get("", response_model=DashboardStats)
+@router.get("")
async def dashboard(db: AsyncSession = Depends(get_db)):
- return await get_dashboard_stats(db)
+ from mqtt_home.main import app
+
+ stats = await get_dashboard_stats(db)
+
+ # Add broker device counts
+ registry = getattr(app.state, "broker_registry", None)
+ broker_devices = registry.get_all() if registry else []
+ broker_online = sum(1 for d in broker_devices if d.is_online)
+
+ stats["total_devices"] += len(broker_devices)
+ stats["online_devices"] += broker_online
+ stats["offline_devices"] += len(broker_devices) - broker_online
+
+ # Add broker topics from EMQX API
+ emqx = getattr(app.state, "emqx_client", None)
+ broker_topics = []
+ if emqx:
+ try:
+ topics = await emqx.get_topics()
+ broker_topics = [t.get("topic", "") for t in topics]
+ except Exception:
+ pass
+
+ stats["broker_topics"] = broker_topics
+ stats["broker_device_count"] = len(broker_devices)
+ stats["broker_online_count"] = broker_online
+
+ # Add MQTT connection status
+ mqtt = getattr(app.state, "mqtt_client", None)
+ stats["mqtt_connected"] = mqtt.is_connected if mqtt else False
+
+ return stats
diff --git a/src/mqtt_home/api/devices.py b/src/mqtt_home/api/devices.py
index 40675d1..81536b1 100644
--- a/src/mqtt_home/api/devices.py
+++ b/src/mqtt_home/api/devices.py
@@ -6,6 +6,7 @@ from mqtt_home.device_registry import (
list_devices, get_device, create_device, update_device,
delete_device, send_command, get_device_logs,
)
+from mqtt_home.models import DeviceLog
from mqtt_home.schemas import (
DeviceCreate, DeviceUpdate, DeviceCommand,
DeviceResponse, DeviceLogResponse,
@@ -16,7 +17,29 @@ router = APIRouter()
@router.get("", response_model=list[DeviceResponse])
async def get_devices(db: AsyncSession = Depends(get_db)):
- return await list_devices(db)
+ db_devices = await list_devices(db)
+ result = [DeviceResponse.model_validate(d) for d in db_devices]
+
+ # Merge broker devices
+ from mqtt_home.main import app
+ registry = getattr(app.state, "broker_registry", None)
+ if registry:
+ for bd in registry.get_all():
+ result.append(DeviceResponse(
+ id=bd.id,
+ name=bd.name,
+ type=bd.type,
+ protocol=bd.protocol,
+ mqtt_topic=bd.mqtt_topic,
+ command_topic=bd.command_topic,
+ state=bd.state,
+ is_online=bd.is_online,
+ last_seen=bd.last_seen,
+ created_at=bd.created_at,
+ updated_at=bd.updated_at,
+ ))
+
+ return result
@router.post("", response_model=DeviceResponse, status_code=201)
@@ -56,6 +79,32 @@ async def command_device(
if not mqtt_client or not mqtt_client.is_connected:
raise HTTPException(status_code=503, detail="MQTT not connected")
+ # Handle broker devices
+ if device_id.startswith("broker:"):
+ registry = getattr(app.state, "broker_registry", None)
+ if not registry:
+ raise HTTPException(status_code=404, detail="Device not found")
+ bd = registry.get(device_id)
+ if not bd:
+ raise HTTPException(status_code=404, detail="Device not found")
+ if not bd.command_topic:
+ raise HTTPException(status_code=400, detail="Device has no command_topic configured")
+
+ await mqtt_client.publish(bd.command_topic, data.payload)
+
+ # Create a log entry in DB for broker device commands
+ log = DeviceLog(
+ device_id=device_id,
+ direction="tx",
+ topic=bd.command_topic,
+ payload=data.payload,
+ )
+ db.add(log)
+ await db.commit()
+ await db.refresh(log)
+ return log
+
+ # Original DB device handling
try:
log = await send_command(
db, device_id, data.payload,
diff --git a/src/mqtt_home/api/rules.py b/src/mqtt_home/api/rules.py
new file mode 100644
index 0000000..abc762d
--- /dev/null
+++ b/src/mqtt_home/api/rules.py
@@ -0,0 +1,65 @@
+import logging
+
+from fastapi import APIRouter, Depends, HTTPException
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from mqtt_home.database import get_db
+from mqtt_home.rule_registry import list_rules, get_rule, create_rule, update_rule, delete_rule
+from mqtt_home.schemas import RuleCreate, RuleUpdate, RuleResponse
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+async def _refresh_rules_cache():
+ """Refresh the cached topic rules in app state."""
+ from mqtt_home.main import app
+ from mqtt_home.database import get_session_factory
+
+ session_factory = get_session_factory()
+ async with session_factory() as db:
+ app.state.topic_rules = [r for r in await list_rules(db) if r.is_enabled]
+
+
+@router.get("", response_model=list[RuleResponse])
+async def get_rules(db: AsyncSession = Depends(get_db)):
+ return await list_rules(db)
+
+
+@router.post("", response_model=RuleResponse, status_code=201)
+async def add_rule(data: RuleCreate, db: AsyncSession = Depends(get_db)):
+ rule = await create_rule(db, data)
+ await _refresh_rules_cache()
+ return rule
+
+
+@router.get("/{rule_id}", response_model=RuleResponse)
+async def get_rule_detail(rule_id: int, db: AsyncSession = Depends(get_db)):
+ rule = await get_rule(db, rule_id)
+ if not rule:
+ raise HTTPException(status_code=404, detail="Rule not found")
+ return rule
+
+
+@router.put("/{rule_id}", response_model=RuleResponse)
+async def patch_rule(rule_id: int, data: RuleUpdate, db: AsyncSession = Depends(get_db)):
+ rule = await update_rule(db, rule_id, data)
+ if not rule:
+ raise HTTPException(status_code=404, detail="Rule not found")
+ await _refresh_rules_cache()
+ return rule
+
+
+@router.delete("/{rule_id}", status_code=204)
+async def remove_rule(rule_id: int, db: AsyncSession = Depends(get_db)):
+ if not await delete_rule(db, rule_id):
+ raise HTTPException(status_code=404, detail="Rule not found")
+ # Clean up broker devices associated with this rule
+ from mqtt_home.main import app
+ registry = getattr(app.state, "broker_registry", None)
+ if registry:
+ removed = registry.remove_by_rule(rule_id)
+ if removed:
+ logger.info("Removed %d broker devices for rule %d", len(removed), rule_id)
+ await _refresh_rules_cache()
diff --git a/src/mqtt_home/broker_devices.py b/src/mqtt_home/broker_devices.py
new file mode 100644
index 0000000..3ec6ec5
--- /dev/null
+++ b/src/mqtt_home/broker_devices.py
@@ -0,0 +1,107 @@
+import json
+import logging
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from typing import Optional
+
+from mqtt_home.topic_matcher import match_topic, extract_device_id, build_command_topic, extract_state_value
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class BrokerDevice:
+ """In-memory broker device, not persisted to DB"""
+ id: str # "broker:{device_id}" e.g. "broker:fire"
+ name: str # Human-readable name, e.g. "fire"
+ type: str # Device type from rule, e.g. "switch"
+ protocol: str = "topic_rule"
+ mqtt_topic: str = "" # The matched topic, e.g. "home/fire"
+ command_topic: Optional[str] = None # Built from template, e.g. "home/fire/set"
+ state: Optional[str] = None # Latest payload (or extracted value)
+ is_online: bool = False
+ last_seen: Optional[datetime] = None
+ rule_id: int = 0
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+
+
+class BrokerDeviceRegistry:
+ """In-memory registry for devices discovered via topic format rules"""
+
+ def __init__(self):
+ self._devices: dict[str, BrokerDevice] = {}
+
+ def update_or_create(self, topic: str, payload: str, rule) -> Optional[BrokerDevice]:
+ """Try to match topic against rule pattern, create/update device if matched.
+
+ Args:
+ topic: The MQTT topic that received a message
+ payload: The message payload string
+ rule: TopicFormatRule instance (SQLAlchemy model or dict-like)
+
+ Returns:
+ Updated or created BrokerDevice, or None if topic doesn't match rule
+ """
+ # 1. Check if topic matches the rule's pattern
+ pattern = rule.topic_pattern
+ if not match_topic(topic, pattern):
+ return None
+
+ # 2. Extract device_id from topic
+ device_id_raw = extract_device_id(topic, pattern)
+ if not device_id_raw:
+ return None
+
+ device_id = f"broker:{device_id_raw}"
+
+ # 3. Extract state value from payload
+ state = extract_state_value(payload, rule.state_value_path)
+
+ # 4. Build command topic from template
+ command_topic = build_command_topic(rule.command_template, device_id_raw)
+
+ now = datetime.now(timezone.utc)
+
+ # 5. Update existing or create new
+ if device_id in self._devices:
+ device = self._devices[device_id]
+ device.state = state
+ device.is_online = True
+ device.last_seen = now
+ device.updated_at = now
+ # Update command_topic in case rule changed
+ if command_topic:
+ device.command_topic = command_topic
+ else:
+ device = BrokerDevice(
+ id=device_id,
+ name=device_id_raw,
+ type=rule.device_type,
+ mqtt_topic=topic,
+ command_topic=command_topic,
+ state=state,
+ is_online=True,
+ last_seen=now,
+ rule_id=rule.id,
+ )
+ self._devices[device_id] = device
+ logger.info("Broker device discovered: %s (topic=%s, rule=%d)", device_id, topic, rule.id)
+
+ return device
+
+ def get_all(self) -> list[BrokerDevice]:
+ return list(self._devices.values())
+
+ def get(self, device_id: str) -> Optional[BrokerDevice]:
+ return self._devices.get(device_id)
+
+ def remove_by_rule(self, rule_id: int) -> list[str]:
+ """Remove all devices associated with a rule. Returns list of removed device IDs."""
+ to_remove = [did for did, d in self._devices.items() if d.rule_id == rule_id]
+ for did in to_remove:
+ del self._devices[did]
+ return to_remove
+
+ def clear(self):
+ self._devices.clear()
diff --git a/src/mqtt_home/main.py b/src/mqtt_home/main.py
index d9648dc..b41ad5f 100644
--- a/src/mqtt_home/main.py
+++ b/src/mqtt_home/main.py
@@ -19,12 +19,25 @@ from mqtt_home.mqtt_client import MqttClient
from mqtt_home.emqx_api import EmqxApiClient
from mqtt_home.discovery import handle_discovery_message
from mqtt_home.device_registry import handle_state_update
+from mqtt_home.broker_devices import BrokerDeviceRegistry
+from mqtt_home.rule_registry import list_rules
from mqtt_home.api import api_router
from mqtt_home.ws import websocket_endpoint, broadcast_device_update
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
+# Default topic rules created on first startup when no rules exist
+DEFAULT_RULES = [
+ {
+ "name": "通用设备",
+ "topic_pattern": "home/+",
+ "device_type": "switch",
+ "command_template": "home/{device_id}/set",
+ "state_value_path": None,
+ },
+]
+
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -42,11 +55,22 @@ async def lifespan(app: FastAPI):
session_factory = get_session_factory()
+ # Create default rules if none exist (first startup)
+ async with session_factory() as db:
+ from mqtt_home.rule_registry import list_rules, create_rule
+ existing = await list_rules(db)
+ if not existing:
+ from mqtt_home.schemas import RuleCreate
+ for default in DEFAULT_RULES:
+ await create_rule(db, RuleCreate(**default))
+ logger.info("Created default rule: %s", default["name"])
+
async def on_discovery(topic: str, payload: str):
async with session_factory() as db:
await handle_discovery_message(topic, payload, db, mqtt)
async def on_state(topic: str, payload: str):
+ # Original: update DB devices
async with session_factory() as db:
device = await handle_state_update(db, topic, payload)
if device:
@@ -56,12 +80,37 @@ async def lifespan(app: FastAPI):
"last_seen": device.last_seen.isoformat() if device.last_seen else None,
})
+ # New: match topic format rules, update broker device registry
+ rules = getattr(app.state, "topic_rules", [])
+ registry: BrokerDeviceRegistry = getattr(app.state, "broker_registry", None)
+ if registry and rules:
+ for rule in rules:
+ if not rule.is_enabled:
+ continue
+ updated = registry.update_or_create(topic, payload, rule)
+ if updated:
+ await broadcast_device_update(updated.id, {
+ "state": updated.state,
+ "is_online": updated.is_online,
+ "last_seen": updated.last_seen.isoformat() if updated.last_seen else None,
+ "source": "broker",
+ })
+
mqtt.on_message("homeassistant/#", on_discovery)
mqtt.on_message("home/#", on_state)
await mqtt.start()
logger.info("MQTT client started")
+ # Initialize broker device registry
+ broker_registry = BrokerDeviceRegistry()
+ app.state.broker_registry = broker_registry
+
+ # Load topic rules from DB
+ async with session_factory() as db:
+ app.state.topic_rules = [r for r in await list_rules(db) if r.is_enabled]
+ logger.info("Loaded %d topic format rules", len(app.state.topic_rules))
+
yield
await mqtt.stop()
diff --git a/src/mqtt_home/models.py b/src/mqtt_home/models.py
index bb3cd71..28ba54f 100644
--- a/src/mqtt_home/models.py
+++ b/src/mqtt_home/models.py
@@ -46,4 +46,17 @@ class DeviceLog(Base):
payload = Column(Text, nullable=False)
timestamp = Column(DateTime(timezone=True), nullable=False, default=_utcnow)
- device = relationship("Device", back_populates="logs")
\ No newline at end of file
+ device = relationship("Device", back_populates="logs")
+
+
+class TopicFormatRule(Base):
+ __tablename__ = "topic_format_rules"
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ name = Column(String(200), nullable=False)
+ topic_pattern = Column(String(500), nullable=False) # MQTT pattern: "home/+/state"
+ device_type = Column(String(50), nullable=False, default="switch")
+ command_template = Column(String(500), nullable=True) # "home/{device_id}/set"
+ state_value_path = Column(String(200), nullable=True) # "state"
+ is_enabled = Column(Boolean, nullable=False, default=True)
+ created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow)
\ No newline at end of file
diff --git a/src/mqtt_home/rule_registry.py b/src/mqtt_home/rule_registry.py
new file mode 100644
index 0000000..5ea7e0f
--- /dev/null
+++ b/src/mqtt_home/rule_registry.py
@@ -0,0 +1,57 @@
+import logging
+from typing import Optional
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from mqtt_home.models import TopicFormatRule
+from mqtt_home.schemas import RuleCreate, RuleUpdate
+
+logger = logging.getLogger(__name__)
+
+
+async def list_rules(db: AsyncSession) -> list[TopicFormatRule]:
+ result = await db.execute(select(TopicFormatRule).order_by(TopicFormatRule.created_at.desc()))
+ return list(result.scalars().all())
+
+
+async def get_rule(db: AsyncSession, rule_id: int) -> Optional[TopicFormatRule]:
+ return await db.get(TopicFormatRule, rule_id)
+
+
+async def create_rule(db: AsyncSession, data: RuleCreate) -> TopicFormatRule:
+ rule = TopicFormatRule(
+ name=data.name,
+ topic_pattern=data.topic_pattern,
+ device_type=data.device_type,
+ command_template=data.command_template,
+ state_value_path=data.state_value_path,
+ is_enabled=data.is_enabled,
+ )
+ db.add(rule)
+ await db.commit()
+ await db.refresh(rule)
+ logger.info("Rule created: %s (id=%d, pattern=%s)", rule.name, rule.id, rule.topic_pattern)
+ return rule
+
+
+async def update_rule(db: AsyncSession, rule_id: int, data: RuleUpdate) -> Optional[TopicFormatRule]:
+ rule = await db.get(TopicFormatRule, rule_id)
+ if not rule:
+ return None
+ update_data = data.model_dump(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(rule, key, value)
+ await db.commit()
+ await db.refresh(rule)
+ return rule
+
+
+async def delete_rule(db: AsyncSession, rule_id: int) -> bool:
+ rule = await db.get(TopicFormatRule, rule_id)
+ if not rule:
+ return False
+ await db.delete(rule)
+ await db.commit()
+ logger.info("Rule deleted: id=%d", rule_id)
+ return True
diff --git a/src/mqtt_home/schemas.py b/src/mqtt_home/schemas.py
index 0c8d7d4..a2200e3 100644
--- a/src/mqtt_home/schemas.py
+++ b/src/mqtt_home/schemas.py
@@ -63,8 +63,55 @@ class BrokerTopic(BaseModel):
node: Optional[str] = None
+class BrokerDeviceResponse(BaseModel):
+ id: str
+ name: str
+ type: str
+ protocol: str = "topic_rule"
+ mqtt_topic: str
+ command_topic: str | None = None
+ state: str | None = None
+ is_online: bool = False
+ last_seen: datetime | None = None
+ rule_id: int = 0
+
+
class DashboardStats(BaseModel):
total_devices: int
online_devices: int
offline_devices: int
recent_logs: list[DeviceLogResponse]
+ broker_topics: list[str] = []
+ broker_device_count: int = 0
+ broker_online_count: int = 0
+ mqtt_connected: bool = False
+
+
+class RuleCreate(BaseModel):
+ name: str
+ topic_pattern: str
+ device_type: str = "switch"
+ command_template: str | None = None
+ state_value_path: str | None = None
+ is_enabled: bool = True
+
+
+class RuleUpdate(BaseModel):
+ name: str | None = None
+ topic_pattern: str | None = None
+ device_type: str | None = None
+ command_template: str | None = None
+ state_value_path: str | None = None
+ is_enabled: bool | None = None
+
+
+class RuleResponse(BaseModel):
+ id: int
+ name: str
+ topic_pattern: str
+ device_type: str
+ command_template: str | None = None
+ state_value_path: str | None = None
+ is_enabled: bool
+ created_at: datetime
+ model_config = {"from_attributes": True}
diff --git a/src/mqtt_home/topic_matcher.py b/src/mqtt_home/topic_matcher.py
new file mode 100644
index 0000000..7ac8c43
--- /dev/null
+++ b/src/mqtt_home/topic_matcher.py
@@ -0,0 +1,131 @@
+import json
+from typing import Optional
+
+
+def match_topic(topic: str, pattern: str) -> dict[str, str] | None:
+ """Match an MQTT topic against a pattern with + and # wildcards.
+
+ Converts MQTT pattern to regex: + -> ([^/]+), # -> (.+)
+ Returns dict of capture groups indexed by wildcard position, e.g. {"1": "fire"}
+ Returns None if no match.
+ """
+ topic_parts = topic.split("/")
+ pattern_parts = pattern.split("/")
+
+ captures: dict[str, str] = {}
+ t_idx = 0
+ p_idx = 0
+
+ while t_idx < len(topic_parts) and p_idx < len(pattern_parts):
+ pat = pattern_parts[p_idx]
+
+ if pat == "#":
+ # # matches everything remaining (must be last in pattern per MQTT spec)
+ captures[str(p_idx)] = "/".join(topic_parts[t_idx:])
+ t_idx = len(topic_parts)
+ p_idx += 1
+ elif pat == "+":
+ # + matches exactly one level
+ captures[str(p_idx)] = topic_parts[t_idx]
+ t_idx += 1
+ p_idx += 1
+ elif pat == topic_parts[t_idx]:
+ t_idx += 1
+ p_idx += 1
+ else:
+ return None
+
+ # After main loop, check remaining parts
+ if p_idx < len(pattern_parts):
+ # Remaining pattern parts must all be empty-string compatible or #
+ remaining_pattern = pattern_parts[p_idx:]
+ if remaining_pattern == ["#"] or remaining_pattern == []:
+ if remaining_pattern == ["#"]:
+ captures[str(p_idx)] = ""
+ p_idx = len(pattern_parts)
+ else:
+ return None
+
+ if t_idx < len(topic_parts):
+ return None
+
+ if captures:
+ return captures
+ return None
+
+
+def extract_device_id(topic: str, pattern: str) -> str | None:
+ """Extract the device identifier from a topic based on pattern.
+
+ For patterns with + wildcards, returns the last + match.
+ For patterns with #, returns the part matched by # (last segment).
+ Returns None if topic doesn't match pattern.
+
+ Examples:
+ extract_device_id("home/fire/state", "home/+/state") -> "fire"
+ extract_device_id("home/fire/brightness", "home/fire/#") -> "brightness"
+ """
+ result = match_topic(topic, pattern)
+ if result is None:
+ return None
+
+ # Check if # was used: # always captures everything remaining
+ # We need to check the pattern for # to identify its index
+ pattern_parts = pattern.split("/")
+ if "#" in pattern_parts:
+ hash_idx = str(pattern_parts.index("#"))
+ if hash_idx in result:
+ value = result[hash_idx]
+ # For #, return the last segment
+ if value:
+ return value.split("/")[-1]
+ return ""
+
+ # For + wildcards, return the last + match (highest index)
+ plus_indices = [
+ str(i) for i, p in enumerate(pattern_parts) if p == "+"
+ ]
+ if plus_indices:
+ last_plus_idx = plus_indices[-1]
+ return result[last_plus_idx]
+
+ return None
+
+
+def build_command_topic(command_template: str | None, device_id: str) -> str | None:
+ """Replace {device_id} in command template with actual value.
+
+ Example: build_command_topic("home/{device_id}/set", "fire") -> "home/fire/set"
+ Returns None if template is None or empty.
+ """
+ if not command_template:
+ return None
+ return command_template.replace("{device_id}", device_id)
+
+
+def extract_state_value(payload: str, state_value_path: str | None) -> str:
+ """Extract state value from JSON payload.
+
+ If state_value_path is set, extract that key from JSON payload.
+ If JSON parsing fails or path not found, return original payload.
+ If state_value_path is None, return entire payload string.
+
+ Examples:
+ extract_state_value('{"state":"on"}', "state") -> "on"
+ extract_state_value('{"state":"on"}', None) -> '{"state":"on"}'
+ extract_state_value("plain text", "state") -> "plain text"
+ extract_state_value('{"brightness":255}', "brightness") -> "255"
+ """
+ if state_value_path is None:
+ return payload
+
+ try:
+ data = json.loads(payload)
+ if not isinstance(data, dict):
+ return payload
+ value = data.get(state_value_path)
+ if value is None:
+ return payload
+ return str(value)
+ except (json.JSONDecodeError, TypeError, ValueError):
+ return payload
diff --git a/tests/test_broker_devices.py b/tests/test_broker_devices.py
new file mode 100644
index 0000000..edb11e0
--- /dev/null
+++ b/tests/test_broker_devices.py
@@ -0,0 +1,86 @@
+from types import SimpleNamespace
+from mqtt_home.broker_devices import BrokerDevice, BrokerDeviceRegistry
+
+
+def make_rule(id=1, topic_pattern="home/+/state", device_type="switch",
+ command_template="home/{device_id}/set", state_value_path="state",
+ is_enabled=True):
+ return SimpleNamespace(
+ id=id, topic_pattern=topic_pattern, device_type=device_type,
+ command_template=command_template, state_value_path=state_value_path,
+ is_enabled=is_enabled,
+ )
+
+
+def test_create_device_on_match():
+ registry = BrokerDeviceRegistry()
+ rule = make_rule()
+ device = registry.update_or_create("home/fire/state", '{"state":"on"}', rule)
+ assert device is not None
+ assert device.id == "broker:fire"
+ assert device.name == "fire"
+ assert device.type == "switch"
+ assert device.state == "on"
+ assert device.command_topic == "home/fire/set"
+ assert device.is_online is True
+ assert device.rule_id == 1
+
+
+def test_update_existing_device():
+ registry = BrokerDeviceRegistry()
+ rule = make_rule()
+ registry.update_or_create("home/fire/state", '{"state":"on"}', rule)
+ device = registry.update_or_create("home/fire/state", '{"state":"off"}', rule)
+ assert device.state == "off"
+ assert len(registry.get_all()) == 1
+
+
+def test_different_topics_different_devices():
+ registry = BrokerDeviceRegistry()
+ rule = make_rule()
+ registry.update_or_create("home/fire/state", '{"state":"on"}', rule)
+ registry.update_or_create("home/servo/state", '{"state":"off"}', rule)
+ assert len(registry.get_all()) == 2
+
+
+def test_no_match_returns_none():
+ registry = BrokerDeviceRegistry()
+ rule = make_rule(topic_pattern="home/+/state")
+ result = registry.update_or_create("living/light/status", '{"state":"on"}', rule)
+ assert result is None
+
+
+def test_remove_by_rule():
+ registry = BrokerDeviceRegistry()
+ rule1 = make_rule(id=1)
+ rule2 = make_rule(id=2, topic_pattern="sensor/+/data")
+ registry.update_or_create("home/fire/state", '{"state":"on"}', rule1)
+ registry.update_or_create("sensor/temp/data", '{"value":25}', rule2)
+ assert len(registry.get_all()) == 2
+ removed = registry.remove_by_rule(1)
+ assert removed == ["broker:fire"]
+ assert len(registry.get_all()) == 1
+ assert registry.get("broker:fire") is None
+
+
+def test_full_payload_as_state():
+ registry = BrokerDeviceRegistry()
+ rule = make_rule(state_value_path=None)
+ device = registry.update_or_create("home/fire/state", '{"state":"on","brightness":128}', rule)
+ assert device.state == '{"state":"on","brightness":128}'
+
+
+def test_get_device():
+ registry = BrokerDeviceRegistry()
+ rule = make_rule()
+ registry.update_or_create("home/fire/state", '{"state":"on"}', rule)
+ assert registry.get("broker:fire") is not None
+ assert registry.get("broker:nonexistent") is None
+
+
+def test_clear():
+ registry = BrokerDeviceRegistry()
+ rule = make_rule()
+ registry.update_or_create("home/fire/state", '{"state":"on"}', rule)
+ registry.clear()
+ assert len(registry.get_all()) == 0
diff --git a/tests/test_rule_registry.py b/tests/test_rule_registry.py
new file mode 100644
index 0000000..96dff28
--- /dev/null
+++ b/tests/test_rule_registry.py
@@ -0,0 +1,66 @@
+import pytest
+from mqtt_home.rule_registry import list_rules, get_rule, create_rule, update_rule, delete_rule
+from mqtt_home.schemas import RuleCreate, RuleUpdate
+
+
+@pytest.mark.asyncio
+async def test_create_rule(db_session):
+ rule = await create_rule(db_session, RuleCreate(
+ name="Test Rule",
+ topic_pattern="home/+/state",
+ device_type="switch",
+ command_template="home/{device_id}/set",
+ state_value_path="state",
+ ))
+ assert rule.id is not None
+ assert rule.name == "Test Rule"
+ assert rule.topic_pattern == "home/+/state"
+ assert rule.is_enabled is True
+
+
+@pytest.mark.asyncio
+async def test_list_rules(db_session):
+ await create_rule(db_session, RuleCreate(name="Rule 1", topic_pattern="a/+"))
+ await create_rule(db_session, RuleCreate(name="Rule 2", topic_pattern="b/+"))
+ rules = await list_rules(db_session)
+ assert len(rules) == 2
+
+
+@pytest.mark.asyncio
+async def test_get_rule(db_session):
+ created = await create_rule(db_session, RuleCreate(name="Test", topic_pattern="x/+"))
+ rule = await get_rule(db_session, created.id)
+ assert rule is not None
+ assert rule.name == "Test"
+
+
+@pytest.mark.asyncio
+async def test_get_rule_not_found(db_session):
+ rule = await get_rule(db_session, 99999)
+ assert rule is None
+
+
+@pytest.mark.asyncio
+async def test_update_rule(db_session):
+ created = await create_rule(db_session, RuleCreate(name="Original", topic_pattern="x/+"))
+ updated = await update_rule(db_session, created.id, RuleUpdate(name="Updated"))
+ assert updated.name == "Updated"
+ assert updated.topic_pattern == "x/+" # unchanged
+
+
+@pytest.mark.asyncio
+async def test_update_rule_not_found(db_session):
+ result = await update_rule(db_session, 99999, RuleUpdate(name="X"))
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_delete_rule(db_session):
+ created = await create_rule(db_session, RuleCreate(name="To Delete", topic_pattern="x/+"))
+ assert await delete_rule(db_session, created.id) is True
+ assert await get_rule(db_session, created.id) is None
+
+
+@pytest.mark.asyncio
+async def test_delete_rule_not_found(db_session):
+ assert await delete_rule(db_session, 99999) is False
diff --git a/tests/test_topic_matcher.py b/tests/test_topic_matcher.py
new file mode 100644
index 0000000..9b383b8
--- /dev/null
+++ b/tests/test_topic_matcher.py
@@ -0,0 +1,74 @@
+from mqtt_home.topic_matcher import match_topic, extract_device_id, build_command_topic, extract_state_value
+
+
+def test_match_topic_single_plus():
+ assert match_topic("home/fire/state", "home/+/state") == {"1": "fire"}
+
+
+def test_match_topic_no_match():
+ assert match_topic("home/fire/state", "home/living/state") is None
+
+
+def test_match_topic_hash():
+ # # at end matches remaining segments
+ result = match_topic("home/fire/brightness", "home/fire/#")
+ assert result is not None
+ assert "2" in result # # is at index 2
+
+
+def test_match_topic_multiple_plus():
+ assert match_topic("home/fire/living/state", "home/+/+/state") == {"1": "fire", "2": "living"}
+
+
+def test_match_topic_hash_matches_empty():
+ assert match_topic("home/fire/", "home/fire/#") is not None
+
+
+def test_extract_device_id_single_plus():
+ assert extract_device_id("home/fire/state", "home/+/state") == "fire"
+
+
+def test_extract_device_id_no_match():
+ assert extract_device_id("home/fire/state", "home/living/state") is None
+
+
+def test_extract_device_id_hash():
+ assert extract_device_id("home/fire/brightness", "home/fire/#") == "brightness"
+
+
+def test_extract_device_id_multiple_plus():
+ # Returns last + match
+ assert extract_device_id("home/fire/living/state", "home/+/+/state") == "living"
+
+
+def test_build_command_topic_basic():
+ assert build_command_topic("home/{device_id}/set", "fire") == "home/fire/set"
+
+
+def test_build_command_topic_none():
+ assert build_command_topic(None, "fire") is None
+ assert build_command_topic("", "fire") is None
+
+
+def test_extract_state_value_json_path():
+ assert extract_state_value('{"state":"on"}', "state") == "on"
+
+
+def test_extract_state_value_no_path():
+ assert extract_state_value('{"state":"on"}', None) == '{"state":"on"}'
+
+
+def test_extract_state_value_plain_text():
+ assert extract_state_value("plain text", "state") == "plain text"
+
+
+def test_extract_state_value_nested_path():
+ assert extract_state_value('{"brightness":255}', "brightness") == "255"
+
+
+def test_extract_state_value_missing_key():
+ assert extract_state_value('{"temperature":22}', "humidity") == '{"temperature":22}'
+
+
+def test_extract_state_value_numeric_value():
+ assert extract_state_value('{"brightness":255}', "brightness") == "255"