From 734c65253de8fc8bfb03defee8bd37d02f1d9889 Mon Sep 17 00:00:00 2001 From: Iwwww Date: Sun, 28 Dec 2025 23:07:49 +0300 Subject: [PATCH] =?UTF-8?q?feat(collector):=20=D0=B4=D0=BE=D0=B1=D0=B0?= =?UTF-8?q?=D0=B2=D0=BB=D0=B5=D0=BD=20=D1=81=D0=B1=D0=BE=D1=80=20=D0=BC?= =?UTF-8?q?=D0=B5=D0=B4=D0=B8=D0=B0=D0=BD=D0=BD=D0=BE=D0=B9=20=D1=87=D0=B0?= =?UTF-8?q?=D1=81=D1=82=D0=BE=D1=82=D1=8B=20=D0=B8=20=D0=B3=D1=80=D0=BE?= =?UTF-8?q?=D0=BC=D0=BA=D0=BE=D1=81=D1=82=D0=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/collector/ws_app.py | 32 ++++-- services/collector/ws_manager.py | 161 ++++++++++++++++++++++++++----- 2 files changed, 162 insertions(+), 31 deletions(-) diff --git a/services/collector/ws_app.py b/services/collector/ws_app.py index 0bad49f..223bb0f 100644 --- a/services/collector/ws_app.py +++ b/services/collector/ws_app.py @@ -1,12 +1,16 @@ #!/usr/bin/env python3 -"""FastAPI WebSocket endpoint for live audio streaming.""" +"""FastAPI WebSocket endpoint with rate limiting support.""" from __future__ import annotations -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +import logging + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query from ws_manager import ConnectionManager +logger = logging.getLogger(__name__) + app = FastAPI(title="Audio Analyzer WebSocket") manager = ConnectionManager() @@ -18,16 +22,28 @@ async def health() -> dict[str, str]: @app.websocket("/ws/live") -async def ws_live(websocket: WebSocket) -> None: - """WebSocket endpoint for real-time audio data streaming.""" - await manager.connect(websocket) +async def ws_live( + websocket: WebSocket, + hz: float = Query(default=10.0, ge=0.1, le=100.0, description="Update rate in Hz"), +) -> None: + """ + WebSocket endpoint for real-time audio data streaming. + + Query parameters: + hz: Update rate in Hz (0.1 - 100.0, default: 10.0) + Examples: + - ws://localhost:8001/ws/live?hz=10 → 10 messages/sec + - ws://localhost:8001/ws/live?hz=1 → 1 message/sec + - ws://localhost:8001/ws/live?hz=30 → 30 messages/sec + """ + await manager.connect(websocket, rate_hz=hz) try: while True: - # Keep connection alive + # Keep connection alive; client can send pings await websocket.receive_text() except WebSocketDisconnect: pass - except Exception: - pass + except Exception as e: + logger.debug("WS connection error: %s", e) finally: await manager.disconnect(websocket) diff --git a/services/collector/ws_manager.py b/services/collector/ws_manager.py index f6038cf..5d56450 100644 --- a/services/collector/ws_manager.py +++ b/services/collector/ws_manager.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 -"""WebSocket connection manager for broadcasting live audio data.""" +"""WebSocket connection manager with per-client rate limiting and median aggregation.""" from __future__ import annotations import asyncio import json import logging +import statistics from typing import Any from starlette.websockets import WebSocket, WebSocketState @@ -13,48 +14,162 @@ from starlette.websockets import WebSocket, WebSocketState logger = logging.getLogger(__name__) +class ThrottledClient: + """WebSocket client wrapper with rate limiting and median aggregation.""" + + def __init__(self, ws: WebSocket, rate_hz: float) -> None: + self.ws = ws + self.rate_hz = rate_hz + self.interval = 1.0 / rate_hz if rate_hz > 0 else 0.0 + self._last_send_time = 0.0 + + # Accumulator for aggregation (list of samples within current window) + self._buffer: list[dict[str, Any]] = [] + self._buffer_lock = asyncio.Lock() + + self._task: asyncio.Task[None] | None = None + + async def start(self) -> None: + """Start background sender task.""" + self._task = asyncio.create_task(self._send_loop()) + + async def stop(self) -> None: + """Stop background sender task.""" + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + async def enqueue(self, message: dict[str, Any]) -> None: + """Add message to aggregation buffer.""" + async with self._buffer_lock: + self._buffer.append(message) + # Limit buffer size (prevent memory issues if rate is very low) + if len(self._buffer) > 1000: + self._buffer.pop(0) + + async def _send_loop(self) -> None: + """Background task: aggregate and send messages respecting rate limit.""" + try: + while True: + # Wait for next send window + if self.interval > 0: + now = asyncio.get_event_loop().time() + elapsed = now - self._last_send_time + if elapsed < self.interval: + await asyncio.sleep(self.interval - elapsed) + self._last_send_time = asyncio.get_event_loop().time() + else: + await asyncio.sleep(0.01) # Minimal delay + + # Get accumulated samples + async with self._buffer_lock: + if not self._buffer: + continue + + samples = self._buffer[:] + self._buffer.clear() + + # Aggregate: compute median + aggregated = self._aggregate_median(samples) + + # Send aggregated message + try: + if self.ws.client_state == WebSocketState.CONNECTED: + payload = json.dumps( + aggregated, ensure_ascii=False, separators=(",", ":") + ) + await self.ws.send_text(payload) + else: + break # Connection closed + except Exception: + break # Send failed + + except asyncio.CancelledError: + pass + + @staticmethod + def _aggregate_median(samples: list[dict[str, Any]]) -> dict[str, Any]: + """ + Aggregate samples by computing median of rms_db and freq_hz. + + Args: + samples: List of raw messages + + Returns: + Aggregated message with median values + """ + if not samples: + return {} + + if len(samples) == 1: + return samples[0] + + # Extract numeric values + rms_values = [s["rms_db"] for s in samples if "rms_db" in s] + freq_values = [s["freq_hz"] for s in samples if "freq_hz" in s] + + # Compute medians + rms_median = statistics.median(rms_values) if rms_values else 0.0 + freq_median = statistics.median(freq_values) if freq_values else 0 + + # Use timestamp from last sample (most recent) + time_value = samples[-1].get("time", "") + + return { + "time": time_value, + "rms_db": round(rms_median, 1), + "freq_hz": int(round(freq_median)), + } + + class ConnectionManager: - """Manages WebSocket connections and broadcasts messages to all clients.""" + """Manages WebSocket connections with per-client rate limiting.""" def __init__(self) -> None: - self._clients: set[WebSocket] = set() + self._clients: dict[WebSocket, ThrottledClient] = {} self._lock = asyncio.Lock() - async def connect(self, ws: WebSocket) -> None: - """Accept and register new WebSocket connection.""" + async def connect(self, ws: WebSocket, rate_hz: float = 10.0) -> None: + """Accept and register new WebSocket connection with rate limiting.""" await ws.accept() + + client = ThrottledClient(ws, rate_hz) + await client.start() + async with self._lock: - self._clients.add(ws) - logger.info("WS client connected (total=%d)", len(self._clients)) + self._clients[ws] = client + + logger.info( + "WS client connected (rate=%.1fHz, total=%d)", rate_hz, len(self._clients) + ) async def disconnect(self, ws: WebSocket) -> None: - """Remove WebSocket connection.""" + """Remove WebSocket connection and stop its sender.""" async with self._lock: - self._clients.discard(ws) + client = self._clients.pop(ws, None) + + if client: + await client.stop() + logger.info("WS client disconnected (total=%d)", len(self._clients)) async def broadcast_json(self, message: dict[str, Any]) -> None: - """Broadcast JSON message to all connected clients.""" - if not self._clients: - return - - payload = json.dumps(message, ensure_ascii=False, separators=(",", ":")) - + """Broadcast JSON message to all connected clients (respecting per-client rates).""" async with self._lock: - clients = list(self._clients) + clients = list(self._clients.values()) dead: list[WebSocket] = [] - for ws in clients: + for client in clients: try: - if ws.client_state == WebSocketState.CONNECTED: - await ws.send_text(payload) - else: - dead.append(ws) + await client.enqueue(message) except Exception: - dead.append(ws) + dead.append(client.ws) if dead: async with self._lock: for ws in dead: - self._clients.discard(ws) + self._clients.pop(ws, None) logger.debug("WS cleanup: removed %d dead clients", len(dead))