feat(collector): add websocket

This commit is contained in:
2025-12-28 22:24:23 +03:00
parent 7334855ba2
commit e6f361def4
6 changed files with 252 additions and 49 deletions

View File

@@ -36,6 +36,10 @@ services:
DB_NAME: ${DB_NAME:-audio_analyzer} DB_NAME: ${DB_NAME:-audio_analyzer}
DB_USER: ${DB_USER:-postgres} DB_USER: ${DB_USER:-postgres}
DB_PASSWORD: ${DB_PASSWORD:-postgres} DB_PASSWORD: ${DB_PASSWORD:-postgres}
WS_HOST: 0.0.0.0
WS_PORT: 8000
ports:
- "8001:8000"
devices: devices:
- "${SERIAL_PORT:-/dev/ttyACM0}:${SERIAL_PORT:-/dev/ttyACM0}" - "${SERIAL_PORT:-/dev/ttyACM0}:${SERIAL_PORT:-/dev/ttyACM0}"
networks: networks:

View File

@@ -1,13 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
FR-2.3: Database Writer with Batch Processing FR-2.3: Database Writer with Batch Processing
Buffers audio metrics and writes in batches (50 records or 5 seconds) Buffers audio metrics and writes in batches
""" """
import asyncio import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional from typing import List, Optional, final
from dataclasses import dataclass from dataclasses import dataclass
import asyncpg import asyncpg
@@ -32,8 +32,8 @@ class DatabaseWriter:
""" """
BATCH_SIZE = 50 BATCH_SIZE = 50
BATCH_TIMEOUT = 5.0 # seconds BATCH_TIMEOUT: float = 5.0 # seconds
SILENCE_THRESHOLD_DB = -30.0 # dB below = silence SILENCE_THRESHOLD_DB: float = -30.0 # dB below = silence
def __init__(self, db_url: str): def __init__(self, db_url: str):
self.db_url = db_url self.db_url = db_url
@@ -124,7 +124,6 @@ class DatabaseWriter:
f"silence={is_silence} (buffer={len(self.buffer)})" f"silence={is_silence} (buffer={len(self.buffer)})"
) )
# Flush if batch size reached
if len(self.buffer) >= self.BATCH_SIZE: if len(self.buffer) >= self.BATCH_SIZE:
await self.flush() await self.flush()

View File

