checking system…
Docs / back / src/maf/triggers/dispatcher.py · line 106
Python · 340 lines
  1"""TriggerDispatcher — fans Redis-Streams events into arena runs.
  2
  3Wire shape
  4----------
  5On startup the dispatcher collects every :class:`TriggerRule` registered
  6via :func:`register_arena_triggers`. It builds a multi-stream XREAD loop
  7over the *distinct* set of ``on_stream`` values and, for each new event,
  8runs every rule whose stream matches. Rule body:
  9
 10  ``when``         (str expression) — must be truthy in the rule's ctx
 11  ``target``       (dict template) — values may interpolate ``{payload.X}``
 12  ``cooldown_s``   (int) — minimum seconds between fires for (arena, target_key)
 13  ``action_mode``  ("auto"|"semi"|"manual") — passed to ControlInbox
 14
 15The dispatcher then XADDs to ``maf:control:in`` with command=run_arena.
 16:class:`maf.control.inbox.ControlInbox` picks it up and runs the arena.
 17
 18Cost gate
 19---------
 20Phase-2 default: **demote-only**. If a global cost-this-hour signal
 21exceeds the configured cap, ``auto`` action_modes are downgraded to
 22``semi`` so a runaway never auto-trades. The cost signal is read from
 23``maf:cost:hour`` (a simple Redis key the LLM-call accounting writes) —
 24absent or 0 means "no demotion" which is the right default until that
 25plumbing lands.
 26
 27Cooldown
 28--------
 29Per-(arena, target_key) cooldown is computed from the rendered target
 30dict. A `target_key` is the JSON-serialised target dict, so two rules
 31firing the same target inside cooldown_s are suppressed.
 32"""
 33
 34from __future__ import annotations
 35
 36import asyncio
 37import json
 38import logging
 39import os
 40import re
 41import time
 42import uuid
 43from dataclasses import dataclass, field
 44from typing import Any
 45
 46from maf.streaming import get_event_bus
 47from maf.triggers.safe_eval import SafeEvalError, safe_eval
 48
 49logger = logging.getLogger(__name__)
 50
 51
 52DEFAULT_COOLDOWN_S = 60
 53DEFAULT_CONTROL_IN = "maf:control:in"
 54COST_KEY = "maf:cost:hour"
 55
 56# {payload.symbol} or {payload.tickers[0]} — dotted/bracketed paths.
 57_TEMPLATE_PATTERN = re.compile(r"\{([^{}]+)\}")
 58
 59
 60@dataclass(frozen=True)
 61class TriggerRule:
 62    """One trigger declared in arena YAML."""
 63
 64    arena: str                                   # which arena to run
 65    on_stream: str                                # input stream to tail
 66    when: str = "True"                            # safe-eval expression
 67    target_template: dict[str, Any] = field(default_factory=dict)
 68    cooldown_s: int = DEFAULT_COOLDOWN_S
 69    action_mode: str = "manual"                   # auto|semi|manual
 70    name: str = ""                                # optional, for logs
 71
 72
 73def register_arena_triggers(
 74    arena_name: str, triggers_block: list[dict[str, Any]] | None,
 75) -> list[TriggerRule]:
 76    """Convert a YAML triggers list into :class:`TriggerRule` objects.
 77
 78    Tolerant of partial entries — missing ``when`` defaults to True (always),
 79    missing ``cooldown_s`` defaults to 60. ``target`` may be empty for
 80    arenas that take no target.
 81    """
 82    if not triggers_block:
 83        return []
 84    out: list[TriggerRule] = []
 85    for i, t in enumerate(triggers_block):
 86        try:
 87            rule = TriggerRule(
 88                arena=arena_name,
 89                on_stream=str(t["on_stream"]),
 90                when=str(t.get("when", "True")),
 91                target_template=dict(t.get("target") or {}),
 92                cooldown_s=int(t.get("cooldown_s", DEFAULT_COOLDOWN_S)),
 93                action_mode=str(t.get("action_mode", "manual")),
 94                name=str(t.get("name") or f"{arena_name}#{i}"),
 95            )
 96        except (KeyError, TypeError, ValueError) as exc:
 97            logger.warning(
 98                "register_arena_triggers: skipping %s rule %d (%s)",
 99                arena_name, i, exc,
100            )
101            continue
102        out.append(rule)
103    return out
104
105
106class TriggerDispatcher:
107    """Async multi-stream dispatcher. ``await dispatcher.run()`` to start."""
108
109    def __init__(
110        self,
111        rules: list[TriggerRule],
112        *,
113        redis_url: str | None = None,
114        control_in_stream: str = DEFAULT_CONTROL_IN,
115        cost_cap_eur_per_hour: float = 0.0,
116    ) -> None:
117        self.rules = list(rules)
118        self.redis_url = redis_url or os.environ.get(
119            "REDIS_URL", "redis://localhost:6379/0",
120        )
121        self.control_in_stream = control_in_stream
122        self.cost_cap_eur_per_hour = float(cost_cap_eur_per_hour)
123        self._redis: Any = None
124        self._stop = asyncio.Event()
125        # Last-fire timestamp per (arena, target_key). Lives in-process —
126        # the dispatcher is a singleton so this is fine. Promote to Redis
127        # if you ever run multiple dispatchers behind the same control plane.
128        self._last_fire: dict[tuple[str, str], float] = {}
129        # XREAD cursor per stream — starts at "$" (live tail).
130        self._cursors: dict[str, str] = {}
131        for r in self.rules:
132            self._cursors.setdefault(r.on_stream, "$")
133
134    async def _get_redis(self) -> Any:
135        if self._redis is None:
136            import redis.asyncio as aioredis
137            self._redis = aioredis.from_url(self.redis_url)
138        return self._redis
139
140    def stop(self) -> None:
141        self._stop.set()
142
143    async def aclose(self) -> None:
144        self._stop.set()
145        if self._redis is None:
146            return
147        try:
148            ac = getattr(self._redis, "aclose", None)
149            if ac:
150                await ac()
151            else:
152                await self._redis.close()
153        except Exception:
154            pass
155
156    # ── main loop ──────────────────────────────────────────────────────────
157
158    async def run(self) -> None:
159        if not self.rules:
160            logger.info("TriggerDispatcher: no rules configured — idle exit")
161            return
162        client = await self._get_redis()
163        streams = sorted(self._cursors)
164        logger.info(
165            "TriggerDispatcher: %d rules across %d streams: %s",
166            len(self.rules), len(streams), streams,
167        )
168        while not self._stop.is_set():
169            try:
170                resp = await client.xread(
171                    self._cursors, block=5000, count=50,
172                )
173            except Exception as exc:
174                logger.warning("TriggerDispatcher xread failed: %s", exc)
175                await asyncio.sleep(1.0)
176                continue
177            if not resp:
178                continue
179            for stream_raw, entries in resp:
180                stream = (
181                    stream_raw.decode() if isinstance(stream_raw, bytes)
182                    else str(stream_raw)
183                )
184                for entry_id, fields in entries:
185                    sid = (
186                        entry_id.decode() if isinstance(entry_id, bytes)
187                        else str(entry_id)
188                    )
189                    self._cursors[stream] = sid
190                    payload = _decode(fields)
191                    await self._process(stream, payload)
192
193    async def _process(self, stream: str, payload: dict[str, Any]) -> None:
194        """Evaluate every rule whose on_stream matches ``stream``."""
195        matching_rules = [r for r in self.rules if r.on_stream == stream]
196        if not matching_rules:
197            return
198        ctx = {"payload": payload}
199        for rule in matching_rules:
200            try:
201                hit = bool(safe_eval(rule.when, ctx))
202            except SafeEvalError as exc:
203                logger.warning(
204                    "trigger %s when=%r safe_eval error: %s",
205                    rule.name, rule.when, exc,
206                )
207                continue
208            if not hit:
209                continue
210
211            target = _render_target(rule.target_template, payload)
212            target_key = json.dumps(target, sort_keys=True, default=str)
213            now = time.time()
214            last = self._last_fire.get((rule.arena, target_key), 0.0)
215            if now - last < rule.cooldown_s:
216                logger.debug(
217                    "trigger %s suppressed by cooldown (%.1fs remaining)",
218                    rule.name, rule.cooldown_s - (now - last),
219                )
220                continue
221
222            action_mode = await self._apply_cost_gate(rule.action_mode)
223            await self._dispatch(rule, target, action_mode)
224            self._last_fire[(rule.arena, target_key)] = now
225
226    async def _dispatch(
227        self, rule: TriggerRule, target: dict[str, Any], action_mode: str,
228    ) -> None:
229        client = await self._get_redis()
230        correlation_id = uuid.uuid4().hex
231        body = {
232            "command": "run_arena",
233            "correlation_id": correlation_id,
234            "args": {
235                "arena": rule.arena,
236                "target": target,
237                "action_mode": action_mode,
238                "emit_action": True,
239                "triggered_by": rule.name,
240            },
241        }
242        try:
243            await client.xadd(
244                self.control_in_stream,
245                {"data": json.dumps(body, default=str)},
246            )
247            logger.info(
248                "trigger %s fired: arena=%s target=%s mode=%s correlation=%s",
249                rule.name, rule.arena, target, action_mode, correlation_id,
250            )
251            bus = get_event_bus()
252            await bus.publish(
253                "control.command",
254                arena=rule.arena, correlation_id=correlation_id,
255                payload={
256                    "kind": "trigger.fired",
257                    "rule": rule.name,
258                    "target": target,
259                    "action_mode": action_mode,
260                },
261            )
262        except Exception as exc:
263            logger.warning("trigger %s dispatch failed: %s", rule.name, exc)
264
265    # ── cost gate ──────────────────────────────────────────────────────────
266
267    async def _apply_cost_gate(self, requested_mode: str) -> str:
268        """Demote ``auto`` to ``semi`` when over the per-hour cap.
269
270        Reads a single Redis key ``maf:cost:hour`` (float, EUR).
271        The plumbing that writes that key is a future ticket — until then
272        the gate is a no-op (cap=0 disables it).
273        """
274        if requested_mode != "auto" or self.cost_cap_eur_per_hour <= 0:
275            return requested_mode
276        try:
277            client = await self._get_redis()
278            raw = await client.get(COST_KEY)
279            current = float(raw) if raw else 0.0
280        except Exception:
281            return requested_mode  # fail-open on cost-read errors
282        if current > self.cost_cap_eur_per_hour:
283            logger.info(
284                "cost gate: demoting auto→semi (current=%.2f cap=%.2f)",
285                current, self.cost_cap_eur_per_hour,
286            )
287            return "semi"
288        return requested_mode
289
290
291# ── helpers ────────────────────────────────────────────────────────────────
292
293
294def _decode(fields: Any) -> dict[str, Any]:
295    """Pull the ``data`` payload from a Redis-Streams entry."""
296    if not isinstance(fields, dict):
297        return {}
298    raw = fields.get(b"data") or fields.get("data")
299    if isinstance(raw, bytes):
300        raw = raw.decode("utf-8", errors="replace")
301    if not isinstance(raw, str):
302        return {}
303    try:
304        body = json.loads(raw)
305    except (json.JSONDecodeError, TypeError):
306        return {}
307    if not isinstance(body, dict):
308        return {}
309    return body
310
311
312def _render_target(template: dict[str, Any], payload: dict[str, Any]) -> dict[str, Any]:
313    """Interpolate ``{payload.X}`` placeholders in template values.
314
315    Only string values get interpolated. Non-string template values pass
316    through. The placeholder syntax supports dotted paths (``payload.x.y``)
317    and integer subscripts (``payload.list[0]``).
318    """
319    out: dict[str, Any] = {}
320    ctx = {"payload": payload}
321    for k, v in template.items():
322        if isinstance(v, str) and "{" in v and "}" in v:
323            out[k] = _render_string(v, ctx)
324        else:
325            out[k] = v
326    return out
327
328
329def _render_string(template: str, ctx: dict[str, Any]) -> str:
330    def _sub(m: re.Match[str]) -> str:
331        expr = m.group(1).strip()
332        try:
333            val = safe_eval(expr, ctx)
334        except SafeEvalError:
335            return m.group(0)  # leave untouched on error
336        if val is None:
337            return ""
338        return str(val)
339    return _TEMPLATE_PATTERN.sub(_sub, template)