checking system…
Docs / back / src/maf/consumers/action_consumer.py · line 76
Python · 232 lines
  1"""ActionConsumer — reads ``maf:actions:out`` and publishes execution decisions.
  2
  3What it does
  4------------
  5Tails the actions stream. For each TradingAction:
  6
  71. Decodes via :class:`TradingAction.model_validate_json`.
  82. Hands the dict to a :class:`RiskGate`.
  93. Builds an :class:`ExecutionEnvelope` containing the gate's decision +
 10   echoes of the relevant fields from the action.
 114. Publishes the envelope to ``maf:executions:out``.
 12
 13Why this lives in MAF rather than trtools2
 14------------------------------------------
 15The wire contract (action shape, execution envelope shape, correlation by
 16arena_id) is owned by MAF. The risk policy + actual order placement should
 17sit in the engine — but having a *reference implementation* of the consumer
 18in MAF means anyone integrating downstream has a working starting point and
 19the tests cover the round-trip.
 20
 21Real engines that already speak their own protocol can adapt this as a
 22sidecar: subscribe to ``maf:executions:out``, push to internal order router.
 23"""
 24
 25from __future__ import annotations
 26
 27import asyncio
 28import json
 29import logging
 30import os
 31import time
 32from datetime import UTC, datetime
 33from typing import Any
 34
 35from pydantic import BaseModel, ConfigDict, Field, ValidationError
 36
 37from maf.actions.outbox import TradingAction
 38from maf.consumers.risk_gate import GateDecision, RiskGate
 39
 40logger = logging.getLogger(__name__)
 41
 42
 43DEFAULT_ACTIONS_STREAM = "maf:actions:out"
 44DEFAULT_EXECUTIONS_STREAM = "maf:executions:out"
 45
 46
 47def _utcnow_iso() -> str:
 48    return datetime.now(UTC).isoformat()
 49
 50
 51class ExecutionEnvelope(BaseModel):
 52    """The consumer's decision, correlated back to the originating action.
 53
 54    Published on ``maf:executions:out`` as ``{"data": envelope.model_dump_json()}``.
 55    Mastermind's outcome harvester ingests this to update DecisionMemory.
 56    """
 57
 58    model_config = ConfigDict(frozen=True)
 59
 60    schema_version: str = "1"
 61    arena_id: str
 62    action_correlation_id: str
 63    ticker: str
 64    verdict: str
 65    mode_requested: str
 66    mode_final: str
 67    gate_action: str       # "execute" | "queue" | "log" | "reject"
 68    size_fraction: float
 69    confidence: float
 70    reason: str
 71    consumer: str = "maf-reference"
 72    ts: str = Field(default_factory=_utcnow_iso)
 73    extras: dict[str, Any] = Field(default_factory=dict)
 74
 75
 76class ActionConsumer:
 77    """Async tail of ``maf:actions:out`` with risk-gated publish to executions."""
 78
 79    def __init__(
 80        self,
 81        risk_gate: RiskGate | None = None,
 82        *,
 83        redis_url: str | None = None,
 84        actions_stream: str = DEFAULT_ACTIONS_STREAM,
 85        executions_stream: str = DEFAULT_EXECUTIONS_STREAM,
 86        group: str = "maf-action-consumer",
 87        consumer: str = "maf-ref-1",
 88        consumer_name: str = "maf-reference",
 89    ) -> None:
 90        self.risk_gate = risk_gate or RiskGate()
 91        self.redis_url = redis_url or os.environ.get(
 92            "REDIS_URL", "redis://localhost:6379/0",
 93        )
 94        self.actions_stream = actions_stream
 95        self.executions_stream = executions_stream
 96        self.group = group
 97        self.consumer = consumer
 98        self.consumer_name = consumer_name
 99        self._redis: Any = None
