#!/usr/bin/env python3 """WebSocket connection manager with per-client rate limiting and median aggregation.""" from __future__ import annotations import asyncio import json import logging import statistics from typing import Any from starlette.websockets import WebSocket, WebSocketState logger = logging.getLogger(__name__) class ThrottledClient: """WebSocket client wrapper with rate limiting and median aggregation.""" def __init__(self, ws: WebSocket, rate_hz: float) -> None: self.ws = ws self.rate_hz = rate_hz self.interval = 1.0 / rate_hz if rate_hz > 0 else 0.0 self._last_send_time = 0.0 # Accumulator for aggregation (list of samples within current window) self._buffer: list[dict[str, Any]] = [] self._buffer_lock = asyncio.Lock() self._task: asyncio.Task[None] | None = None async def start(self) -> None: """Start background sender task.""" self._task = asyncio.create_task(self._send_loop()) async def stop(self) -> None: """Stop background sender task.""" if self._task and not self._task.done(): self._task.cancel() try: await self._task except asyncio.CancelledError: pass async def enqueue(self, message: dict[str, Any]) -> None: """Add message to aggregation buffer.""" async with self._buffer_lock: self._buffer.append(message) # Limit buffer size (prevent memory issues if rate is very low) if len(self._buffer) > 1000: self._buffer.pop(0) async def _send_loop(self) -> None: """Background task: aggregate and send messages respecting rate limit.""" try: while True: # Wait for next send window if self.interval > 0: now = asyncio.get_event_loop().time() elapsed = now - self._last_send_time if elapsed < self.interval: await asyncio.sleep(self.interval - elapsed) self._last_send_time = asyncio.get_event_loop().time() else: await asyncio.sleep(0.01) # Minimal delay # Get accumulated samples async with self._buffer_lock: if not self._buffer: continue samples = self._buffer[:] self._buffer.clear() # Aggregate: compute median aggregated = self._aggregate_median(samples) # Send aggregated message try: if self.ws.client_state == WebSocketState.CONNECTED: payload = json.dumps( aggregated, ensure_ascii=False, separators=(",", ":") ) await self.ws.send_text(payload) else: break # Connection closed except Exception: break # Send failed except asyncio.CancelledError: pass @staticmethod def _aggregate_median(samples: list[dict[str, Any]]) -> dict[str, Any]: """ Aggregate samples by computing median of rms_db and freq_hz. Args: samples: List of raw messages Returns: Aggregated message with median values """ if not samples: return {} if len(samples) == 1: return samples[0] # Extract numeric values rms_values = [s["rms_db"] for s in samples if "rms_db" in s] freq_values = [s["freq_hz"] for s in samples if "freq_hz" in s] # Compute medians rms_median = statistics.median(rms_values) if rms_values else 0.0 freq_median = statistics.median(freq_values) if freq_values else 0 # Use timestamp from last sample (most recent) time_value = samples[-1].get("time", "") return { "time": time_value, "rms_db": round(rms_median, 1), "freq_hz": int(round(freq_median)), } class ConnectionManager: """Manages WebSocket connections with per-client rate limiting.""" def __init__(self) -> None: self._clients: dict[WebSocket, ThrottledClient] = {} self._lock = asyncio.Lock() async def connect(self, ws: WebSocket, rate_hz: float = 10.0) -> None: """Accept and register new WebSocket connection with rate limiting.""" await ws.accept() client = ThrottledClient(ws, rate_hz) await client.start() async with self._lock: self._clients[ws] = client logger.info( "WS client connected (rate=%.1fHz, total=%d)", rate_hz, len(self._clients) ) async def disconnect(self, ws: WebSocket) -> None: """Remove WebSocket connection and stop its sender.""" async with self._lock: client = self._clients.pop(ws, None) if client: await client.stop() logger.info("WS client disconnected (total=%d)", len(self._clients)) async def broadcast_json(self, message: dict[str, Any]) -> None: """Broadcast JSON message to all connected clients (respecting per-client rates).""" async with self._lock: clients = list(self._clients.values()) dead: list[WebSocket] = [] for client in clients: try: await client.enqueue(message) except Exception: dead.append(client.ws) if dead: async with self._lock: for ws in dead: self._clients.pop(ws, None) logger.debug("WS cleanup: removed %d dead clients", len(dead))