diff --git a/services/api/app/ws/broadcaster.py b/services/api/app/ws/broadcaster.py index fae9e00..c6cb430 100644 --- a/services/api/app/ws/broadcaster.py +++ b/services/api/app/ws/broadcaster.py @@ -3,24 +3,26 @@ from __future__ import annotations import asyncio from contextlib import suppress from datetime import timezone -from sqlalchemy.ext.asyncio import AsyncSession from app.db.session import SessionLocal from app.repositories.audio_repository import AudioRepository -from app.ws.manager import manager +from app.ws.router import manager # используем тот же manager, что и в ws/router.py def _iso_z(dt) -> str: - # dt ожидается timezone-aware return dt.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") -async def audio_live_broadcaster(poll_interval_sec: float = 0.2) -> None: +async def audio_live_broadcaster(poll_interval_sec: float = 0.05) -> None: + """ + Poll latest row and broadcast only when a NEW row appears. + Throttling per client is handled by manager.broadcast_json(). + """ last_time = None while True: try: - async with SessionLocal() as db: # AsyncSession + async with SessionLocal() as db: repo = AudioRepository(db) rows = await repo.latest(1) if rows: @@ -35,7 +37,7 @@ async def audio_live_broadcaster(poll_interval_sec: float = 0.2) -> None: } ) except Exception: - # чтобы WS не умирал из-за временных проблем с БД + # не даём таске умереть при временных проблемах БД pass await asyncio.sleep(poll_interval_sec) diff --git a/services/api/app/ws/manager.py b/services/api/app/ws/manager.py index bc3c2c9..9901b46 100644 --- a/services/api/app/ws/manager.py +++ b/services/api/app/ws/manager.py @@ -1,39 +1,65 @@ from __future__ import annotations import asyncio +import time +from dataclasses import dataclass from typing import Any + from fastapi import WebSocket +@dataclass(slots=True) +class ClientConn: + ws: WebSocket + hz: int + min_interval: float + last_sent_monotonic: float + + class ConnectionManager: def __init__(self) -> None: - self._connections: set[WebSocket] = set() + self._conns: dict[WebSocket, ClientConn] = {} self._lock = asyncio.Lock() - async def connect(self, ws: WebSocket) -> None: + async def connect(self, ws: WebSocket, hz: int) -> None: await ws.accept() + now = time.monotonic() + client = ClientConn( + ws=ws, hz=hz, min_interval=1.0 / hz, last_sent_monotonic=0.0 + ) + async with self._lock: - self._connections.add(ws) + self._conns[ws] = client + + # Небольшой лог (можно заменить на structlog/loguru) + print(f"[ws] connected client={id(ws)} hz={hz} at={now:.3f}") async def disconnect(self, ws: WebSocket) -> None: async with self._lock: - self._connections.discard(ws) + existed = ws in self._conns + self._conns.pop(ws, None) + if existed: + print(f"[ws] disconnected client={id(ws)}") async def broadcast_json(self, payload: dict[str, Any]) -> None: + now = time.monotonic() + async with self._lock: - conns = list(self._connections) + clients = list(self._conns.values()) to_remove: list[WebSocket] = [] - for ws in conns: + for c in clients: + # throttling per connection + if c.last_sent_monotonic and (now - c.last_sent_monotonic) < c.min_interval: + continue + try: - await ws.send_json(payload) + await c.ws.send_json(payload) + c.last_sent_monotonic = now except Exception: - to_remove.append(ws) + to_remove.append(c.ws) if to_remove: async with self._lock: for ws in to_remove: - self._connections.discard(ws) - - -manager = ConnectionManager() + self._conns.pop(ws, None) diff --git a/services/api/app/ws/router.py b/services/api/app/ws/router.py index 785bb11..63e9bc3 100644 --- a/services/api/app/ws/router.py +++ b/services/api/app/ws/router.py @@ -1,19 +1,52 @@ -from fastapi import APIRouter, WebSocket +from __future__ import annotations + +from fastapi import APIRouter, WebSocket, status +from fastapi.exceptions import WebSocketException from starlette.websockets import WebSocketDisconnect -from app.ws.manager import manager +from app.ws.manager import ConnectionManager router = APIRouter() +manager = ConnectionManager() + +DEFAULT_HZ = 10 +MIN_HZ = 1 +MAX_HZ = 60 + + +def _parse_hz(ws: WebSocket) -> int: + raw = ws.query_params.get("hz") + if raw is None: + return DEFAULT_HZ + try: + hz = int(raw) + except ValueError: + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, reason="Invalid 'hz' (int expected)" + ) + if hz < MIN_HZ or hz > MAX_HZ: + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason=f"Invalid 'hz' (allowed {MIN_HZ}..{MAX_HZ})", + ) + return hz @router.websocket("/ws/live") async def ws_live(ws: WebSocket) -> None: - await manager.connect(ws) + hz = _parse_hz(ws) + + await manager.connect(ws, hz=hz) try: - # Держим соединение + # Не обязательно принимать сообщения от клиента + # Но чтобы корректно ловить disconnect в некоторых клиентах - держим receive loop while True: await ws.receive_text() except WebSocketDisconnect: await manager.disconnect(ws) + except WebSocketException: + # если прилетит exception после accept — корректно удалим + await manager.disconnect(ws) + raise except Exception: await manager.disconnect(ws)