feat/frontend #1
@@ -1,12 +1,16 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""FastAPI WebSocket endpoint for live audio streaming."""
|
"""FastAPI WebSocket endpoint with rate limiting support."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
import logging
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query
|
||||||
|
|
||||||
from ws_manager import ConnectionManager
|
from ws_manager import ConnectionManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="Audio Analyzer WebSocket")
|
app = FastAPI(title="Audio Analyzer WebSocket")
|
||||||
manager = ConnectionManager()
|
manager = ConnectionManager()
|
||||||
|
|
||||||
@@ -18,16 +22,28 @@ async def health() -> dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws/live")
|
@app.websocket("/ws/live")
|
||||||
async def ws_live(websocket: WebSocket) -> None:
|
async def ws_live(
|
||||||
"""WebSocket endpoint for real-time audio data streaming."""
|
websocket: WebSocket,
|
||||||
await manager.connect(websocket)
|
hz: float = Query(default=10.0, ge=0.1, le=100.0, description="Update rate in Hz"),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for real-time audio data streaming.
|
||||||
|
|
||||||
|
Query parameters:
|
||||||
|
hz: Update rate in Hz (0.1 - 100.0, default: 10.0)
|
||||||
|
Examples:
|
||||||
|
- ws://localhost:8001/ws/live?hz=10 → 10 messages/sec
|
||||||
|
- ws://localhost:8001/ws/live?hz=1 → 1 message/sec
|
||||||
|
- ws://localhost:8001/ws/live?hz=30 → 30 messages/sec
|
||||||
|
"""
|
||||||
|
await manager.connect(websocket, rate_hz=hz)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# Keep connection alive
|
# Keep connection alive; client can send pings
|
||||||
await websocket.receive_text()
|
await websocket.receive_text()
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug("WS connection error: %s", e)
|
||||||
finally:
|
finally:
|
||||||
await manager.disconnect(websocket)
|
await manager.disconnect(websocket)
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""WebSocket connection manager for broadcasting live audio data."""
|
"""WebSocket connection manager with per-client rate limiting and median aggregation."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import statistics
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from starlette.websockets import WebSocket, WebSocketState
|
from starlette.websockets import WebSocket, WebSocketState
|
||||||
@@ -13,48 +14,162 @@ from starlette.websockets import WebSocket, WebSocketState
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThrottledClient:
|
||||||
|
"""WebSocket client wrapper with rate limiting and median aggregation."""
|
||||||
|
|
||||||
|
def __init__(self, ws: WebSocket, rate_hz: float) -> None:
|
||||||
|
self.ws = ws
|
||||||
|
self.rate_hz = rate_hz
|
||||||
|
self.interval = 1.0 / rate_hz if rate_hz > 0 else 0.0
|
||||||
|
self._last_send_time = 0.0
|
||||||
|
|
||||||
|
# Accumulator for aggregation (list of samples within current window)
|
||||||
|
self._buffer: list[dict[str, Any]] = []
|
||||||
|
self._buffer_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
self._task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start background sender task."""
|
||||||
|
self._task = asyncio.create_task(self._send_loop())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop background sender task."""
|
||||||
|
if self._task and not self._task.done():
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def enqueue(self, message: dict[str, Any]) -> None:
|
||||||
|
"""Add message to aggregation buffer."""
|
||||||
|
async with self._buffer_lock:
|
||||||
|
self._buffer.append(message)
|
||||||
|
# Limit buffer size (prevent memory issues if rate is very low)
|
||||||
|
if len(self._buffer) > 1000:
|
||||||
|
self._buffer.pop(0)
|
||||||
|
|
||||||
|
async def _send_loop(self) -> None:
|
||||||
|
"""Background task: aggregate and send messages respecting rate limit."""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Wait for next send window
|
||||||
|
if self.interval > 0:
|
||||||
|
now = asyncio.get_event_loop().time()
|
||||||
|
elapsed = now - self._last_send_time
|
||||||
|
if elapsed < self.interval:
|
||||||
|
await asyncio.sleep(self.interval - elapsed)
|
||||||
|
self._last_send_time = asyncio.get_event_loop().time()
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.01) # Minimal delay
|
||||||
|
|
||||||
|
# Get accumulated samples
|
||||||
|
async with self._buffer_lock:
|
||||||
|
if not self._buffer:
|
||||||
|
continue
|
||||||
|
|
||||||
|
samples = self._buffer[:]
|
||||||
|
self._buffer.clear()
|
||||||
|
|
||||||
|
# Aggregate: compute median
|
||||||
|
aggregated = self._aggregate_median(samples)
|
||||||
|
|
||||||
|
# Send aggregated message
|
||||||
|
try:
|
||||||
|
if self.ws.client_state == WebSocketState.CONNECTED:
|
||||||
|
payload = json.dumps(
|
||||||
|
aggregated, ensure_ascii=False, separators=(",", ":")
|
||||||
|
)
|
||||||
|
await self.ws.send_text(payload)
|
||||||
|
else:
|
||||||
|
break # Connection closed
|
||||||
|
except Exception:
|
||||||
|
break # Send failed
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _aggregate_median(samples: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Aggregate samples by computing median of rms_db and freq_hz.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples: List of raw messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregated message with median values
|
||||||
|
"""
|
||||||
|
if not samples:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if len(samples) == 1:
|
||||||
|
return samples[0]
|
||||||
|
|
||||||
|
# Extract numeric values
|
||||||
|
rms_values = [s["rms_db"] for s in samples if "rms_db" in s]
|
||||||
|
freq_values = [s["freq_hz"] for s in samples if "freq_hz" in s]
|
||||||
|
|
||||||
|
# Compute medians
|
||||||
|
rms_median = statistics.median(rms_values) if rms_values else 0.0
|
||||||
|
freq_median = statistics.median(freq_values) if freq_values else 0
|
||||||
|
|
||||||
|
# Use timestamp from last sample (most recent)
|
||||||
|
time_value = samples[-1].get("time", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"time": time_value,
|
||||||
|
"rms_db": round(rms_median, 1),
|
||||||
|
"freq_hz": int(round(freq_median)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConnectionManager:
|
class ConnectionManager:
|
||||||
"""Manages WebSocket connections and broadcasts messages to all clients."""
|
"""Manages WebSocket connections with per-client rate limiting."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._clients: set[WebSocket] = set()
|
self._clients: dict[WebSocket, ThrottledClient] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def connect(self, ws: WebSocket) -> None:
|
async def connect(self, ws: WebSocket, rate_hz: float = 10.0) -> None:
|
||||||
"""Accept and register new WebSocket connection."""
|
"""Accept and register new WebSocket connection with rate limiting."""
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
|
|
||||||
|
client = ThrottledClient(ws, rate_hz)
|
||||||
|
await client.start()
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._clients.add(ws)
|
self._clients[ws] = client
|
||||||
logger.info("WS client connected (total=%d)", len(self._clients))
|
|
||||||
|
logger.info(
|
||||||
|
"WS client connected (rate=%.1fHz, total=%d)", rate_hz, len(self._clients)
|
||||||
|
)
|
||||||
|
|
||||||
async def disconnect(self, ws: WebSocket) -> None:
|
async def disconnect(self, ws: WebSocket) -> None:
|
||||||
"""Remove WebSocket connection."""
|
"""Remove WebSocket connection and stop its sender."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._clients.discard(ws)
|
client = self._clients.pop(ws, None)
|
||||||
|
|
||||||
|
if client:
|
||||||
|
await client.stop()
|
||||||
|
|
||||||
logger.info("WS client disconnected (total=%d)", len(self._clients))
|
logger.info("WS client disconnected (total=%d)", len(self._clients))
|
||||||
|
|
||||||
async def broadcast_json(self, message: dict[str, Any]) -> None:
|
async def broadcast_json(self, message: dict[str, Any]) -> None:
|
||||||
"""Broadcast JSON message to all connected clients."""
|
"""Broadcast JSON message to all connected clients (respecting per-client rates)."""
|
||||||
if not self._clients:
|
|
||||||
return
|
|
||||||
|
|
||||||
payload = json.dumps(message, ensure_ascii=False, separators=(",", ":"))
|
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
clients = list(self._clients)
|
clients = list(self._clients.values())
|
||||||
|
|
||||||
dead: list[WebSocket] = []
|
dead: list[WebSocket] = []
|
||||||
for ws in clients:
|
for client in clients:
|
||||||
try:
|
try:
|
||||||
if ws.client_state == WebSocketState.CONNECTED:
|
await client.enqueue(message)
|
||||||
await ws.send_text(payload)
|
|
||||||
else:
|
|
||||||
dead.append(ws)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
dead.append(ws)
|
dead.append(client.ws)
|
||||||
|
|
||||||
if dead:
|
if dead:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
for ws in dead:
|
for ws in dead:
|
||||||
self._clients.discard(ws)
|
self._clients.pop(ws, None)
|
||||||
logger.debug("WS cleanup: removed %d dead clients", len(dead))
|
logger.debug("WS cleanup: removed %d dead clients", len(dead))
|
||||||
|
|||||||
Reference in New Issue
Block a user