@@ -1,33 +1,46 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
FR-2: Audio Data Collector Service FR-2: Audio Data Collector Service with WebSocket Live Streaming
Reads audio metrics from STM32, validates, and writes to TimescaleDB Reads audio metrics from STM32, validates, writes to DB, and streams via WebSocket.
""" """
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import signal import signal
import sys import sys
from contextlib import suppress
from datetime import datetime, timezone
from typing import Callable, Optional
import uvicorn
from serial_reader import SerialReader
from audio_validator import AudioValidator from audio_validator import AudioValidator
from db_writer import DatabaseWriter from db_writer import DatabaseWriter
from protocol_parser import AudioMetrics from protocol_parser import AudioMetrics
from serial_reader import SerialReader
from ws_app import app as ws_app
from ws_app import manager
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CollectorService: def _iso_z(dt: datetime) -> str:
"""Main collector service orchestrating serial reading and database writing""" """Format datetime as ISO8601 with 'Z' suffix (UTC)."""
return dt.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
def __init__(self, serial_port: str, db_url: str, baudrate: int = 115200):
class CollectorService:
"""Main collector service: serial → validate → DB + WebSocket."""
def __init__(self, serial_port: str, db_url: str, baudrate: int = 115200) -> None:
self.serial_reader = SerialReader( self.serial_reader = SerialReader(
port=serial_port, baudrate=baudrate, on_packet=self._handle_packet port=serial_port, baudrate=baudrate, on_packet=self._handle_packet
) )
@@ -36,23 +49,94 @@ class CollectorService:
self._shutdown_event = asyncio.Event() self._shutdown_event = asyncio.Event()
async def _handle_packet(self, packet: AudioMetrics): # WebSocket broadcast queue (bounded to prevent memory issues)
self._ws_queue: asyncio.Queue[dict] = asyncio.Queue(maxsize=200)
self._ws_broadcast_task: Optional[asyncio.Task[None]] = None
# Uvicorn server for WebSocket endpoint
self._uvicorn_server: Optional[uvicorn.Server] = None
self._ws_server_task: Optional[asyncio.Task[None]] = None
def shutdown(self) -> None:
"""Trigger graceful shutdown (called from signal handler)."""
logger.info("Shutdown requested")
self._shutdown_event.set()
async def _ws_broadcast_loop(self) -> None:
"""Background task: consume queue and broadcast to WebSocket clients."""
try:
while True:
msg = await self._ws_queue.get()
try:
await manager.broadcast_json(msg)
except Exception as e:
logger.error("WS broadcast error: %s", e)
finally:
self._ws_queue.task_done()
except asyncio.CancelledError:
logger.debug("WS broadcast loop cancelled")
async def _start_ws_server(self) -> None:
"""Start uvicorn server for WebSocket endpoint."""
host = os.getenv("WS_HOST", "0.0.0.0")
port = int(os.getenv("WS_PORT", "8001"))
config = uvicorn.Config(
ws_app,
host=host,
port=port,
log_level="warning",
loop="asyncio",
access_log=False,
)
self._uvicorn_server = uvicorn.Server(config)
try:
logger.info("Starting WebSocket server on ws://%s:%d/ws/live", host, port)
await self._uvicorn_server.serve()
except SystemExit as e:
logger.error("WS server failed (port %d already in use?): %s", port, e)
self.shutdown()
except Exception as e:
logger.exception("WS server crashed: %s", e)
self.shutdown()
async def _handle_packet(self, packet: AudioMetrics) -> None:
""" """
Process received audio packet: validate and write to database. Process received audio packet: validate, write to DB, push to WebSocket.
Args: Args:
packet: Parsed audio metrics packet packet: Parsed audio metrics from STM32
""" """
# Validate packet # Validate packet
validation = self.validator.validate_packet(packet.rms_db, packet.freq_hz) validation = self.validator.validate_packet(packet.rms_db, packet.freq_hz)
if not validation.valid: if not validation.valid:
logger.warning( logger.warning(
f"Invalid packet: {validation.error} " "Invalid packet: %s (rms=%.1fdB freq=%dHz)",
f"(rms={packet.rms_db:.1f}dB freq={packet.freq_hz}Hz)" validation.error,
packet.rms_db,
packet.freq_hz,
) )
return return
# Push to WebSocket queue (non-blocking)
msg = {
"time": _iso_z(datetime.now(timezone.utc)),
"rms_db": float(packet.rms_db),
"freq_hz": int(packet.freq_hz),
}
try:
self._ws_queue.put_nowait(msg)
except asyncio.QueueFull:
# Drop oldest message if queue full
try:
_ = self._ws_queue.get_nowait()
self._ws_queue.task_done()
except asyncio.QueueEmpty:
pass
self._ws_queue.put_nowait(msg)
# Write to database # Write to database
try: try:
await self.db_writer.add_record( await self.db_writer.add_record(
@@ -61,10 +145,10 @@ class CollectorService:
freq_hz=packet.freq_hz, freq_hz=packet.freq_hz,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to add record to database: {e}") logger.error("Failed to write to database: %s", e)
async def start(self): async def start(self) -> None:
"""Start collector service""" """Start collector service: DB, WS, serial reader."""
logger.info("Starting Audio Data Collector Service") logger.info("Starting Audio Data Collector Service")
try: try:
@@ -72,6 +156,13 @@ class CollectorService:
await self.db_writer.connect() await self.db_writer.connect()
await self.db_writer.start_auto_flush() await self.db_writer.start_auto_flush()
# Start WebSocket server and broadcaster
self._ws_broadcast_task = asyncio.create_task(self._ws_broadcast_loop())
self._ws_server_task = asyncio.create_task(self._start_ws_server())
# Give uvicorn a moment to bind (avoid race on port check)
await asyncio.sleep(0.5)
# Connect to serial port # Connect to serial port
await self.serial_reader.connect() await self.serial_reader.connect()
await self.serial_reader.start_reading() await self.serial_reader.start_reading()
@@ -82,62 +173,75 @@ class CollectorService:
await self._shutdown_event.wait() await self._shutdown_event.wait()
except Exception as e: except Exception as e:
logger.error(f"Service startup failed: {e}") logger.error("Service startup failed: %s", e)
raise raise
finally: finally:
await self.stop() await self.stop()
async def stop(self): async def stop(self) -> None:
"""Stop collector service gracefully""" """Stop collector service gracefully."""
logger.info("Stopping Audio Data Collector Service") logger.info("Stopping Audio Data Collector Service")
# Disconnect serial reader # Stop serial reader
await self.serial_reader.disconnect() await self.serial_reader.disconnect()
# Stop WebSocket server
if self._uvicorn_server is not None:
self._uvicorn_server.should_exit = True
if self._ws_server_task is not None:
self._ws_server_task.cancel()
with suppress(asyncio.CancelledError, SystemExit, Exception):
await self._ws_server_task
# Stop WebSocket broadcaster
if self._ws_broadcast_task is not None:
self._ws_broadcast_task.cancel()
with suppress(asyncio.CancelledError):
await self._ws_broadcast_task
# Close database writer (flushes remaining data) # Close database writer (flushes remaining data)
await self.db_writer.close() await self.db_writer.close()
logger.info("Service stopped") logger.info("Service stopped")
def shutdown(self):
"""Trigger graceful shutdown"""
logger.info("Shutdown requested")
self._shutdown_event.set()
async def _amain() -> None:
def main(): """Async main entry point."""
"""Main entry point"""
# Read configuration from environment # Read configuration from environment
SERIAL_PORT = os.getenv("SERIAL_PORT", "/dev/ttyACM0") serial_port = os.getenv("SERIAL_PORT", "/dev/ttyACM0")
BAUDRATE = int(os.getenv("BAUDRATE", "115200")) baudrate = int(os.getenv("BAUDRATE", "115200"))
DB_HOST = os.getenv("DB_HOST", "localhost")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME", "audio_analyzer")
DB_USER = os.getenv("DB_USER", "postgres")
DB_PASSWORD = os.getenv("DB_PASSWORD", "postgres")
db_url = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" db_host = os.getenv("DB_HOST", "localhost")
db_port = os.getenv("DB_PORT", "5432")
db_name = os.getenv("DB_NAME", "audio_analyzer")
db_user = os.getenv("DB_USER", "postgres")
db_password = os.getenv("DB_PASSWORD", "postgres")
db_url = f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
# Create service # Create service
service = CollectorService( service = CollectorService(
serial_port=SERIAL_PORT, db_url=db_url, baudrate=BAUDRATE serial_port=serial_port, db_url=db_url, baudrate=baudrate
) )
# Setup signal handlers for graceful shutdown # Setup signal handlers for graceful shutdown
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
shutdown_callback: Callable[[], None] = service.shutdown
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, service.shutdown) loop.add_signal_handler(sig, shutdown_callback)
await service.start()
def main() -> None:
"""Main entry point."""
try: try:
# Run service asyncio.run(_amain())
loop.run_until_complete(service.start())
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Interrupted by user") logger.info("Interrupted by user")
except Exception as e: except Exception:
logger.error(f"Service error: {e}") logger.exception("Service error")
sys.exit(1) sys.exit(1)
finally:
loop.close()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -3,3 +3,6 @@ asyncpg
numpy numpy
pytest pytest
pytest-asyncio pytest-asyncio
fastapi
uvicorn
websockets

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
"""FastAPI WebSocket endpoint for live audio streaming."""
from __future__ import annotations
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from ws_manager import ConnectionManager
app = FastAPI(title="Audio Analyzer WebSocket")
manager = ConnectionManager()
@app.get("/health")
async def health() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "ok"}
@app.websocket("/ws/live")
async def ws_live(websocket: WebSocket) -> None:
"""WebSocket endpoint for real-time audio data streaming."""
await manager.connect(websocket)
try:
while True:
# Keep connection alive
await websocket.receive_text()
except WebSocketDisconnect:
pass
except Exception:
pass
finally:
await manager.disconnect(websocket)

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
"""WebSocket connection manager for broadcasting live audio data."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from starlette.websockets import WebSocket, WebSocketState
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages WebSocket connections and broadcasts messages to all clients."""
def __init__(self) -> None:
self._clients: set[WebSocket] = set()
self._lock = asyncio.Lock()
async def connect(self, ws: WebSocket) -> None:
"""Accept and register new WebSocket connection."""
await ws.accept()
async with self._lock:
self._clients.add(ws)
logger.info("WS client connected (total=%d)", len(self._clients))
async def disconnect(self, ws: WebSocket) -> None:
"""Remove WebSocket connection."""
async with self._lock:
self._clients.discard(ws)
logger.info("WS client disconnected (total=%d)", len(self._clients))
async def broadcast_json(self, message: dict[str, Any]) -> None:
"""Broadcast JSON message to all connected clients."""
if not self._clients:
return
payload = json.dumps(message, ensure_ascii=False, separators=(",", ":"))
async with self._lock:
clients = list(self._clients)
dead: list[WebSocket] = []
for ws in clients:
try:
if ws.client_state == WebSocketState.CONNECTED:
await ws.send_text(payload)
else:
dead.append(ws)
except Exception:
dead.append(ws)
if dead:
async with self._lock:
for ws in dead:
self._clients.discard(ws)
logger.debug("WS cleanup: removed %d dead clients", len(dead))