From 957f489a0d239414dec6a3e838b408163a583064 Mon Sep 17 00:00:00 2001 From: walkpan Date: Sun, 29 Mar 2026 21:46:00 +0800 Subject: [PATCH] feat: FastAPI app with REST API routes and WebSocket endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/mqtt_home/api/__init__.py | 9 ++++ src/mqtt_home/api/broker.py | 43 ++++++++++++++++ src/mqtt_home/api/dashboard.py | 13 +++++ src/mqtt_home/api/devices.py | 73 ++++++++++++++++++++++++++ src/mqtt_home/main.py | 84 ++++++++++++++++++++++++++++++ src/mqtt_home/ws.py | 50 ++++++++++++++++++ tests/test_api_broker.py | 29 +++++++++++ tests/test_api_devices.py | 93 ++++++++++++++++++++++++++++++++++ 8 files changed, 394 insertions(+) create mode 100644 src/mqtt_home/api/__init__.py create mode 100644 src/mqtt_home/api/broker.py create mode 100644 src/mqtt_home/api/dashboard.py create mode 100644 src/mqtt_home/api/devices.py create mode 100644 src/mqtt_home/main.py create mode 100644 src/mqtt_home/ws.py create mode 100644 tests/test_api_broker.py create mode 100644 tests/test_api_devices.py diff --git a/src/mqtt_home/api/__init__.py b/src/mqtt_home/api/__init__.py new file mode 100644 index 0000000..bdf76f2 --- /dev/null +++ b/src/mqtt_home/api/__init__.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter +from mqtt_home.api.devices import router as devices_router +from mqtt_home.api.broker import router as broker_router +from mqtt_home.api.dashboard import router as dashboard_router + +api_router = APIRouter(prefix="/api") +api_router.include_router(devices_router, prefix="/devices", tags=["devices"]) +api_router.include_router(broker_router, prefix="/broker", tags=["broker"]) +api_router.include_router(dashboard_router, prefix="/dashboard", tags=["dashboard"]) diff --git a/src/mqtt_home/api/broker.py b/src/mqtt_home/api/broker.py new file mode 100644 index 0000000..7e96a90 --- /dev/null +++ b/src/mqtt_home/api/broker.py @@ -0,0 +1,43 @@ +from fastapi import APIRouter, HTTPException + +from mqtt_home.schemas import BrokerClient, BrokerTopic + +router = APIRouter() + + +@router.get("/status") +async def broker_status(): + from mqtt_home.main import app + emqx = getattr(app.state, "emqx_client", None) + if not emqx: + raise HTTPException(status_code=503, detail="EMQX API client not configured") + try: + status = await emqx.get_broker_status() + metrics = await emqx.get_metrics() + return {"status": status, "metrics": metrics} + except Exception as e: + raise HTTPException(status_code=502, detail=f"EMQX API error: {e}") + + +@router.get("/clients", response_model=list[BrokerClient]) +async def broker_clients(limit: int = 100): + from mqtt_home.main import app + emqx = getattr(app.state, "emqx_client", None) + if not emqx: + raise HTTPException(status_code=503, detail="EMQX API client not configured") + try: + return await emqx.get_clients(limit) + except Exception as e: + raise HTTPException(status_code=502, detail=f"EMQX API error: {e}") + + +@router.get("/topics", response_model=list[BrokerTopic]) +async def broker_topics(limit: int = 100): + from mqtt_home.main import app + emqx = getattr(app.state, "emqx_client", None) + if not emqx: + raise HTTPException(status_code=503, detail="EMQX API client not configured") + try: + return await emqx.get_topics(limit) + except Exception as e: + raise HTTPException(status_code=502, detail=f"EMQX API error: {e}") diff --git a/src/mqtt_home/api/dashboard.py b/src/mqtt_home/api/dashboard.py new file mode 100644 index 0000000..1b20462 --- /dev/null +++ b/src/mqtt_home/api/dashboard.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from mqtt_home.database import get_db +from mqtt_home.device_registry import get_dashboard_stats +from mqtt_home.schemas import DashboardStats + +router = APIRouter() + + +@router.get("", response_model=DashboardStats) +async def dashboard(db: AsyncSession = Depends(get_db)): + return await get_dashboard_stats(db) diff --git a/src/mqtt_home/api/devices.py b/src/mqtt_home/api/devices.py new file mode 100644 index 0000000..40675d1 --- /dev/null +++ b/src/mqtt_home/api/devices.py @@ -0,0 +1,73 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from mqtt_home.database import get_db +from mqtt_home.device_registry import ( + list_devices, get_device, create_device, update_device, + delete_device, send_command, get_device_logs, +) +from mqtt_home.schemas import ( + DeviceCreate, DeviceUpdate, DeviceCommand, + DeviceResponse, DeviceLogResponse, +) + +router = APIRouter() + + +@router.get("", response_model=list[DeviceResponse]) +async def get_devices(db: AsyncSession = Depends(get_db)): + return await list_devices(db) + + +@router.post("", response_model=DeviceResponse, status_code=201) +async def add_device(data: DeviceCreate, db: AsyncSession = Depends(get_db)): + return await create_device(db, data) + + +@router.get("/{device_id}", response_model=DeviceResponse) +async def get_device_detail(device_id: str, db: AsyncSession = Depends(get_db)): + device = await get_device(db, device_id) + if not device: + raise HTTPException(status_code=404, detail="Device not found") + return device + + +@router.put("/{device_id}", response_model=DeviceResponse) +async def patch_device(device_id: str, data: DeviceUpdate, db: AsyncSession = Depends(get_db)): + device = await update_device(db, device_id, data) + if not device: + raise HTTPException(status_code=404, detail="Device not found") + return device + + +@router.delete("/{device_id}", status_code=204) +async def remove_device(device_id: str, db: AsyncSession = Depends(get_db)): + if not await delete_device(db, device_id): + raise HTTPException(status_code=404, detail="Device not found") + + +@router.post("/{device_id}/command", response_model=DeviceLogResponse) +async def command_device( + device_id: str, data: DeviceCommand, + db: AsyncSession = Depends(get_db), +): + from mqtt_home.main import app + mqtt_client = getattr(app.state, "mqtt_client", None) + if not mqtt_client or not mqtt_client.is_connected: + raise HTTPException(status_code=503, detail="MQTT not connected") + + try: + log = await send_command( + db, device_id, data.payload, + publish_fn=mqtt_client.publish, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if not log: + raise HTTPException(status_code=404, detail="Device not found") + return log + + +@router.get("/{device_id}/logs", response_model=list[DeviceLogResponse]) +async def read_device_logs(device_id: str, limit: int = 20, db: AsyncSession = Depends(get_db)): + return await get_device_logs(db, device_id, limit) diff --git a/src/mqtt_home/main.py b/src/mqtt_home/main.py new file mode 100644 index 0000000..c5447ee --- /dev/null +++ b/src/mqtt_home/main.py @@ -0,0 +1,84 @@ +import asyncio +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from mqtt_home.config import get_settings +from mqtt_home.database import init_db, get_session_factory, Base +from mqtt_home.mqtt_client import MqttClient +from mqtt_home.emqx_api import EmqxApiClient +from mqtt_home.discovery import handle_discovery_message +from mqtt_home.device_registry import handle_state_update +from mqtt_home.api import api_router +from mqtt_home.ws import websocket_endpoint, broadcast_device_update + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + settings = get_settings() + + await init_db() + logger.info("Database initialized") + + emqx = EmqxApiClient(settings) + app.state.emqx_client = emqx + logger.info("EMQX API client initialized") + + mqtt = MqttClient(settings) + app.state.mqtt_client = mqtt + + session_factory = get_session_factory() + + async def on_discovery(topic: str, payload: str): + async with session_factory() as db: + await handle_discovery_message(topic, payload, db, mqtt) + + async def on_state(topic: str, payload: str): + async with session_factory() as db: + device = await handle_state_update(db, topic, payload) + if device: + await broadcast_device_update(device.id, { + "state": device.state, + "is_online": device.is_online, + "last_seen": device.last_seen.isoformat() if device.last_seen else None, + }) + + mqtt.on_message("homeassistant/#", on_discovery) + mqtt.on_message("home/#", on_state) + + await mqtt.start() + logger.info("MQTT client started") + + yield + + await mqtt.stop() + await emqx.close() + logger.info("Shutdown complete") + + +app = FastAPI(title="MQTT Home", version="0.1.0", lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(api_router) +app.websocket("/ws/devices")(websocket_endpoint) + + +@app.get("/health") +async def health(): + mqtt = getattr(app.state, "mqtt_client", None) + return { + "status": "ok", + "mqtt_connected": mqtt.is_connected if mqtt else False, + } diff --git a/src/mqtt_home/ws.py b/src/mqtt_home/ws.py new file mode 100644 index 0000000..4ceafb8 --- /dev/null +++ b/src/mqtt_home/ws.py @@ -0,0 +1,50 @@ +import asyncio +import json +import logging +from fastapi import WebSocket, WebSocketDisconnect + +logger = logging.getLogger(__name__) + + +class ConnectionManager: + def __init__(self): + self._connections: list[WebSocket] = [] + + async def connect(self, ws: WebSocket): + await ws.accept() + self._connections.append(ws) + logger.info("WebSocket connected, total: %d", len(self._connections)) + + def disconnect(self, ws: WebSocket): + self._connections.remove(ws) + logger.info("WebSocket disconnected, total: %d", len(self._connections)) + + async def broadcast(self, message: dict): + dead = [] + for ws in self._connections: + try: + await ws.send_json(message) + except Exception: + dead.append(ws) + for ws in dead: + self.disconnect(ws) + + +ws_manager = ConnectionManager() + + +async def websocket_endpoint(ws: WebSocket): + await ws_manager.connect(ws) + try: + while True: + await ws.receive_text() + except WebSocketDisconnect: + ws_manager.disconnect(ws) + + +async def broadcast_device_update(device_id: str, data: dict): + await ws_manager.broadcast({ + "type": "device_update", + "device_id": device_id, + **data, + }) diff --git a/tests/test_api_broker.py b/tests/test_api_broker.py new file mode 100644 index 0000000..120cef7 --- /dev/null +++ b/tests/test_api_broker.py @@ -0,0 +1,29 @@ +import pytest +from httpx import AsyncClient, ASGITransport + +from mqtt_home.main import app + + +@pytest.fixture +async def client(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +async def test_broker_status_no_client(client): + resp = await client.get("/api/broker/status") + assert resp.status_code == 503 + + +async def test_broker_clients_no_client(client): + resp = await client.get("/api/broker/clients") + assert resp.status_code == 503 + + +async def test_health_endpoint(client): + resp = await client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert "status" in data + assert "mqtt_connected" in data diff --git a/tests/test_api_devices.py b/tests/test_api_devices.py new file mode 100644 index 0000000..49651aa --- /dev/null +++ b/tests/test_api_devices.py @@ -0,0 +1,93 @@ +import pytest +from httpx import AsyncClient, ASGITransport + +from mqtt_home.main import app +from mqtt_home.database import Base, get_engine +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker +from sqlalchemy.pool import StaticPool +from mqtt_home.database import get_db + +TEST_DB = "sqlite+aiosqlite:///:memory:" + + +@pytest.fixture +async def client(): + engine = create_async_engine(TEST_DB, connect_args={"check_same_thread": False}, poolclass=StaticPool) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + test_factory = async_sessionmaker(engine, expire_on_commit=False) + + async def override_get_db(): + async with test_factory() as session: + yield session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + app.dependency_overrides.clear() + await engine.dispose() + + +async def test_get_devices_empty(client): + resp = await client.get("/api/devices") + assert resp.status_code == 200 + assert resp.json() == [] + + +async def test_create_device(client): + resp = await client.post("/api/devices", json={ + "name": "客厅灯", + "type": "light", + "mqtt_topic": "home/light", + "command_topic": "home/light/set", + }) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "客厅灯" + assert data["type"] == "light" + + +async def test_get_device_detail(client): + create_resp = await client.post("/api/devices", json={ + "name": "灯", "type": "switch", "mqtt_topic": "t" + }) + device_id = create_resp.json()["id"] + + resp = await client.get(f"/api/devices/{device_id}") + assert resp.status_code == 200 + assert resp.json()["name"] == "灯" + + +async def test_get_device_not_found(client): + resp = await client.get("/api/devices/nonexistent") + assert resp.status_code == 404 + + +async def test_update_device(client): + create_resp = await client.post("/api/devices", json={ + "name": "灯", "type": "switch", "mqtt_topic": "t" + }) + device_id = create_resp.json()["id"] + + resp = await client.put(f"/api/devices/{device_id}", json={"name": "新名字"}) + assert resp.status_code == 200 + assert resp.json()["name"] == "新名字" + + +async def test_delete_device(client): + create_resp = await client.post("/api/devices", json={ + "name": "灯", "type": "switch", "mqtt_topic": "t" + }) + device_id = create_resp.json()["id"] + + resp = await client.delete(f"/api/devices/{device_id}") + assert resp.status_code == 204 + + +async def test_delete_device_not_found(client): + resp = await client.delete("/api/devices/nonexistent") + assert resp.status_code == 404