100        self._stop = asyncio.Event()
101
102    async def _get_redis(self) -> Any:
103        if self._redis is None:
104            import redis.asyncio as aioredis
105            self._redis = aioredis.from_url(self.redis_url)
106        return self._redis
107
108    async def _ensure_group(self, client: Any) -> None:
109        try:
110            await client.xgroup_create(
111                self.actions_stream, self.group, id="$", mkstream=True,
112            )
113        except Exception as exc:
114            if "BUSYGROUP" in str(exc):
115                return
116            logger.warning("ActionConsumer: xgroup_create failed: %s", exc)
117
118    async def run(self) -> None:
119        """Consume actions until :meth:`stop` is called."""
120        client = await self._get_redis()
121        await self._ensure_group(client)
122        logger.info(
123            "ActionConsumer: listening on %s (group=%s) → %s",
124            self.actions_stream, self.group, self.executions_stream,
125        )
126        while not self._stop.is_set():
127            try:
128                resp = await client.xreadgroup(
129                    self.group, self.consumer,
130                    {self.actions_stream: ">"},
131                    block=5000, count=10,
132                )
133            except Exception as exc:
134                logger.warning("ActionConsumer: xreadgroup failed: %s", exc)
135                await asyncio.sleep(1.0)
136                continue
137            if not resp:
138                continue
139            for _stream, entries in resp:
140                for entry_id, fields in entries:
141                    try:
142                        await self._handle_one(client, entry_id, fields)
143                    except Exception:
144                        logger.exception("ActionConsumer: handler crashed")
145                    finally:
146                        try:
147                            await client.xack(
148                                self.actions_stream, self.group, entry_id,
149                            )
150                        except Exception as exc:
151                            logger.warning("ActionConsumer: xack failed: %s", exc)
152
153    def stop(self) -> None:
154        self._stop.set()
155
156    async def aclose(self) -> None:
157        self._stop.set()
158        if self._redis is None:
159            return
160        try:
161            ac = getattr(self._redis, "aclose", None)
162            if ac:
163                await ac()
164            else:
165                await self._redis.close()
166        except Exception:
167            pass
168
169    async def _handle_one(self, client: Any, _entry_id: Any, fields: Any) -> None:
170        raw = None
171        if isinstance(fields, dict):
172            raw = fields.get(b"data") or fields.get("data")
173        if isinstance(raw, bytes):
174            raw = raw.decode("utf-8")
175        if not raw:
176            return
177        try:
178            action = TradingAction.model_validate_json(raw)
179        except ValidationError as exc:
180            logger.warning("ActionConsumer: malformed action: %s", exc)
181            return
182
183        decision: GateDecision = self.risk_gate.evaluate(action.model_dump(mode="json"))
184
185        envelope = ExecutionEnvelope(
186            arena_id=action.arena_id,
187            action_correlation_id=action.correlation_id,
188            ticker=action.target.ticker,
189            verdict=action.verdict,
190            mode_requested=action.mode,
191            mode_final=decision.final_mode,
192            gate_action=decision.action,
193            size_fraction=decision.size_fraction,
194            confidence=action.sizing.confidence,
195            reason=decision.reason,
196            consumer=self.consumer_name,
197            extras={
198                "arena": action.arena,
199                "ensemble_score": action.sizing.ensemble_score,
200                "horizon": action.sizing.horizon,
201                "total_exposure_after_gate": (
202                    self.risk_gate.total_exposure +
203                    (decision.size_fraction if decision.action == "execute" else 0.0)
204                ),
205            },
206        )
207
208        # In a real engine this is where the order would be placed for
209        # ``execute`` and a queue/notification dispatched for ``queue``. The
210        # reference consumer just publishes the envelope and (for execute)
211        # registers a synthetic fill so the exposure tracker stays current.
212        if decision.action == "execute":
213            self.risk_gate.register_fill(action.target.ticker, decision.size_fraction)
214
215        try:
216            await client.xadd(
217                self.executions_stream,
218                {"data": envelope.model_dump_json()},
219                maxlen=50_000,
220                approximate=True,
221            )
222        except Exception:
223            logger.exception("ActionConsumer: publish failed")
224
225
226__all__ = [
227    "ActionConsumer",
228    "DEFAULT_ACTIONS_STREAM",
229    "DEFAULT_EXECUTIONS_STREAM",
230    "ExecutionEnvelope",
231]