1"""TriggerDispatcher — fans Redis-Streams events into arena runs. 2 3Wire shape 4---------- 5On startup the dispatcher collects every :class:`TriggerRule` registered 6via :func:`register_arena_triggers`. It builds a multi-stream XREAD loop 7over the *distinct* set of ``on_stream`` values and, for each new event, 8runs every rule whose stream matches. Rule body: 9 10 ``when`` (str expression) — must be truthy in the rule's ctx 11 ``target`` (dict template) — values may interpolate ``{payload.X}`` 12 ``cooldown_s`` (int) — minimum seconds between fires for (arena, target_key) 13 ``action_mode`` ("auto"|"semi"|"manual") — passed to ControlInbox 14 15The dispatcher then XADDs to ``maf:control:in`` with command=run_arena. 16:class:`maf.control.inbox.ControlInbox` picks it up and runs the arena. 17 18Cost gate 19--------- 20Phase-2 default: **demote-only**. If a global cost-this-hour signal 21exceeds the configured cap, ``auto`` action_modes are downgraded to 22``semi`` so a runaway never auto-trades. The cost signal is read from 23``maf:cost:hour`` (a simple Redis key the LLM-call accounting writes) — 24absent or 0 means "no demotion" which is the right default until that 25plumbing lands. 26 27Cooldown 28-------- 29Per-(arena, target_key) cooldown is computed from the rendered target 30dict. A `target_key` is the JSON-serialised target dict, so two rules 31firing the same target inside cooldown_s are suppressed. 32""" 33 34from __future__ import annotations 35 36import asyncio 37import json 38import logging 39import os 40import re 41import time 42import uuid 43from dataclasses import dataclass, field 44from typing import Any 45 46from maf.streaming import get_event_bus 47from maf.triggers.safe_eval import SafeEvalError, safe_eval 48 49logger = logging.getLogger(__name__) 50 51 52DEFAULT_COOLDOWN_S = 60 53DEFAULT_CONTROL_IN = "maf:control:in" 54COST_KEY = "maf:cost:hour" 55 56# {payload.symbol} or {payload.tickers[0]} — dotted/bracketed paths. 57_TEMPLATE_PATTERN = re.compile(r"\{([^{}]+)\}") 58 59 60@dataclass(frozen=True) 61class TriggerRule: 62 """One trigger declared in arena YAML.""" 63 64 arena: str # which arena to run 65 on_stream: str # input stream to tail 66 when: str = "True" # safe-eval expression 67 target_template: dict[str, Any] = field(default_factory=dict) 68 cooldown_s: int = DEFAULT_COOLDOWN_S 69 action_mode: str = "manual" # auto|semi|manual 70 name: str = "" # optional, for logs 71 72 73def register_arena_triggers( 74 arena_name: str, triggers_block: list[dict[str, Any]] | None, 75) -> list[TriggerRule]: 76 """Convert a YAML triggers list into :class:`TriggerRule` objects. 77 78 Tolerant of partial entries — missing ``when`` defaults to True (always), 79 missing ``cooldown_s`` defaults to 60. ``target`` may be empty for 80 arenas that take no target. 81 """ 82 if not triggers_block: 83 return [] 84 out: list[TriggerRule] = [] 85 for i, t in enumerate(triggers_block): 86 try: 87 rule = TriggerRule( 88 arena=arena_name, 89 on_stream=str(t["on_stream"]), 90 when=str(t.get("when", "True")), 91 target_template=dict(t.get("target") or {}), 92 cooldown_s=int(t.get("cooldown_s", DEFAULT_COOLDOWN_S)), 93 action_mode=str(t.get("action_mode", "manual")), 94 name=str(t.get("name") or f"{arena_name}#{i}"), 95 ) 96 except (KeyError, TypeError, ValueError) as exc: 97 logger.warning( 98 "register_arena_triggers: skipping %s rule %d (%s)", 99 arena_name, i, exc, 100 ) 101 continue 102 out.append(rule) 103 return out 104 105 106class TriggerDispatcher: 107 """Async multi-stream dispatcher. ``await dispatcher.run()`` to start.""" 108 109 def __init__( 110 self, 111 rules: list[TriggerRule], 112 *, 113 redis_url: str | None = None, 114 control_in_stream: str = DEFAULT_CONTROL_IN, 115 cost_cap_eur_per_hour: float = 0.0, 116 ) -> None: 117 self.rules = list(rules) 118 self.redis_url = redis_url or os.environ.get( 119 "REDIS_URL", "redis://localhost:6379/0", 120 ) 121 self.control_in_stream = control_in_stream 122 self.cost_cap_eur_per_hour = float(cost_cap_eur_per_hour) 123 self._redis: Any = None 124 self._stop = asyncio.Event() 125 # Last-fire timestamp per (arena, target_key). Lives in-process — 126 # the dispatcher is a singleton so this is fine. Promote to Redis 127 # if you ever run multiple dispatchers behind the same control plane. 128 self._last_fire: dict[tuple[str, str], float] = {} 129 # XREAD cursor per stream — starts at "$" (live tail). 130 self._cursors: dict[str, str] = {} 131 for r in self.rules: 132 self._cursors.setdefault(r.on_stream, "$") 133 134 async def _get_redis(self) -> Any: 135 if self._redis is None: 136 import redis.asyncio as aioredis 137 self._redis = aioredis.from_url(self.redis_url) 138 return self._redis 139 140 def stop(self) -> None: 141 self._stop.set() 142 143 async def aclose(self) -> None: 144 self._stop.set() 145 if self._redis is None: 146 return 147 try: 148 ac = getattr(self._redis, "aclose", None) 149 if ac: 150 await ac() 151 else: 152 await self._redis.close() 153 except Exception: 154 pass 155 156 # ── main loop ────────────────────────────────────────────────────────── 157 158 async def run(self) -> None: 159 if not self.rules: 160 logger.info("TriggerDispatcher: no rules configured — idle exit") 161 return 162 client = await self._get_redis() 163 streams = sorted(self._cursors) 164 logger.info( 165 "TriggerDispatcher: %d rules across %d streams: %s", 166 len(self.rules), len(streams), streams, 167 ) 168 while not self._stop.is_set(): 169 try: 170 resp = await client.xread( 171 self._cursors, block=5000, count=50, 172 ) 173 except Exception as exc: 174 logger.warning("TriggerDispatcher xread failed: %s", exc) 175 await asyncio.sleep(1.0) 176 continue 177 if not resp: 178 continue 179 for stream_raw, entries in resp: 180 stream = ( 181 stream_raw.decode() if isinstance(stream_raw, bytes) 182 else str(stream_raw) 183 ) 184 for entry_id, fields in entries: 185 sid = ( 186 entry_id.decode() if isinstance(entry_id, bytes) 187 else str(entry_id) 188 ) 189 self._cursors[stream] = sid 190 payload = _decode(fields) 191 await self._process(stream, payload) 192 193 async def _process(self, stream: str, payload: dict[str, Any]) -> None: 194 """Evaluate every rule whose on_stream matches ``stream``.""" 195 matching_rules = [r for r in self.rules if r.on_stream == stream] 196 if not matching_rules: 197 return 198 ctx = {"payload": payload} 199 for rule in matching_rules: 200 try: 201 hit = bool(safe_eval(rule.when, ctx)) 202 except SafeEvalError as exc: 203 logger.warning( 204 "trigger %s when=%r safe_eval error: %s", 205 rule.name, rule.when, exc, 206 ) 207 continue 208 if not hit: 209 continue 210 211 target = _render_target(rule.target_template, payload) 212 target_key = json.dumps(target, sort_keys=True, default=str) 213 now = time.time() 214 last = self._last_fire.get((rule.arena, target_key), 0.0) 215 if now - last < rule.cooldown_s: 216 logger.debug( 217 "trigger %s suppressed by cooldown (%.1fs remaining)", 218 rule.name, rule.cooldown_s - (now - last), 219 ) 220 continue 221 222 action_mode = await self._apply_cost_gate(rule.action_mode) 223 await self._dispatch(rule, target, action_mode) 224 self._last_fire[(rule.arena, target_key)] = now 225 226 async def _dispatch( 227 self, rule: TriggerRule, target: dict[str, Any], action_mode: str, 228 ) -> None: 229 client = await self._get_redis() 230 correlation_id = uuid.uuid4().hex 231 body = { 232 "command": "run_arena", 233 "correlation_id": correlation_id, 234 "args": { 235 "arena": rule.arena, 236 "target": target, 237 "action_mode": action_mode, 238 "emit_action": True, 239 "triggered_by": rule.name, 240 }, 241 } 242 try: 243 await client.xadd( 244 self.control_in_stream, 245 {"data": json.dumps(body, default=str)}, 246 ) 247 logger.info( 248 "trigger %s fired: arena=%s target=%s mode=%s correlation=%s", 249 rule.name, rule.arena, target, action_mode, correlation_id, 250 ) 251 bus = get_event_bus() 252 await bus.publish( 253 "control.command", 254 arena=rule.arena, correlation_id=correlation_id, 255 payload={ 256 "kind": "trigger.fired", 257 "rule": rule.name, 258 "target": target, 259 "action_mode": action_mode, 260 }, 261 ) 262 except Exception as exc: 263 logger.warning("trigger %s dispatch failed: %s", rule.name, exc) 264 265 # ── cost gate ────────────────────────────────────────────────────────── 266 267 async def _apply_cost_gate(self, requested_mode: str) -> str: 268 """Demote ``auto`` to ``semi`` when over the per-hour cap. 269 270 Reads a single Redis key ``maf:cost:hour`` (float, EUR). 271 The plumbing that writes that key is a future ticket — until then 272 the gate is a no-op (cap=0 disables it). 273 """ 274 if requested_mode != "auto" or self.cost_cap_eur_per_hour <= 0: 275 return requested_mode 276 try: 277 client = await self._get_redis() 278 raw = await client.get(COST_KEY) 279 current = float(raw) if raw else 0.0 280 except Exception: 281 return requested_mode # fail-open on cost-read errors 282 if current > self.cost_cap_eur_per_hour: 283 logger.info( 284 "cost gate: demoting auto→semi (current=%.2f cap=%.2f)", 285 current, self.cost_cap_eur_per_hour, 286 ) 287 return "semi" 288 return requested_mode 289 290 291# ── helpers ──────────────────────────────────────────────────────────────── 292 293 294def _decode(fields: Any) -> dict[str, Any]: 295 """Pull the ``data`` payload from a Redis-Streams entry.""" 296 if not isinstance(fields, dict): 297 return {} 298 raw = fields.get(b"data") or fields.get("data") 299 if isinstance(raw, bytes): 300 raw = raw.decode("utf-8", errors="replace") 301 if not isinstance(raw, str): 302 return {} 303 try: 304 body = json.loads(raw) 305 except (json.JSONDecodeError, TypeError): 306 return {} 307 if not isinstance(body, dict): 308 return {} 309 return body 310 311 312def _render_target(template: dict[str, Any], payload: dict[str, Any]) -> dict[str, Any]: 313 """Interpolate ``{payload.X}`` placeholders in template values. 314 315 Only string values get interpolated. Non-string template values pass 316 through. The placeholder syntax supports dotted paths (``payload.x.y``) 317 and integer subscripts (``payload.list[0]``). 318 """ 319 out: dict[str, Any] = {} 320 ctx = {"payload": payload} 321 for k, v in template.items(): 322 if isinstance(v, str) and "{" in v and "}" in v: 323 out[k] = _render_string(v, ctx) 324 else: 325 out[k] = v 326 return out 327 328 329def _render_string(template: str, ctx: dict[str, Any]) -> str: 330 def _sub(m: re.Match[str]) -> str: 331 expr = m.group(1).strip() 332 try: 333 val = safe_eval(expr, ctx) 334 except SafeEvalError: 335 return m.group(0) # leave untouched on error 336 if val is None: 337 return "" 338 return str(val) 339 return _TEMPLATE_PATTERN.sub(_sub, template)