feat: FastAPI app with REST API routes and WebSocket endpoint
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
9
src/mqtt_home/api/__init__.py
Normal file
9
src/mqtt_home/api/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi import APIRouter
|
||||
from mqtt_home.api.devices import router as devices_router
|
||||
from mqtt_home.api.broker import router as broker_router
|
||||
from mqtt_home.api.dashboard import router as dashboard_router
|
||||
|
||||
api_router = APIRouter(prefix="/api")
|
||||
api_router.include_router(devices_router, prefix="/devices", tags=["devices"])
|
||||
api_router.include_router(broker_router, prefix="/broker", tags=["broker"])
|
||||
api_router.include_router(dashboard_router, prefix="/dashboard", tags=["dashboard"])
|
||||
43
src/mqtt_home/api/broker.py
Normal file
43
src/mqtt_home/api/broker.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from mqtt_home.schemas import BrokerClient, BrokerTopic
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def broker_status():
|
||||
from mqtt_home.main import app
|
||||
emqx = getattr(app.state, "emqx_client", None)
|
||||
if not emqx:
|
||||
raise HTTPException(status_code=503, detail="EMQX API client not configured")
|
||||
try:
|
||||
status = await emqx.get_broker_status()
|
||||
metrics = await emqx.get_metrics()
|
||||
return {"status": status, "metrics": metrics}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"EMQX API error: {e}")
|
||||
|
||||
|
||||
@router.get("/clients", response_model=list[BrokerClient])
|
||||
async def broker_clients(limit: int = 100):
|
||||
from mqtt_home.main import app
|
||||
emqx = getattr(app.state, "emqx_client", None)
|
||||
if not emqx:
|
||||
raise HTTPException(status_code=503, detail="EMQX API client not configured")
|
||||
try:
|
||||
return await emqx.get_clients(limit)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"EMQX API error: {e}")
|
||||
|
||||
|
||||
@router.get("/topics", response_model=list[BrokerTopic])
|
||||
async def broker_topics(limit: int = 100):
|
||||
from mqtt_home.main import app
|
||||
emqx = getattr(app.state, "emqx_client", None)
|
||||
if not emqx:
|
||||
raise HTTPException(status_code=503, detail="EMQX API client not configured")
|
||||
try:
|
||||
return await emqx.get_topics(limit)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"EMQX API error: {e}")
|
||||
13
src/mqtt_home/api/dashboard.py
Normal file
13
src/mqtt_home/api/dashboard.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from mqtt_home.database import get_db
|
||||
from mqtt_home.device_registry import get_dashboard_stats
|
||||
from mqtt_home.schemas import DashboardStats
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=DashboardStats)
|
||||
async def dashboard(db: AsyncSession = Depends(get_db)):
|
||||
return await get_dashboard_stats(db)
|
||||
73
src/mqtt_home/api/devices.py
Normal file
73
src/mqtt_home/api/devices.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from mqtt_home.database import get_db
|
||||
from mqtt_home.device_registry import (
|
||||
list_devices, get_device, create_device, update_device,
|
||||
delete_device, send_command, get_device_logs,
|
||||
)
|
||||
from mqtt_home.schemas import (
|
||||
DeviceCreate, DeviceUpdate, DeviceCommand,
|
||||
DeviceResponse, DeviceLogResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=list[DeviceResponse])
|
||||
async def get_devices(db: AsyncSession = Depends(get_db)):
|
||||
return await list_devices(db)
|
||||
|
||||
|
||||
@router.post("", response_model=DeviceResponse, status_code=201)
|
||||
async def add_device(data: DeviceCreate, db: AsyncSession = Depends(get_db)):
|
||||
return await create_device(db, data)
|
||||
|
||||
|
||||
@router.get("/{device_id}", response_model=DeviceResponse)
|
||||
async def get_device_detail(device_id: str, db: AsyncSession = Depends(get_db)):
|
||||
device = await get_device(db, device_id)
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
return device
|
||||
|
||||
|
||||
@router.put("/{device_id}", response_model=DeviceResponse)
|
||||
async def patch_device(device_id: str, data: DeviceUpdate, db: AsyncSession = Depends(get_db)):
|
||||
device = await update_device(db, device_id, data)
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
return device
|
||||
|
||||
|
||||
@router.delete("/{device_id}", status_code=204)
|
||||
async def remove_device(device_id: str, db: AsyncSession = Depends(get_db)):
|
||||
if not await delete_device(db, device_id):
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
|
||||
|
||||
@router.post("/{device_id}/command", response_model=DeviceLogResponse)
|
||||
async def command_device(
|
||||
device_id: str, data: DeviceCommand,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from mqtt_home.main import app
|
||||
mqtt_client = getattr(app.state, "mqtt_client", None)
|
||||
if not mqtt_client or not mqtt_client.is_connected:
|
||||
raise HTTPException(status_code=503, detail="MQTT not connected")
|
||||
|
||||
try:
|
||||
log = await send_command(
|
||||
db, device_id, data.payload,
|
||||
publish_fn=mqtt_client.publish,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
if not log:
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
return log
|
||||
|
||||
|
||||
@router.get("/{device_id}/logs", response_model=list[DeviceLogResponse])
|
||||
async def read_device_logs(device_id: str, limit: int = 20, db: AsyncSession = Depends(get_db)):
|
||||
return await get_device_logs(db, device_id, limit)
|
||||
84
src/mqtt_home/main.py
Normal file
84
src/mqtt_home/main.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from mqtt_home.config import get_settings
|
||||
from mqtt_home.database import init_db, get_session_factory, Base
|
||||
from mqtt_home.mqtt_client import MqttClient
|
||||
from mqtt_home.emqx_api import EmqxApiClient
|
||||
from mqtt_home.discovery import handle_discovery_message
|
||||
from mqtt_home.device_registry import handle_state_update
|
||||
from mqtt_home.api import api_router
|
||||
from mqtt_home.ws import websocket_endpoint, broadcast_device_update
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
settings = get_settings()
|
||||
|
||||
await init_db()
|
||||
logger.info("Database initialized")
|
||||
|
||||
emqx = EmqxApiClient(settings)
|
||||
app.state.emqx_client = emqx
|
||||
logger.info("EMQX API client initialized")
|
||||
|
||||
mqtt = MqttClient(settings)
|
||||
app.state.mqtt_client = mqtt
|
||||
|
||||
session_factory = get_session_factory()
|
||||
|
||||
async def on_discovery(topic: str, payload: str):
|
||||
async with session_factory() as db:
|
||||
await handle_discovery_message(topic, payload, db, mqtt)
|
||||
|
||||
async def on_state(topic: str, payload: str):
|
||||
async with session_factory() as db:
|
||||
device = await handle_state_update(db, topic, payload)
|
||||
if device:
|
||||
await broadcast_device_update(device.id, {
|
||||
"state": device.state,
|
||||
"is_online": device.is_online,
|
||||
"last_seen": device.last_seen.isoformat() if device.last_seen else None,
|
||||
})
|
||||
|
||||
mqtt.on_message("homeassistant/#", on_discovery)
|
||||
mqtt.on_message("home/#", on_state)
|
||||
|
||||
await mqtt.start()
|
||||
logger.info("MQTT client started")
|
||||
|
||||
yield
|
||||
|
||||
await mqtt.stop()
|
||||
await emqx.close()
|
||||
logger.info("Shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="MQTT Home", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(api_router)
|
||||
app.websocket("/ws/devices")(websocket_endpoint)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
mqtt = getattr(app.state, "mqtt_client", None)
|
||||
return {
|
||||
"status": "ok",
|
||||
"mqtt_connected": mqtt.is_connected if mqtt else False,
|
||||
}
|
||||
50
src/mqtt_home/ws.py
Normal file
50
src/mqtt_home/ws.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self._connections: list[WebSocket] = []
|
||||
|
||||
async def connect(self, ws: WebSocket):
|
||||
await ws.accept()
|
||||
self._connections.append(ws)
|
||||
logger.info("WebSocket connected, total: %d", len(self._connections))
|
||||
|
||||
def disconnect(self, ws: WebSocket):
|
||||
self._connections.remove(ws)
|
||||
logger.info("WebSocket disconnected, total: %d", len(self._connections))
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
dead = []
|
||||
for ws in self._connections:
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
ws_manager = ConnectionManager()
|
||||
|
||||
|
||||
async def websocket_endpoint(ws: WebSocket):
|
||||
await ws_manager.connect(ws)
|
||||
try:
|
||||
while True:
|
||||
await ws.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(ws)
|
||||
|
||||
|
||||
async def broadcast_device_update(device_id: str, data: dict):
|
||||
await ws_manager.broadcast({
|
||||
"type": "device_update",
|
||||
"device_id": device_id,
|
||||
**data,
|
||||
})
|
||||
29
tests/test_api_broker.py
Normal file
29
tests/test_api_broker.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from mqtt_home.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
async def test_broker_status_no_client(client):
|
||||
resp = await client.get("/api/broker/status")
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
async def test_broker_clients_no_client(client):
|
||||
resp = await client.get("/api/broker/clients")
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
async def test_health_endpoint(client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "status" in data
|
||||
assert "mqtt_connected" in data
|
||||
93
tests/test_api_devices.py
Normal file
93
tests/test_api_devices.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from mqtt_home.main import app
|
||||
from mqtt_home.database import Base, get_engine
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from mqtt_home.database import get_db
|
||||
|
||||
TEST_DB = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
engine = create_async_engine(TEST_DB, connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
test_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async def override_get_db():
|
||||
async with test_factory() as session:
|
||||
yield session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def test_get_devices_empty(client):
|
||||
resp = await client.get("/api/devices")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
async def test_create_device(client):
|
||||
resp = await client.post("/api/devices", json={
|
||||
"name": "客厅灯",
|
||||
"type": "light",
|
||||
"mqtt_topic": "home/light",
|
||||
"command_topic": "home/light/set",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "客厅灯"
|
||||
assert data["type"] == "light"
|
||||
|
||||
|
||||
async def test_get_device_detail(client):
|
||||
create_resp = await client.post("/api/devices", json={
|
||||
"name": "灯", "type": "switch", "mqtt_topic": "t"
|
||||
})
|
||||
device_id = create_resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/devices/{device_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "灯"
|
||||
|
||||
|
||||
async def test_get_device_not_found(client):
|
||||
resp = await client.get("/api/devices/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_update_device(client):
|
||||
create_resp = await client.post("/api/devices", json={
|
||||
"name": "灯", "type": "switch", "mqtt_topic": "t"
|
||||
})
|
||||
device_id = create_resp.json()["id"]
|
||||
|
||||
resp = await client.put(f"/api/devices/{device_id}", json={"name": "新名字"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "新名字"
|
||||
|
||||
|
||||
async def test_delete_device(client):
|
||||
create_resp = await client.post("/api/devices", json={
|
||||
"name": "灯", "type": "switch", "mqtt_topic": "t"
|
||||
})
|
||||
device_id = create_resp.json()["id"]
|
||||
|
||||
resp = await client.delete(f"/api/devices/{device_id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
|
||||
async def test_delete_device_not_found(client):
|
||||
resp = await client.delete("/api/devices/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
Reference in New Issue
Block a user