checking system…
Docs / back / src/maf/control/client.py · line 23
Python · 161 lines
  1"""Convenience client for the control plane.
  2
  3Used by the CLI (``python -m maf trigger ...``) and tests. Wraps the wire
  4shape behind a small async API::
  5
  6    client = ControlClient()
  7    ack = await client.run_arena("trading_intelligence", target={"ticker": "NVDA"})
  8    print(ack["result"]["synthesis_verdict"])
  9"""
 10
 11from __future__ import annotations
 12
 13import asyncio
 14import json
 15import logging
 16import os
 17import uuid
 18from typing import Any
 19
 20logger = logging.getLogger(__name__)
 21
 22
 23class ControlClient:
 24    """Async client for ``maf:control:in`` + ``maf:control:out``."""
 25
 26    def __init__(
 27        self,
 28        redis_url: str | None = None,
 29        in_stream: str = "maf:control:in",
 30        out_stream: str = "maf:control:out",
 31    ) -> None:
 32        self.redis_url = redis_url or os.environ.get(
 33            "REDIS_URL", "redis://localhost:6379/0",
 34        )
 35        self.in_stream = in_stream
 36        self.out_stream = out_stream
 37        self._redis: Any = None
 38
 39    async def _get_redis(self) -> Any:
 40        if self._redis is None:
 41            import redis.asyncio as aioredis
 42
 43            self._redis = aioredis.from_url(self.redis_url)
 44        return self._redis
 45
 46    async def send(
 47        self,
 48        command: str,
 49        args: dict[str, Any] | None = None,
 50        *,
 51        correlation_id: str | None = None,
 52        wait_seconds: float = 60.0,
 53    ) -> dict[str, Any]:
 54        """Send a command and wait for its ack.
 55
 56        Waits up to ``wait_seconds`` for an ack with matching
 57        ``correlation_id`` to appear on the out stream. Returns the ack
 58        dict — or a synthetic timeout ack if nothing arrived.
 59        """
 60        client = await self._get_redis()
 61        correlation_id = correlation_id or uuid.uuid4().hex
 62        body = {
 63            "command": command,
 64            "correlation_id": correlation_id,
 65            "args": args or {},
 66        }
 67        # Capture the current $-position on the out stream so we don't miss
 68        # an ack that lands between our publish and our XREAD.
 69        try:
 70            last = await client.xrevrange(self.out_stream, count=1)
 71            last_id = (
 72                (last[0][0].decode() if isinstance(last[0][0], bytes) else str(last[0][0]))
 73                if last else "0"
 74            )
 75        except Exception:
 76            last_id = "0"
 77
 78        await client.xadd(
 79            self.in_stream,
 80            {"data": json.dumps(body, default=str)},
 81        )
 82
 83        deadline = asyncio.get_event_loop().time() + wait_seconds
 84        while True:
 85            timeout = max(0.0, deadline - asyncio.get_event_loop().time())
 86            if timeout <= 0:
 87                return {
 88                    "ok": False,
 89                    "correlation_id": correlation_id,
 90                    "command": command,
 91                    "error": f"timeout after {wait_seconds}s",
 92                    "result": None,
 93                }
 94            block_ms = int(min(timeout, 5.0) * 1000)
 95            try:
 96                resp = await client.xread(
 97                    {self.out_stream: last_id}, block=block_ms, count=50,
 98                )
 99            except Exception as exc:
100                logger.warning("ControlClient: xread failed: %s", exc)
101                await asyncio.sleep(0.5)
102                continue
103            if not resp:
104                continue
105            for _stream, entries in resp:
106                for entry_id, fields in entries:
107                    sid = (
108                        entry_id.decode("utf-8")
109                        if isinstance(entry_id, bytes)
110                        else str(entry_id)
111                    )
112                    last_id = sid
113                    raw = (
114                        fields.get(b"data") or fields.get("data")
115                        if isinstance(fields, dict) else None
116                    )
117                    if isinstance(raw, bytes):
118                        raw = raw.decode("utf-8")
119                    if not raw:
120                        continue
121                    try:
122                        ack = json.loads(raw)
123                    except (json.JSONDecodeError, TypeError):
124                        continue
125                    if ack.get("correlation_id") == correlation_id:
126                        return ack
127
128    async def aclose(self) -> None:
129        if self._redis is None:
130            return
131        try:
132            aclose = getattr(self._redis, "aclose", None)
133            if aclose is not None:
134                await aclose()
135            else:
136                close = getattr(self._redis, "close", None)
137                if close is not None:
138                    await close()
139        except Exception:
140            pass
141
142    # ── Convenience wrappers ────────────────────────────────────────────────
143
144    async def run_arena(
145        self,
146        arena: str,
147        target: dict[str, Any] | None = None,
148        wait_seconds: float = 120.0,
149    ) -> dict[str, Any]:
150        return await self.send(
151            "run_arena",
152            {"arena": arena, "target": target or {}},
153            wait_seconds=wait_seconds,
154        )
155
156    async def health(self, wait_seconds: float = 5.0) -> dict[str, Any]:
157        return await self.send("health", {}, wait_seconds=wait_seconds)
158
159
160__all__ = ["ControlClient"]