feat: FastAPI app with REST API routes and WebSocket endpoint

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
walkpan
2026-03-29 21:46:00 +08:00
parent abb170ace6
commit 957f489a0d
8 changed files with 394 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
from fastapi import APIRouter
from mqtt_home.api.devices import router as devices_router
from mqtt_home.api.broker import router as broker_router
from mqtt_home.api.dashboard import router as dashboard_router
api_router = APIRouter(prefix="/api")
api_router.include_router(devices_router, prefix="/devices", tags=["devices"])
api_router.include_router(broker_router, prefix="/broker", tags=["broker"])
api_router.include_router(dashboard_router, prefix="/dashboard", tags=["dashboard"])

View File

@@ -0,0 +1,43 @@
from fastapi import APIRouter, HTTPException
from mqtt_home.schemas import BrokerClient, BrokerTopic
router = APIRouter()
@router.get("/status")
async def broker_status():
from mqtt_home.main import app
emqx = getattr(app.state, "emqx_client", None)
if not emqx:
raise HTTPException(status_code=503, detail="EMQX API client not configured")
try:
status = await emqx.get_broker_status()
metrics = await emqx.get_metrics()
return {"status": status, "metrics": metrics}
except Exception as e:
raise HTTPException(status_code=502, detail=f"EMQX API error: {e}")
@router.get("/clients", response_model=list[BrokerClient])
async def broker_clients(limit: int = 100):
from mqtt_home.main import app
emqx = getattr(app.state, "emqx_client", None)
if not emqx:
raise HTTPException(status_code=503, detail="EMQX API client not configured")
try:
return await emqx.get_clients(limit)
except Exception as e:
raise HTTPException(status_code=502, detail=f"EMQX API error: {e}")
@router.get("/topics", response_model=list[BrokerTopic])
async def broker_topics(limit: int = 100):
from mqtt_home.main import app
emqx = getattr(app.state, "emqx_client", None)
if not emqx:
raise HTTPException(status_code=503, detail="EMQX API client not configured")
try:
return await emqx.get_topics(limit)
except Exception as e:
raise HTTPException(status_code=502, detail=f"EMQX API error: {e}")

View File

@@ -0,0 +1,13 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from mqtt_home.database import get_db
from mqtt_home.device_registry import get_dashboard_stats
from mqtt_home.schemas import DashboardStats
router = APIRouter()
@router.get("", response_model=DashboardStats)
async def dashboard(db: AsyncSession = Depends(get_db)):
return await get_dashboard_stats(db)

View File

