1"""Dashboard WebSocket endpoints — live arena event feed. 2 3Two channels: 4 5``/ws/events`` 6 Global firehose. Forwards every event from the ``maf:events`` Redis Stream. 7 8``/ws/arenas/{name}`` 9 Filtered feed — only events whose ``arena`` field matches ``name``. 10 11Each WebSocket message is a JSON-serialised event from 12:class:`maf.streaming.bus.EventBus`. See the bus docstring for the shape. 13 14Lifecycle 15--------- 16On connect, the endpoint starts an async loop that XREADs from the stream 17starting at ``$`` (live tail). If the client disconnects, the loop ends and 18the redis connection closes. Errors are logged but never crash the server. 19""" 20 21from __future__ import annotations 22 23import asyncio 24import json 25import logging 26import os 27from typing import Any 28 29from fastapi import APIRouter, WebSocket, WebSocketDisconnect 30 31logger = logging.getLogger(__name__) 32 33 34router = APIRouter() 35 36 37def _redis_url() -> str: 38 """Resolve the Redis URL — prefer MAF app config, fall back to env.""" 39 from maf.dashboard.api import _maf_app 40 41 if _maf_app is not None: 42 try: 43 return _maf_app.config.redis_url 44 except Exception: 45 pass 46 return os.environ.get("REDIS_URL", "redis://localhost:6379/0") 47 48 49def _events_stream() -> str: 50 """Resolve the events stream name — prefer MAF app config.""" 51 from maf.dashboard.api import _maf_app 52 53 if _maf_app is not None: 54 try: 55 return _maf_app.config.streams.events_stream 56 except Exception: 57 pass 58 return "maf:events" 59 60 61@router.websocket("/ws/events") 62async def ws_events(websocket: WebSocket) -> None: 63 """Live tail of every MAF event.""" 64 await websocket.accept() 65 await _pump(websocket, arena_filter=None) 66 67 68@router.websocket("/ws/arenas/{name}") 69async def ws_arena(websocket: WebSocket, name: str) -> None: 70 """Live tail filtered to a single arena.""" 71 await websocket.accept() 72 await _pump(websocket, arena_filter=name) 73 74 75async def _pump(websocket: WebSocket, arena_filter: str | None) -> None: 76 """Forward decoded events from the Redis stream to the WebSocket. 77 78 Stops on client disconnect or first XREAD setup failure (e.g. no redis). 79 We *don't* keep retrying forever — if Redis is unreachable on connect, 80 we send an error frame and close cleanly so the client can show a 81 "realtime offline" banner instead of spinning silently. 82 """ 83 try: 84 import redis.asyncio as aioredis 85 except ImportError: 86 await websocket.send_json({"error": "redis-py is not installed"}) 87 await websocket.close() 88 return 89 90 client = aioredis.from_url(_redis_url()) 91 stream = _events_stream() 92 last_id = "$" # tail live; use "0" for full replay in future 93 94 # Tell the client which stream / filter we're tailing — useful for debugging. 95 try: 96 await websocket.send_json({ 97 "kind": "ws.hello", 98 "stream": stream, 99 "arena_filter": arena_filter, 100 }) 101 except Exception: 102 await _close_redis(client) 103 return 104 105 try: 106 while True: 107 try: 108 resp = await client.xread({stream: last_id}, block=5000, count=50) 109 except Exception as exc: 110 logger.warning("ws_pump xread failed: %s", exc) 111 # Send a non-fatal error frame; close on persistent failure. 112 try: 113 await websocket.send_json({ 114 "kind": "ws.error", "error": str(exc), 115 }) 116 except Exception: 117 break 118 await asyncio.sleep(1.0) 119 continue 120 121 if not resp: 122 # No new events within block window — send a heartbeat so the 123 # client can show "connected, idle" instead of "stuck". 124 try: 125 await websocket.send_json({"kind": "ws.heartbeat"}) 126 except (WebSocketDisconnect, RuntimeError): 127 break 128 continue 129 130 for _s, entries in resp: 131 for entry_id, fields in entries: 132 sid = ( 133 entry_id.decode("utf-8") 134 if isinstance(entry_id, bytes) 135 else str(entry_id) 136 ) 137 last_id = sid 138 evt = _decode(sid, fields) 139 if arena_filter and evt.get("arena") not in (arena_filter, ""): 140 continue 141 try: 142 await websocket.send_json(evt) 143 except (WebSocketDisconnect, RuntimeError): 144 await _close_redis(client) 145 return 146 finally: 147 await _close_redis(client) 148 149 150def _decode(stream_id: str, fields: Any) -> dict[str, Any]: 151 raw = None 152 if isinstance(fields, dict): 153 raw = fields.get(b"data") or fields.get("data") 154 if raw is None: 155 return {"stream_id": stream_id, "kind": "ws.malformed", "error": "no_data"} 156 if isinstance(raw, bytes): 157 raw = raw.decode("utf-8") 158 try: 159 body = json.loads(raw) 160 except (json.JSONDecodeError, TypeError) as exc: 161 return {"stream_id": stream_id, "kind": "ws.malformed", "error": str(exc)} 162 body["stream_id"] = stream_id 163 return body 164 165 166async def _close_redis(client: Any) -> None: 167 try: 168 aclose = getattr(client, "aclose", None) 169 if aclose is not None: 170 await aclose() 171 else: 172 close = getattr(client, "close", None) 173 if close is not None: 174 await close() 175 except Exception: 176 pass 177 178 179__all__ = ["router"]