checking system…
Docs / back / src/maf/arenas/mastermind/stream.py · line 1
Python · 252 lines
  1"""Redis-stream envelope + publisher for the ``mastermind`` arena.
  2
  3Mirrors the crowd_simulation envelope/publisher pair (T-0026 / iter-4) — but
  4arenas are independent: this module does **not** import from
  5``crowd_simulation``. Copy-with-tweaks keeps each arena's wire contract
  6self-contained so a future schema bump on one side cannot accidentally
  7ripple to the other.
  8
  9Wire shape
 10----------
 11Each XADD writes one Redis field named ``data`` whose value is
 12``envelope.model_dump_json()``. Consumers JSON-decode + ``model_validate_json``.
 13A single field per entry keeps consumers schema-naive: they don't need to know
 14which envelope key maps to which Redis field.
 15
 16Schema versioning
 17-----------------
 18:class:`DecisionEnvelope` pins ``schema_version: Literal["1"]``. The
 19:func:`decode_envelope` helper checks this **before** trying to validate the
 20rest of the body so a future v2 envelope skips cleanly with a WARN instead of
 21raising a ValidationError on every entry.
 22"""
 23
 24from __future__ import annotations
 25
 26import json as _json
 27import logging
 28import uuid
 29from collections.abc import AsyncIterator
 30from datetime import UTC, datetime
 31from typing import Any, Literal
 32
 33from pydantic import BaseModel, ConfigDict, Field, ValidationError
 34
 35from maf.arenas.mastermind.schema import Decision
 36
 37logger = logging.getLogger(__name__)
 38
 39
 40DEFAULT_STREAM_NAME = "maf:arena:mastermind:output"
 41
 42
 43def _utcnow() -> datetime:
 44    """Factory for ``published_at`` (separate so tests can monkey-patch)."""
 45    return datetime.now(UTC)
 46
 47
 48def _new_correlation_id() -> str:
 49    """Factory for ``correlation_id`` — 32-char uuid4 hex."""
 50    return uuid.uuid4().hex
 51
 52
 53class DecisionEnvelope(BaseModel):
 54    """Versioned envelope published to ``maf:arena:mastermind:output``.
 55
 56    Fields
 57    ------
 58    schema_version:
 59        Literal ``"1"``. Bump this and add a parallel decoder when the wire
 60        contract changes; v1 consumers MUST keep validating cleanly.
 61    arena:
 62        Literal ``"mastermind"``. Pinned so a misrouted envelope from another
 63        arena is rejected at decode time.
 64    decision:
 65        The full :class:`Decision` payload — recommendation, confidence,
 66        argument tree, votes, citations, etc.
 67    published_at:
 68        UTC timestamp at the moment the envelope was built.
 69    correlation_id:
 70        Stable id for downstream tracing / log joining. Defaults to a fresh
 71        uuid4 hex when callers don't pass one.
 72    meta:
 73        Operator-facing free-form bag (timings, citation counts, debate
 74        summary, model name, …). Consumers MAY ignore. Same convention as
 75        the crowd_simulation envelope's ``meta``.
 76    """
 77
 78    model_config = ConfigDict(frozen=True)
 79
 80    schema_version: Literal["1"] = "1"
 81    arena: Literal["mastermind"] = "mastermind"
 82    decision: Decision
 83    published_at: datetime = Field(default_factory=_utcnow)
 84    correlation_id: str = Field(default_factory=_new_correlation_id)
 85    meta: dict[str, Any] = Field(default_factory=dict)
 86
 87
 88# ---------------------------------------------------------------------------
 89# Redis publisher
 90# ---------------------------------------------------------------------------
 91
 92
 93async def publish_envelope(
 94    redis_client: Any,
 95    stream_name: str,
 96    envelope: DecisionEnvelope,
 97) -> str:
 98    """Publish ``envelope`` to ``stream_name`` as a single ``data`` field.
 99
100    Returns the assigned stream id (e.g. ``"1735580000000-0"``) decoded if
101    Redis returned bytes.
102    """
103    payload = envelope.model_dump_json()
104    raw_id = await redis_client.xadd(stream_name, {"data": payload})
105    if isinstance(raw_id, bytes):
106        return raw_id.decode("utf-8")
107    return str(raw_id)
108
109
110# ---------------------------------------------------------------------------
111# Decoder helpers (used by tests + a future tail CLI)
112# ---------------------------------------------------------------------------
113
114
115def decode_envelope(payload: str | bytes) -> DecisionEnvelope | None:
116    """Decode a stream payload into an envelope, or return ``None`` on a skip.
117
118    Skips (returning ``None`` after a WARN) when:
119
120    * The payload's ``schema_version`` is anything other than ``"1"`` — a
121      future v2 must not crash today's consumers.
122    * The JSON does not validate against :class:`DecisionEnvelope`.
123
124    Other decode errors (non-JSON payload) also return ``None`` after a WARN.
125    """
126    if isinstance(payload, bytes):
127        try:
128            payload = payload.decode("utf-8")
129        except UnicodeDecodeError as exc:
130            logger.warning("decode_envelope: payload is not valid UTF-8: %s", exc)
131            return None
132
133    try:
134        head = _json.loads(payload)
135    except (ValueError, TypeError) as exc:
136        logger.warning("decode_envelope: payload is not valid JSON: %s", exc)
137        return None
138    if isinstance(head, dict):
139        version = head.get("schema_version")
140        if version is not None and version != "1":
141            logger.warning(
142                "decode_envelope: unsupported schema_version=%r — skipping envelope",
143                version,
144            )
145            return None
146
147    try:
148        return DecisionEnvelope.model_validate_json(payload)
149    except ValidationError as exc:
150        logger.warning("decode_envelope: validation failed, skipping: %s", exc)
151        return None
152
153
154def _extract_data_field(fields: Any) -> str | None:
155    """Pull the ``data`` field out of a Redis xrevrange/xread message body.
156
157    Tolerates both bytes and str keys/values (``redis.asyncio`` defaults to
158    bytes; ``decode_responses=True`` clients return str).
159    """
160    if not isinstance(fields, dict):
161        return None
162    for key, val in fields.items():
163        if isinstance(key, bytes):
164            key_s = key.decode("utf-8", errors="replace")
165        else:
166            key_s = str(key)
167        if key_s != "data":
168            continue
169        if isinstance(val, bytes):
170            return val.decode("utf-8", errors="replace")
171        return str(val)
172    return None
173
174
175async def read_envelopes(
176    redis_client: Any,
177    stream_name: str,
178    count: int,
179) -> list[tuple[str, DecisionEnvelope]]:
180    """Read the latest ``count`` envelopes from ``stream_name`` (newest first).
181
182    Skips entries with unknown ``schema_version`` or invalid JSON (after a
183    WARN inside :func:`decode_envelope`).
184    """
185    raw = await redis_client.xrevrange(stream_name, count=count)
186    out: list[tuple[str, DecisionEnvelope]] = []
187    for raw_id, fields in raw:
188        stream_id = raw_id.decode("utf-8") if isinstance(raw_id, bytes) else str(raw_id)
189        payload = _extract_data_field(fields)
190        if payload is None:
191            logger.warning(
192                "read_envelopes: stream entry %s missing 'data' field — skipping",
193                stream_id,
194            )
195            continue
196        env = decode_envelope(payload)
197        if env is None:
198            continue
199        out.append((stream_id, env))
200    return out
201
202
203async def follow_envelopes(
204    redis_client: Any,
205    stream_name: str,
206    *,
207    block_ms: int = 0,
208    last_id: str = "$",
209) -> AsyncIterator[tuple[str, DecisionEnvelope]]:
210    """Yield ``(stream_id, envelope)`` pairs forever via ``XREAD BLOCK``.
211
212    ``last_id`` defaults to ``"$"`` (only entries arriving after the call).
213    The caller controls termination — interrupt via ``KeyboardInterrupt`` /
214    cancellation.
215    """
216    cursor = last_id
217    while True:
218        resp = await redis_client.xread(
219            {stream_name: cursor},
220            count=100,
221            block=block_ms,
222        )
223        if not resp:
224            continue
225        for _stream, entries in resp:
226            for raw_id, fields in entries:
227                stream_id = (
228                    raw_id.decode("utf-8") if isinstance(raw_id, bytes) else str(raw_id)
229                )
230                cursor = stream_id
231                payload = _extract_data_field(fields)
232                if payload is None:
233                    logger.warning(
234                        "follow_envelopes: stream entry %s missing 'data' field — skipping",
235                        stream_id,
236                    )
237                    continue
238                env = decode_envelope(payload)
239                if env is None:
240                    continue
241                yield stream_id, env
242
243
244__all__ = [
245    "DEFAULT_STREAM_NAME",
246    "DecisionEnvelope",
247    "decode_envelope",
248    "follow_envelopes",
249    "publish_envelope",
250    "read_envelopes",
251]