176 lines
5.7 KiB
Python
176 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
"""WebSocket connection manager with per-client rate limiting and median aggregation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import statistics
|
|
from typing import Any
|
|
|
|
from starlette.websockets import WebSocket, WebSocketState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ThrottledClient:
|
|
"""WebSocket client wrapper with rate limiting and median aggregation."""
|
|
|
|
def __init__(self, ws: WebSocket, rate_hz: float) -> None:
|
|
self.ws = ws
|
|
self.rate_hz = rate_hz
|
|
self.interval = 1.0 / rate_hz if rate_hz > 0 else 0.0
|
|
self._last_send_time = 0.0
|
|
|
|
# Accumulator for aggregation (list of samples within current window)
|
|
self._buffer: list[dict[str, Any]] = []
|
|
self._buffer_lock = asyncio.Lock()
|
|
|
|
self._task: asyncio.Task[None] | None = None
|
|
|
|
async def start(self) -> None:
|
|
"""Start background sender task."""
|
|
self._task = asyncio.create_task(self._send_loop())
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop background sender task."""
|
|
if self._task and not self._task.done():
|
|
self._task.cancel()
|
|
try:
|
|
await self._task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def enqueue(self, message: dict[str, Any]) -> None:
|
|
"""Add message to aggregation buffer."""
|
|
async with self._buffer_lock:
|
|
self._buffer.append(message)
|
|
# Limit buffer size (prevent memory issues if rate is very low)
|
|
if len(self._buffer) > 1000:
|
|
self._buffer.pop(0)
|
|
|
|
async def _send_loop(self) -> None:
|
|
"""Background task: aggregate and send messages respecting rate limit."""
|
|
try:
|
|
while True:
|
|
# Wait for next send window
|
|
if self.interval > 0:
|
|
now = asyncio.get_event_loop().time()
|
|
elapsed = now - self._last_send_time
|
|
if elapsed < self.interval:
|
|
await asyncio.sleep(self.interval - elapsed)
|
|
self._last_send_time = asyncio.get_event_loop().time()
|
|
else:
|
|
await asyncio.sleep(0.01) # Minimal delay
|
|
|
|
# Get accumulated samples
|
|
async with self._buffer_lock:
|
|
if not self._buffer:
|
|
continue
|
|
|
|
samples = self._buffer[:]
|
|
self._buffer.clear()
|
|
|
|
# Aggregate: compute median
|
|
aggregated = self._aggregate_median(samples)
|
|
|
|
# Send aggregated message
|
|
try:
|
|
if self.ws.client_state == WebSocketState.CONNECTED:
|
|
payload = json.dumps(
|
|
aggregated, ensure_ascii=False, separators=(",", ":")
|
|
)
|
|
await self.ws.send_text(payload)
|
|
else:
|
|
break # Connection closed
|
|
except Exception:
|
|
break # Send failed
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _aggregate_median(samples: list[dict[str, Any]]) -> dict[str, Any]:
|
|
"""
|
|
Aggregate samples by computing median of rms_db and freq_hz.
|
|
|
|
Args:
|
|
samples: List of raw messages
|
|
|
|
Returns:
|
|
Aggregated message with median values
|
|
"""
|
|
if not samples:
|
|
return {}
|
|
|
|
if len(samples) == 1:
|
|
return samples[0]
|
|
|
|
# Extract numeric values
|
|
rms_values = [s["rms_db"] for s in samples if "rms_db" in s]
|
|
freq_values = [s["freq_hz"] for s in samples if "freq_hz" in s]
|
|
|
|
# Compute medians
|
|
rms_median = statistics.median(rms_values) if rms_values else 0.0
|
|
freq_median = statistics.median(freq_values) if freq_values else 0
|
|
|
|
# Use timestamp from last sample (most recent)
|
|
time_value = samples[-1].get("time", "")
|
|
|
|
return {
|
|
"time": time_value,
|
|
"rms_db": round(rms_median, 1),
|
|
"freq_hz": int(round(freq_median)),
|
|
}
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections with per-client rate limiting."""
|
|
|
|
def __init__(self) -> None:
|
|
self._clients: dict[WebSocket, ThrottledClient] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def connect(self, ws: WebSocket, rate_hz: float = 10.0) -> None:
|
|
"""Accept and register new WebSocket connection with rate limiting."""
|
|
await ws.accept()
|
|
|
|
client = ThrottledClient(ws, rate_hz)
|
|
await client.start()
|
|
|
|
async with self._lock:
|
|
self._clients[ws] = client
|
|
|
|
logger.info(
|
|
"WS client connected (rate=%.1fHz, total=%d)", rate_hz, len(self._clients)
|
|
)
|
|
|
|
async def disconnect(self, ws: WebSocket) -> None:
|
|
"""Remove WebSocket connection and stop its sender."""
|
|
async with self._lock:
|
|
client = self._clients.pop(ws, None)
|
|
|
|
if client:
|
|
await client.stop()
|
|
|
|
logger.info("WS client disconnected (total=%d)", len(self._clients))
|
|
|
|
async def broadcast_json(self, message: dict[str, Any]) -> None:
|
|
"""Broadcast JSON message to all connected clients (respecting per-client rates)."""
|
|
async with self._lock:
|
|
clients = list(self._clients.values())
|
|
|
|
dead: list[WebSocket] = []
|
|
for client in clients:
|
|
try:
|
|
await client.enqueue(message)
|
|
except Exception:
|
|
dead.append(client.ws)
|
|
|
|
if dead:
|
|
async with self._lock:
|
|
for ws in dead:
|
|
self._clients.pop(ws, None)
|
|
logger.debug("WS cleanup: removed %d dead clients", len(dead))
|