feat(collector): добавлен сбор медианной частоты и громкости
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
"""WebSocket connection manager for broadcasting live audio data."""
|
||||
"""WebSocket connection manager with per-client rate limiting and median aggregation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import statistics
|
||||
from typing import Any
|
||||
|
||||
from starlette.websockets import WebSocket, WebSocketState
|
||||
@@ -13,48 +14,162 @@ from starlette.websockets import WebSocket, WebSocketState
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThrottledClient:
|
||||
"""WebSocket client wrapper with rate limiting and median aggregation."""
|
||||
|
||||
def __init__(self, ws: WebSocket, rate_hz: float) -> None:
|
||||
self.ws = ws
|
||||
self.rate_hz = rate_hz
|
||||
self.interval = 1.0 / rate_hz if rate_hz > 0 else 0.0
|
||||
self._last_send_time = 0.0
|
||||
|
||||
# Accumulator for aggregation (list of samples within current window)
|
||||
self._buffer: list[dict[str, Any]] = []
|
||||
self._buffer_lock = asyncio.Lock()
|
||||
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start background sender task."""
|
||||
self._task = asyncio.create_task(self._send_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop background sender task."""
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def enqueue(self, message: dict[str, Any]) -> None:
|
||||
"""Add message to aggregation buffer."""
|
||||
async with self._buffer_lock:
|
||||
self._buffer.append(message)
|
||||
# Limit buffer size (prevent memory issues if rate is very low)
|
||||
if len(self._buffer) > 1000:
|
||||
self._buffer.pop(0)
|
||||
|
||||
async def _send_loop(self) -> None:
|
||||
"""Background task: aggregate and send messages respecting rate limit."""
|
||||
try:
|
||||
while True:
|
||||
# Wait for next send window
|
||||
if self.interval > 0:
|
||||
now = asyncio.get_event_loop().time()
|
||||
elapsed = now - self._last_send_time
|
||||
if elapsed < self.interval:
|
||||
await asyncio.sleep(self.interval - elapsed)
|
||||
self._last_send_time = asyncio.get_event_loop().time()
|
||||
else:
|
||||
await asyncio.sleep(0.01) # Minimal delay
|
||||
|
||||
# Get accumulated samples
|
||||
async with self._buffer_lock:
|
||||
if not self._buffer:
|
||||
continue
|
||||
|
||||
samples = self._buffer[:]
|
||||
self._buffer.clear()
|
||||
|
||||
# Aggregate: compute median
|
||||
aggregated = self._aggregate_median(samples)
|
||||
|
||||
# Send aggregated message
|
||||
try:
|
||||
if self.ws.client_state == WebSocketState.CONNECTED:
|
||||
payload = json.dumps(
|
||||
aggregated, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
await self.ws.send_text(payload)
|
||||
else:
|
||||
break # Connection closed
|
||||
except Exception:
|
||||
break # Send failed
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_median(samples: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""
|
||||
Aggregate samples by computing median of rms_db and freq_hz.
|
||||
|
||||
Args:
|
||||
samples: List of raw messages
|
||||
|
||||
Returns:
|
||||
Aggregated message with median values
|
||||
"""
|
||||
if not samples:
|
||||
return {}
|
||||
|
||||
if len(samples) == 1:
|
||||
return samples[0]
|
||||
|
||||
# Extract numeric values
|
||||
rms_values = [s["rms_db"] for s in samples if "rms_db" in s]
|
||||
freq_values = [s["freq_hz"] for s in samples if "freq_hz" in s]
|
||||
|
||||
# Compute medians
|
||||
rms_median = statistics.median(rms_values) if rms_values else 0.0
|
||||
freq_median = statistics.median(freq_values) if freq_values else 0
|
||||
|
||||
# Use timestamp from last sample (most recent)
|
||||
time_value = samples[-1].get("time", "")
|
||||
|
||||
return {
|
||||
"time": time_value,
|
||||
"rms_db": round(rms_median, 1),
|
||||
"freq_hz": int(round(freq_median)),
|
||||
}
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections and broadcasts messages to all clients."""
|
||||
"""Manages WebSocket connections with per-client rate limiting."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._clients: set[WebSocket] = set()
|
||||
self._clients: dict[WebSocket, ThrottledClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, ws: WebSocket) -> None:
|
||||
"""Accept and register new WebSocket connection."""
|
||||
async def connect(self, ws: WebSocket, rate_hz: float = 10.0) -> None:
|
||||
"""Accept and register new WebSocket connection with rate limiting."""
|
||||
await ws.accept()
|
||||
|
||||
client = ThrottledClient(ws, rate_hz)
|
||||
await client.start()
|
||||
|
||||
async with self._lock:
|
||||
self._clients.add(ws)
|
||||
logger.info("WS client connected (total=%d)", len(self._clients))
|
||||
self._clients[ws] = client
|
||||
|
||||
logger.info(
|
||||
"WS client connected (rate=%.1fHz, total=%d)", rate_hz, len(self._clients)
|
||||
)
|
||||
|
||||
async def disconnect(self, ws: WebSocket) -> None:
|
||||
"""Remove WebSocket connection."""
|
||||
"""Remove WebSocket connection and stop its sender."""
|
||||
async with self._lock:
|
||||
self._clients.discard(ws)
|
||||
client = self._clients.pop(ws, None)
|
||||
|
||||
if client:
|
||||
await client.stop()
|
||||
|
||||
logger.info("WS client disconnected (total=%d)", len(self._clients))
|
||||
|
||||
async def broadcast_json(self, message: dict[str, Any]) -> None:
|
||||
"""Broadcast JSON message to all connected clients."""
|
||||
if not self._clients:
|
||||
return
|
||||
|
||||
payload = json.dumps(message, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
"""Broadcast JSON message to all connected clients (respecting per-client rates)."""
|
||||
async with self._lock:
|
||||
clients = list(self._clients)
|
||||
clients = list(self._clients.values())
|
||||
|
||||
dead: list[WebSocket] = []
|
||||
for ws in clients:
|
||||
for client in clients:
|
||||
try:
|
||||
if ws.client_state == WebSocketState.CONNECTED:
|
||||
await ws.send_text(payload)
|
||||
else:
|
||||
dead.append(ws)
|
||||
await client.enqueue(message)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
dead.append(client.ws)
|
||||
|
||||
if dead:
|
||||
async with self._lock:
|
||||
for ws in dead:
|
||||
self._clients.discard(ws)
|
||||
self._clients.pop(ws, None)
|
||||
logger.debug("WS cleanup: removed %d dead clients", len(dead))
|
||||
|
||||
Reference in New Issue
Block a user