From e6f361def4f8ead9e79f035e52d4c35158ce0210 Mon Sep 17 00:00:00 2001 From: Iwwww Date: Sun, 28 Dec 2025 22:24:23 +0300 Subject: [PATCH] feat(collector): add websocket --- docker-compose.yml | 4 + services/collector/db_writer.py | 9 +- services/collector/main.py | 192 +++++++++++++++++++++------- services/collector/requirements.txt | 3 + services/collector/ws_app.py | 33 +++++ services/collector/ws_manager.py | 60 +++++++++ 6 files changed, 252 insertions(+), 49 deletions(-) create mode 100644 services/collector/ws_app.py create mode 100644 services/collector/ws_manager.py diff --git a/docker-compose.yml b/docker-compose.yml index 574166f..24aac82 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -36,6 +36,10 @@ services: DB_NAME: ${DB_NAME:-audio_analyzer} DB_USER: ${DB_USER:-postgres} DB_PASSWORD: ${DB_PASSWORD:-postgres} + WS_HOST: 0.0.0.0 + WS_PORT: 8000 + ports: + - "8001:8000" devices: - "${SERIAL_PORT:-/dev/ttyACM0}:${SERIAL_PORT:-/dev/ttyACM0}" networks: diff --git a/services/collector/db_writer.py b/services/collector/db_writer.py index a0c4095..23dec9c 100644 --- a/services/collector/db_writer.py +++ b/services/collector/db_writer.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 """ FR-2.3: Database Writer with Batch Processing -Buffers audio metrics and writes in batches (50 records or 5 seconds) +Buffers audio metrics and writes in batches """ import asyncio import logging from datetime import datetime, timezone -from typing import List, Optional +from typing import List, Optional, final from dataclasses import dataclass import asyncpg @@ -32,8 +32,8 @@ class DatabaseWriter: """ BATCH_SIZE = 50 - BATCH_TIMEOUT = 5.0 # seconds - SILENCE_THRESHOLD_DB = -30.0 # dB below = silence + BATCH_TIMEOUT: float = 5.0 # seconds + SILENCE_THRESHOLD_DB: float = -30.0 # dB below = silence def __init__(self, db_url: str): self.db_url = db_url @@ -124,7 +124,6 @@ class DatabaseWriter: f"silence={is_silence} (buffer={len(self.buffer)})" ) - # Flush if batch size reached if len(self.buffer) >= self.BATCH_SIZE: await self.flush() diff --git a/services/collector/main.py b/services/collector/main.py index cb84782..d7ac5df 100644 --- a/services/collector/main.py +++ b/services/collector/main.py @@ -1,33 +1,46 @@ #!/usr/bin/env python3 """ -FR-2: Audio Data Collector Service -Reads audio metrics from STM32, validates, and writes to TimescaleDB +FR-2: Audio Data Collector Service with WebSocket Live Streaming +Reads audio metrics from STM32, validates, writes to DB, and streams via WebSocket. """ +from __future__ import annotations + import asyncio import logging import os import signal import sys +from contextlib import suppress +from datetime import datetime, timezone +from typing import Callable, Optional + +import uvicorn -from serial_reader import SerialReader from audio_validator import AudioValidator from db_writer import DatabaseWriter from protocol_parser import AudioMetrics +from serial_reader import SerialReader +from ws_app import app as ws_app +from ws_app import manager logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) - logger = logging.getLogger(__name__) -class CollectorService: - """Main collector service orchestrating serial reading and database writing""" +def _iso_z(dt: datetime) -> str: + """Format datetime as ISO8601 with 'Z' suffix (UTC).""" + return dt.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") - def __init__(self, serial_port: str, db_url: str, baudrate: int = 115200): + +class CollectorService: + """Main collector service: serial → validate → DB + WebSocket.""" + + def __init__(self, serial_port: str, db_url: str, baudrate: int = 115200) -> None: self.serial_reader = SerialReader( port=serial_port, baudrate=baudrate, on_packet=self._handle_packet ) @@ -36,23 +49,94 @@ class CollectorService: self._shutdown_event = asyncio.Event() - async def _handle_packet(self, packet: AudioMetrics): + # WebSocket broadcast queue (bounded to prevent memory issues) + self._ws_queue: asyncio.Queue[dict] = asyncio.Queue(maxsize=200) + self._ws_broadcast_task: Optional[asyncio.Task[None]] = None + + # Uvicorn server for WebSocket endpoint + self._uvicorn_server: Optional[uvicorn.Server] = None + self._ws_server_task: Optional[asyncio.Task[None]] = None + + def shutdown(self) -> None: + """Trigger graceful shutdown (called from signal handler).""" + logger.info("Shutdown requested") + self._shutdown_event.set() + + async def _ws_broadcast_loop(self) -> None: + """Background task: consume queue and broadcast to WebSocket clients.""" + try: + while True: + msg = await self._ws_queue.get() + try: + await manager.broadcast_json(msg) + except Exception as e: + logger.error("WS broadcast error: %s", e) + finally: + self._ws_queue.task_done() + except asyncio.CancelledError: + logger.debug("WS broadcast loop cancelled") + + async def _start_ws_server(self) -> None: + """Start uvicorn server for WebSocket endpoint.""" + host = os.getenv("WS_HOST", "0.0.0.0") + port = int(os.getenv("WS_PORT", "8001")) + + config = uvicorn.Config( + ws_app, + host=host, + port=port, + log_level="warning", + loop="asyncio", + access_log=False, + ) + self._uvicorn_server = uvicorn.Server(config) + + try: + logger.info("Starting WebSocket server on ws://%s:%d/ws/live", host, port) + await self._uvicorn_server.serve() + except SystemExit as e: + logger.error("WS server failed (port %d already in use?): %s", port, e) + self.shutdown() + except Exception as e: + logger.exception("WS server crashed: %s", e) + self.shutdown() + + async def _handle_packet(self, packet: AudioMetrics) -> None: """ - Process received audio packet: validate and write to database. + Process received audio packet: validate, write to DB, push to WebSocket. Args: - packet: Parsed audio metrics packet + packet: Parsed audio metrics from STM32 """ # Validate packet validation = self.validator.validate_packet(packet.rms_db, packet.freq_hz) - if not validation.valid: logger.warning( - f"Invalid packet: {validation.error} " - f"(rms={packet.rms_db:.1f}dB freq={packet.freq_hz}Hz)" + "Invalid packet: %s (rms=%.1fdB freq=%dHz)", + validation.error, + packet.rms_db, + packet.freq_hz, ) return + # Push to WebSocket queue (non-blocking) + msg = { + "time": _iso_z(datetime.now(timezone.utc)), + "rms_db": float(packet.rms_db), + "freq_hz": int(packet.freq_hz), + } + + try: + self._ws_queue.put_nowait(msg) + except asyncio.QueueFull: + # Drop oldest message if queue full + try: + _ = self._ws_queue.get_nowait() + self._ws_queue.task_done() + except asyncio.QueueEmpty: + pass + self._ws_queue.put_nowait(msg) + # Write to database try: await self.db_writer.add_record( @@ -61,10 +145,10 @@ class CollectorService: freq_hz=packet.freq_hz, ) except Exception as e: - logger.error(f"Failed to add record to database: {e}") + logger.error("Failed to write to database: %s", e) - async def start(self): - """Start collector service""" + async def start(self) -> None: + """Start collector service: DB, WS, serial reader.""" logger.info("Starting Audio Data Collector Service") try: @@ -72,6 +156,13 @@ class CollectorService: await self.db_writer.connect() await self.db_writer.start_auto_flush() + # Start WebSocket server and broadcaster + self._ws_broadcast_task = asyncio.create_task(self._ws_broadcast_loop()) + self._ws_server_task = asyncio.create_task(self._start_ws_server()) + + # Give uvicorn a moment to bind (avoid race on port check) + await asyncio.sleep(0.5) + # Connect to serial port await self.serial_reader.connect() await self.serial_reader.start_reading() @@ -82,62 +173,75 @@ class CollectorService: await self._shutdown_event.wait() except Exception as e: - logger.error(f"Service startup failed: {e}") + logger.error("Service startup failed: %s", e) raise finally: await self.stop() - async def stop(self): - """Stop collector service gracefully""" + async def stop(self) -> None: + """Stop collector service gracefully.""" logger.info("Stopping Audio Data Collector Service") - # Disconnect serial reader + # Stop serial reader await self.serial_reader.disconnect() + # Stop WebSocket server + if self._uvicorn_server is not None: + self._uvicorn_server.should_exit = True + + if self._ws_server_task is not None: + self._ws_server_task.cancel() + with suppress(asyncio.CancelledError, SystemExit, Exception): + await self._ws_server_task + + # Stop WebSocket broadcaster + if self._ws_broadcast_task is not None: + self._ws_broadcast_task.cancel() + with suppress(asyncio.CancelledError): + await self._ws_broadcast_task + # Close database writer (flushes remaining data) await self.db_writer.close() logger.info("Service stopped") - def shutdown(self): - """Trigger graceful shutdown""" - logger.info("Shutdown requested") - self._shutdown_event.set() - -def main(): - """Main entry point""" +async def _amain() -> None: + """Async main entry point.""" # Read configuration from environment - SERIAL_PORT = os.getenv("SERIAL_PORT", "/dev/ttyACM0") - BAUDRATE = int(os.getenv("BAUDRATE", "115200")) - DB_HOST = os.getenv("DB_HOST", "localhost") - DB_PORT = os.getenv("DB_PORT", "5432") - DB_NAME = os.getenv("DB_NAME", "audio_analyzer") - DB_USER = os.getenv("DB_USER", "postgres") - DB_PASSWORD = os.getenv("DB_PASSWORD", "postgres") + serial_port = os.getenv("SERIAL_PORT", "/dev/ttyACM0") + baudrate = int(os.getenv("BAUDRATE", "115200")) - db_url = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" + db_host = os.getenv("DB_HOST", "localhost") + db_port = os.getenv("DB_PORT", "5432") + db_name = os.getenv("DB_NAME", "audio_analyzer") + db_user = os.getenv("DB_USER", "postgres") + db_password = os.getenv("DB_PASSWORD", "postgres") + db_url = f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" # Create service service = CollectorService( - serial_port=SERIAL_PORT, db_url=db_url, baudrate=BAUDRATE + serial_port=serial_port, db_url=db_url, baudrate=baudrate ) # Setup signal handlers for graceful shutdown - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() + shutdown_callback: Callable[[], None] = service.shutdown for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, service.shutdown) + loop.add_signal_handler(sig, shutdown_callback) + await service.start() + + +def main() -> None: + """Main entry point.""" try: - # Run service - loop.run_until_complete(service.start()) + asyncio.run(_amain()) except KeyboardInterrupt: logger.info("Interrupted by user") - except Exception as e: - logger.error(f"Service error: {e}") + except Exception: + logger.exception("Service error") sys.exit(1) - finally: - loop.close() if __name__ == "__main__": diff --git a/services/collector/requirements.txt b/services/collector/requirements.txt index d07db00..b1f0b1c 100644 --- a/services/collector/requirements.txt +++ b/services/collector/requirements.txt @@ -3,3 +3,6 @@ asyncpg numpy pytest pytest-asyncio +fastapi +uvicorn +websockets diff --git a/services/collector/ws_app.py b/services/collector/ws_app.py new file mode 100644 index 0000000..0bad49f --- /dev/null +++ b/services/collector/ws_app.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +"""FastAPI WebSocket endpoint for live audio streaming.""" + +from __future__ import annotations + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect + +from ws_manager import ConnectionManager + +app = FastAPI(title="Audio Analyzer WebSocket") +manager = ConnectionManager() + + +@app.get("/health") +async def health() -> dict[str, str]: + """Health check endpoint.""" + return {"status": "ok"} + + +@app.websocket("/ws/live") +async def ws_live(websocket: WebSocket) -> None: + """WebSocket endpoint for real-time audio data streaming.""" + await manager.connect(websocket) + try: + while True: + # Keep connection alive + await websocket.receive_text() + except WebSocketDisconnect: + pass + except Exception: + pass + finally: + await manager.disconnect(websocket) diff --git a/services/collector/ws_manager.py b/services/collector/ws_manager.py new file mode 100644 index 0000000..f6038cf --- /dev/null +++ b/services/collector/ws_manager.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +"""WebSocket connection manager for broadcasting live audio data.""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any + +from starlette.websockets import WebSocket, WebSocketState + +logger = logging.getLogger(__name__) + + +class ConnectionManager: + """Manages WebSocket connections and broadcasts messages to all clients.""" + + def __init__(self) -> None: + self._clients: set[WebSocket] = set() + self._lock = asyncio.Lock() + + async def connect(self, ws: WebSocket) -> None: + """Accept and register new WebSocket connection.""" + await ws.accept() + async with self._lock: + self._clients.add(ws) + logger.info("WS client connected (total=%d)", len(self._clients)) + + async def disconnect(self, ws: WebSocket) -> None: + """Remove WebSocket connection.""" + async with self._lock: + self._clients.discard(ws) + logger.info("WS client disconnected (total=%d)", len(self._clients)) + + async def broadcast_json(self, message: dict[str, Any]) -> None: + """Broadcast JSON message to all connected clients.""" + if not self._clients: + return + + payload = json.dumps(message, ensure_ascii=False, separators=(",", ":")) + + async with self._lock: + clients = list(self._clients) + + dead: list[WebSocket] = [] + for ws in clients: + try: + if ws.client_state == WebSocketState.CONNECTED: + await ws.send_text(payload) + else: + dead.append(ws) + except Exception: + dead.append(ws) + + if dead: + async with self._lock: + for ws in dead: + self._clients.discard(ws) + logger.debug("WS cleanup: removed %d dead clients", len(dead))