1"""Convenience client for the control plane. 2 3Used by the CLI (``python -m maf trigger ...``) and tests. Wraps the wire 4shape behind a small async API:: 5 6 client = ControlClient() 7 ack = await client.run_arena("trading_intelligence", target={"ticker": "NVDA"}) 8 print(ack["result"]["synthesis_verdict"]) 9""" 10 11from __future__ import annotations 12 13import asyncio 14import json 15import logging 16import os 17import uuid 18from typing import Any 19 20logger = logging.getLogger(__name__) 21 22 23class ControlClient: 24 """Async client for ``maf:control:in`` + ``maf:control:out``.""" 25 26 def __init__( 27 self, 28 redis_url: str | None = None, 29 in_stream: str = "maf:control:in", 30 out_stream: str = "maf:control:out", 31 ) -> None: 32 self.redis_url = redis_url or os.environ.get( 33 "REDIS_URL", "redis://localhost:6379/0", 34 ) 35 self.in_stream = in_stream 36 self.out_stream = out_stream 37 self._redis: Any = None 38 39 async def _get_redis(self) -> Any: 40 if self._redis is None: 41 import redis.asyncio as aioredis 42 43 self._redis = aioredis.from_url(self.redis_url) 44 return self._redis 45 46 async def send( 47 self, 48 command: str, 49 args: dict[str, Any] | None = None, 50 *, 51 correlation_id: str | None = None, 52 wait_seconds: float = 60.0, 53 ) -> dict[str, Any]: 54 """Send a command and wait for its ack. 55 56 Waits up to ``wait_seconds`` for an ack with matching 57 ``correlation_id`` to appear on the out stream. Returns the ack 58 dict — or a synthetic timeout ack if nothing arrived. 59 """ 60 client = await self._get_redis() 61 correlation_id = correlation_id or uuid.uuid4().hex 62 body = { 63 "command": command, 64 "correlation_id": correlation_id, 65 "args": args or {}, 66 } 67 # Capture the current $-position on the out stream so we don't miss 68 # an ack that lands between our publish and our XREAD. 69 try: 70 last = await client.xrevrange(self.out_stream, count=1) 71 last_id = ( 72 (last[0][0].decode() if isinstance(last[0][0], bytes) else str(last[0][0])) 73 if last else "0" 74 ) 75 except Exception: 76 last_id = "0" 77 78 await client.xadd( 79 self.in_stream, 80 {"data": json.dumps(body, default=str)}, 81 ) 82 83 deadline = asyncio.get_event_loop().time() + wait_seconds 84 while True: 85 timeout = max(0.0, deadline - asyncio.get_event_loop().time()) 86 if timeout <= 0: 87 return { 88 "ok": False, 89 "correlation_id": correlation_id, 90 "command": command, 91 "error": f"timeout after {wait_seconds}s", 92 "result": None, 93 } 94 block_ms = int(min(timeout, 5.0) * 1000) 95 try: 96 resp = await client.xread( 97 {self.out_stream: last_id}, block=block_ms, count=50, 98 ) 99 except Exception as exc: 100 logger.warning("ControlClient: xread failed: %s", exc) 101 await asyncio.sleep(0.5) 102 continue 103 if not resp: 104 continue 105 for _stream, entries in resp: 106 for entry_id, fields in entries: 107 sid = ( 108 entry_id.decode("utf-8") 109 if isinstance(entry_id, bytes) 110 else str(entry_id) 111 ) 112 last_id = sid 113 raw = ( 114 fields.get(b"data") or fields.get("data") 115 if isinstance(fields, dict) else None 116 ) 117 if isinstance(raw, bytes): 118 raw = raw.decode("utf-8") 119 if not raw: 120 continue 121 try: 122 ack = json.loads(raw) 123 except (json.JSONDecodeError, TypeError): 124 continue 125 if ack.get("correlation_id") == correlation_id: 126 return ack 127 128 async def aclose(self) -> None: 129 if self._redis is None: 130 return 131 try: 132 aclose = getattr(self._redis, "aclose", None) 133 if aclose is not None: 134 await aclose() 135 else: 136 close = getattr(self._redis, "close", None) 137 if close is not None: 138 await close() 139 except Exception: 140 pass 141 142 # ── Convenience wrappers ──────────────────────────────────────────────── 143 144 async def run_arena( 145 self, 146 arena: str, 147 target: dict[str, Any] | None = None, 148 wait_seconds: float = 120.0, 149 ) -> dict[str, Any]: 150 return await self.send( 151 "run_arena", 152 {"arena": arena, "target": target or {}}, 153 wait_seconds=wait_seconds, 154 ) 155 156 async def health(self, wait_seconds: float = 5.0) -> dict[str, Any]: 157 return await self.send("health", {}, wait_seconds=wait_seconds) 158 159 160__all__ = ["ControlClient"]