@@ -0,0 +1,73 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from mqtt_home.database import get_db
from mqtt_home.device_registry import (
list_devices, get_device, create_device, update_device,
delete_device, send_command, get_device_logs,
)
from mqtt_home.schemas import (
DeviceCreate, DeviceUpdate, DeviceCommand,
DeviceResponse, DeviceLogResponse,
)
router = APIRouter()
@router.get("", response_model=list[DeviceResponse])
async def get_devices(db: AsyncSession = Depends(get_db)):
return await list_devices(db)
@router.post("", response_model=DeviceResponse, status_code=201)
async def add_device(data: DeviceCreate, db: AsyncSession = Depends(get_db)):
return await create_device(db, data)
@router.get("/{device_id}", response_model=DeviceResponse)
async def get_device_detail(device_id: str, db: AsyncSession = Depends(get_db)):
device = await get_device(db, device_id)
if not device:
raise HTTPException(status_code=404, detail="Device not found")
return device
@router.put("/{device_id}", response_model=DeviceResponse)
async def patch_device(device_id: str, data: DeviceUpdate, db: AsyncSession = Depends(get_db)):
device = await update_device(db, device_id, data)
if not device:
raise HTTPException(status_code=404, detail="Device not found")
return device
@router.delete("/{device_id}", status_code=204)
async def remove_device(device_id: str, db: AsyncSession = Depends(get_db)):
if not await delete_device(db, device_id):
raise HTTPException(status_code=404, detail="Device not found")
@router.post("/{device_id}/command", response_model=DeviceLogResponse)
async def command_device(
device_id: str, data: DeviceCommand,
db: AsyncSession = Depends(get_db),
):
from mqtt_home.main import app
mqtt_client = getattr(app.state, "mqtt_client", None)
if not mqtt_client or not mqtt_client.is_connected:
raise HTTPException(status_code=503, detail="MQTT not connected")
try:
log = await send_command(
db, device_id, data.payload,
publish_fn=mqtt_client.publish,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if not log:
raise HTTPException(status_code=404, detail="Device not found")
return log
@router.get("/{device_id}/logs", response_model=list[DeviceLogResponse])
async def read_device_logs(device_id: str, limit: int = 20, db: AsyncSession = Depends(get_db)):
return await get_device_logs(db, device_id, limit)

84
src/mqtt_home/main.py Normal file
View File

@@ -0,0 +1,84 @@
import asyncio
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from mqtt_home.config import get_settings
from mqtt_home.database import init_db, get_session_factory, Base
from mqtt_home.mqtt_client import MqttClient
from mqtt_home.emqx_api import EmqxApiClient
from mqtt_home.discovery import handle_discovery_message
from mqtt_home.device_registry import handle_state_update
from mqtt_home.api import api_router
from mqtt_home.ws import websocket_endpoint, broadcast_device_update
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = get_settings()
await init_db()
logger.info("Database initialized")
emqx = EmqxApiClient(settings)
app.state.emqx_client = emqx
logger.info("EMQX API client initialized")
mqtt = MqttClient(settings)
app.state.mqtt_client = mqtt
session_factory = get_session_factory()
async def on_discovery(topic: str, payload: str):
async with session_factory() as db:
await handle_discovery_message(topic, payload, db, mqtt)
async def on_state(topic: str, payload: str):
async with session_factory() as db:
device = await handle_state_update(db, topic, payload)
if device:
await broadcast_device_update(device.id, {
"state": device.state,
"is_online": device.is_online,
"last_seen": device.last_seen.isoformat() if device.last_seen else None,
})
mqtt.on_message("homeassistant/#", on_discovery)
mqtt.on_message("home/#", on_state)
await mqtt.start()
logger.info("MQTT client started")
yield
await mqtt.stop()
await emqx.close()
logger.info("Shutdown complete")
app = FastAPI(title="MQTT Home", version="0.1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router)
app.websocket("/ws/devices")(websocket_endpoint)
@app.get("/health")
async def health():
mqtt = getattr(app.state, "mqtt_client", None)
return {
"status": "ok",
"mqtt_connected": mqtt.is_connected if mqtt else False,
}

50
src/mqtt_home/ws.py Normal file
View File

@@ -0,0 +1,50 @@
import asyncio
import json
import logging
from fastapi import WebSocket, WebSocketDisconnect
logger = logging.getLogger(__name__)
class ConnectionManager:
def __init__(self):
self._connections: list[WebSocket] = []
async def connect(self, ws: WebSocket):
await ws.accept()
self._connections.append(ws)
logger.info("WebSocket connected, total: %d", len(self._connections))
def disconnect(self, ws: WebSocket):
self._connections.remove(ws)
logger.info("WebSocket disconnected, total: %d", len(self._connections))
async def broadcast(self, message: dict):
dead = []
for ws in self._connections:
try:
await ws.send_json(message)
except Exception:
dead.append(ws)
for ws in dead:
self.disconnect(ws)
ws_manager = ConnectionManager()
async def websocket_endpoint(ws: WebSocket):
await ws_manager.connect(ws)
try:
while True:
await ws.receive_text()
except WebSocketDisconnect:
ws_manager.disconnect(ws)
async def broadcast_device_update(device_id: str, data: dict):
await ws_manager.broadcast({
"type": "device_update",
"device_id": device_id,
**data,
})

29
tests/test_api_broker.py Normal file
View File

@@ -0,0 +1,29 @@
import pytest
from httpx import AsyncClient, ASGITransport
from mqtt_home.main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
async def test_broker_status_no_client(client):
resp = await client.get("/api/broker/status")
assert resp.status_code == 503
async def test_broker_clients_no_client(client):
resp = await client.get("/api/broker/clients")
assert resp.status_code == 503
async def test_health_endpoint(client):
resp = await client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert "status" in data
assert "mqtt_connected" in data

93
tests/test_api_devices.py Normal file
View File

@@ -0,0 +1,93 @@
import pytest
from httpx import AsyncClient, ASGITransport
from mqtt_home.main import app
from mqtt_home.database import Base, get_engine
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.pool import StaticPool
from mqtt_home.database import get_db
TEST_DB = "sqlite+aiosqlite:///:memory:"
@pytest.fixture
async def client():
engine = create_async_engine(TEST_DB, connect_args={"check_same_thread": False}, poolclass=StaticPool)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
test_factory = async_sessionmaker(engine, expire_on_commit=False)
async def override_get_db():
async with test_factory() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
await engine.dispose()
async def test_get_devices_empty(client):
resp = await client.get("/api/devices")
assert resp.status_code == 200
assert resp.json() == []
async def test_create_device(client):
resp = await client.post("/api/devices", json={
"name": "客厅灯",
"type": "light",
"mqtt_topic": "home/light",
"command_topic": "home/light/set",
})
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "客厅灯"
assert data["type"] == "light"
async def test_get_device_detail(client):
create_resp = await client.post("/api/devices", json={
"name": "", "type": "switch", "mqtt_topic": "t"
})
device_id = create_resp.json()["id"]
resp = await client.get(f"/api/devices/{device_id}")
assert resp.status_code == 200
assert resp.json()["name"] == ""
async def test_get_device_not_found(client):
resp = await client.get("/api/devices/nonexistent")
assert resp.status_code == 404
async def test_update_device(client):
create_resp = await client.post("/api/devices", json={
"name": "", "type": "switch", "mqtt_topic": "t"
})
device_id = create_resp.json()["id"]
resp = await client.put(f"/api/devices/{device_id}", json={"name": "新名字"})
assert resp.status_code == 200
assert resp.json()["name"] == "新名字"
async def test_delete_device(client):
create_resp = await client.post("/api/devices", json={
"name": "", "type": "switch", "mqtt_topic": "t"
})
device_id = create_resp.json()["id"]
resp = await client.delete(f"/api/devices/{device_id}")
assert resp.status_code == 204
async def test_delete_device_not_found(client):
resp = await client.delete("/api/devices/nonexistent")
assert resp.status_code == 404