checking system…
Docs / back / src/maf/scheduler/kronos_refresher.py · line 140
Python · 444 lines
  1"""Kronos refresher — periodic worker that fills the forecast cache.
  2
  3For every watched symbol (``WatchList.members(kind="symbol")``) and every
  4configured timeframe (default: 1m, 1h), this worker:
  5
  6  1. Pulls the most recent ``lookback`` OHLCV bars from
  7     ``trtools2:bars:{timeframe}`` (already streamed by the engine).
  8  2. POSTs them to the kronos sidecar's ``/forecast`` endpoint.
  9  3. Writes the response summary + (optional) per-bar forecast to
 10     ``kronos:forecast:{symbol}:{timeframe}`` with a TTL slightly longer
 11     than the refresh cadence.
 12  4. XADDs a compact event to ``kronos:forecasts:emitted`` so trigger
 13     dispatchers can react.
 14
 15Design notes
 16------------
 17*Per-timeframe cadence.* The 1m forecast refreshes every 60 s; the 1h
 18forecast every 5 min. The cadence is *what* the freshness budget on the
 19adapter side keys off — if the refresher stalls, the adapter flags
 20``stale_kronos_forecast`` and ReplanAgent forces a re-run.
 21
 22*Per-symbol delta-emit.* We only XADD when the summary changes meaningfully
 23(direction flipped or ``prob_up`` moved > 0.05). This keeps the emit
 24stream signal-rich for trigger rules like "fire arena when prob_up jumps".
 25
 26*Backoff on sidecar errors.* If the sidecar returns 5xx or times out we
 27double the wait for that symbol with a cap, so we don't hammer a dying
 28service. A consecutive-success resets the backoff.
 29"""
 30
 31from __future__ import annotations
 32
 33import asyncio
 34import json
 35import logging
 36import os
 37import time
 38from dataclasses import dataclass, field
 39from datetime import UTC, datetime
 40from typing import Any
 41
 42import httpx
 43
 44from maf.streaming import get_event_bus
 45from maf.watch.list import KIND_SYMBOL, WatchList
 46
 47logger = logging.getLogger(__name__)
 48
 49
 50DEFAULT_SIDECAR_URL = "http://localhost:5102"
 51DEFAULT_BARS_STREAM = "trtools2:bars:{timeframe}"
 52DEFAULT_FORECAST_KEY = "kronos:forecast:{symbol}:{timeframe}"
 53DEFAULT_EMIT_STREAM = "kronos:forecasts:emitted"
 54
 55# Per-timeframe defaults. Override via the constructor's ``profiles`` arg.
 56DEFAULT_PROFILES: dict[str, dict[str, Any]] = {
 57    "1m": {
 58        "cadence_s": 60,
 59        "lookback_bars": 256,
 60        "pred_len": 60,
 61        "sample_count": 5,
 62    },
 63    "1h": {
 64        "cadence_s": 300,
 65        "lookback_bars": 256,
 66        "pred_len": 24,
 67        "sample_count": 5,
 68    },
 69}
 70
 71_MIN_DELTA_PROB_UP = 0.05  # emit event only when |Δprob_up| ≥ 5pp
 72_BACKOFF_MAX_S = 600       # cap per-symbol backoff at 10 min
 73_HTTP_TIMEOUT_S = 30
 74
 75
 76@dataclass
 77class _SymbolState:
 78    """Per-symbol tracking — last forecast emitted + backoff."""
 79
 80    last_summary: dict[str, Any] = field(default_factory=dict)
 81    backoff_s: float = 0.0
 82    next_due_at: float = 0.0
 83    consecutive_failures: int = 0
 84
 85
 86class KronosRefresher:
 87    """Long-running refresher. Run via :meth:`run`; stop with :meth:`stop`."""
 88
 89    def __init__(
 90        self,
 91        *,
 92        sidecar_url: str | None = None,
 93        redis_url: str | None = None,
 94        watch_list: WatchList | None = None,
 95        profiles: dict[str, dict[str, Any]] | None = None,
 96        forecast_key_template: str = DEFAULT_FORECAST_KEY,
 97        emit_stream: str = DEFAULT_EMIT_STREAM,
 98        bars_stream_template: str = DEFAULT_BARS_STREAM,
 99        include_forecast: bool = False,
100        max_concurrent: int = 4,
101    ) -> None:
102        self.sidecar_url = (sidecar_url
103                             or os.environ.get("KRONOS_SVC_URL")
104                             or DEFAULT_SIDECAR_URL).rstrip("/")
105        self.redis_url = redis_url or os.environ.get(
106            "REDIS_URL", "redis://localhost:6379/0",
107        )
108        self.watch = watch_list or WatchList(redis_url=self.redis_url)
109        self.profiles = dict(profiles or DEFAULT_PROFILES)
110        self.forecast_key_template = forecast_key_template
111        self.emit_stream = emit_stream
112        self.bars_stream_template = bars_stream_template
113        self.include_forecast = include_forecast
114        self._sem = asyncio.Semaphore(max(1, max_concurrent))
115        self._redis: Any = None
116        self._stop = asyncio.Event()
117        # (symbol, timeframe) → _SymbolState
118        self._state: dict[tuple[str, str], _SymbolState] = {}
119
120    # ── lifecycle ──────────────────────────────────────────────────────────
121
122    async def _get_redis(self) -> Any:
123        if self._redis is None:
124            import redis.asyncio as aioredis
125            self._redis = aioredis.from_url(self.redis_url)
126        return self._redis
127
128    async def aclose(self) -> None:
129        self._stop.set()
130        if self._redis is not None:
131            try:
132                ac = getattr(self._redis, "aclose", None)
133                if ac:
134                    await ac()
135                else:
136                    await self._redis.close()
137            except Exception:
138                pass
139
140    async def _write_heartbeat(self, watched_count: int) -> None:
141        """Drop a heartbeat key the dashboard reads to render the status pill.
142
143        Stores a JSON blob with last tick timestamp, watched symbols count,
144        and minimum cadence so the UI knows how soon to expect the next tick.
145        """
146        try:
147            client = await self._get_redis()
148            min_cadence = min(p["cadence_s"] for p in self.profiles.values())
149            payload = json.dumps({
150                "ts": time.time(),
151                "watched": watched_count,
152                "min_cadence_s": min_cadence,
153                "timeframes": list(self.profiles.keys()),
154            })
155            # Set a generous TTL so a crashed refresher is detected as down
156            # within ~3× the min cadence.
157            await client.set(
158                "maf:refresher:kronos:heartbeat",
159                payload,
160                ex=max(min_cadence * 3, 180),
161            )
162        except Exception as exc:
163            logger.debug("kronos heartbeat write failed: %s", exc)
164
165    def stop(self) -> None:
166        self._stop.set()
167
168    async def run(self) -> None:
169        """Main loop. Wakes every ``min_cadence`` and refreshes anything due."""
170        min_cadence = min(p["cadence_s"] for p in self.profiles.values())
171        logger.info(
172            "KronosRefresher started — sidecar=%s timeframes=%s min_cadence=%ds",
173            self.sidecar_url, list(self.profiles), min_cadence,
174        )
175        async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT_S) as http:
176            while not self._stop.is_set():
177                await self._tick(http)
178                # Sleep in small chunks so stop() exits quickly.
179                for _ in range(min_cadence):
180                    if self._stop.is_set():
181                        break
182                    await asyncio.sleep(1)
183
184    async def _tick(self, http: httpx.AsyncClient) -> None:
185        """One pass over the watch list × configured timeframes."""
186        try:
187            symbols = await self.watch.members(kind=KIND_SYMBOL)
188        except Exception as exc:
189            logger.warning("WatchList read failed: %s", exc)
190            await self._write_heartbeat(0)
191            return
192
193        await self._write_heartbeat(len(symbols))
194
195        if not symbols:
196            logger.debug("KronosRefresher: no watched symbols")
197            return
198
199        now = time.time()
200        tasks: list[asyncio.Task[None]] = []
201        for entry in symbols:
202            for tf, profile in self.profiles.items():
203                key = (entry.target_id.upper(), tf)
204                st = self._state.setdefault(key, _SymbolState())
205                if now < st.next_due_at:
206                    continue
207                tasks.append(asyncio.create_task(
208                    self._refresh_one(http, entry.target_id.upper(), tf, profile, st),
209                ))
210        if tasks:
211            await asyncio.gather(*tasks, return_exceptions=True)
212
213    async def _refresh_one(
214        self,
215        http: httpx.AsyncClient,
216        symbol: str,
217        timeframe: str,
218        profile: dict[str, Any],
219        state: _SymbolState,
220    ) -> None:
221        async with self._sem:
222            cadence = float(profile["cadence_s"])
223            try:
224                bars = await self._fetch_bars(symbol, timeframe, profile["lookback_bars"])
225                if len(bars) < 8:
226                    logger.debug(
227                        "skip refresh: %s %s has only %d bars (need ≥8)",
228                        symbol, timeframe, len(bars),
229                    )
230                    state.next_due_at = time.time() + cadence
231                    return
232
233                req_body = {
234                    "symbol": symbol,
235                    "timeframe": timeframe,
236                    "history": bars,
237                    "pred_len": int(profile["pred_len"]),
238                    "sample_count": int(profile["sample_count"]),
239                    "include_forecast": self.include_forecast,
240                }
241                t0 = time.perf_counter()
242                resp = await http.post(
243                    f"{self.sidecar_url}/forecast", json=req_body,
244                )
245                if resp.status_code != 200:
246                    raise RuntimeError(
247                        f"sidecar {resp.status_code}: {resp.text[:200]}",
248                    )
249                payload = resp.json()
250                elapsed = time.perf_counter() - t0
251
252                await self._write_cache(symbol, timeframe, payload, cadence)
253                await self._maybe_emit(symbol, timeframe, payload, state)
254                state.consecutive_failures = 0
255                state.backoff_s = 0.0
256                state.next_due_at = time.time() + cadence
257                logger.info(
258                    "kronos forecast refreshed: %s %s dir=%s prob_up=%.2f elapsed=%.1fs",
259                    symbol, timeframe,
260                    (payload.get("summary") or {}).get("direction", "?"),
261                    (payload.get("summary") or {}).get("prob_up", 0.0),
262                    elapsed,
263                )
264            except Exception as exc:
265                state.consecutive_failures += 1
266                # Exponential backoff capped at _BACKOFF_MAX_S
267                state.backoff_s = min(
268                    _BACKOFF_MAX_S,
269                    (state.backoff_s * 2) if state.backoff_s else cadence,
270                )
271                state.next_due_at = time.time() + state.backoff_s
272                logger.warning(
273                    "kronos refresh failed (%s %s, attempt=%d, backoff=%.0fs): %s",
274                    symbol, timeframe, state.consecutive_failures,
275                    state.backoff_s, exc,
276                )
277
278    # ── data helpers ───────────────────────────────────────────────────────
279
280    async def _fetch_bars(
281        self, symbol: str, timeframe: str, lookback: int,
282    ) -> list[dict[str, Any]]:
283        """Read the last N bars for ``symbol`` from the trtools2 stream.
284
285        The stream entries are dicts; the symbol may live under ``symbol``
286        or ``ticker``. Volume / amount are optional. We grab a generous
287        window (lookback × 4) to account for multi-symbol interleaving,
288        then filter client-side.
289        """
290        client = await self._get_redis()
291        stream = self.bars_stream_template.format(timeframe=timeframe)
292        try:
293            rows = await client.xrevrange(stream, count=lookback * 4)
294        except Exception as exc:
295            logger.warning(
296                "bars stream %s read failed: %s — kronos refresh skipped",
297                stream, exc,
298            )
299            return []
300
301        out: list[dict[str, Any]] = []
302        for _id, fields in rows:
303            bar = _decode_bar(fields)
304            if bar is None:
305                continue
306            bar_symbol = str(bar.get("symbol") or bar.get("ticker") or "").upper()
307            if bar_symbol != symbol:
308                continue
309            out.append({
310                "ts": bar.get("ts") or bar.get("timestamp") or "",
311                "open": _f(bar.get("open")),
312                "high": _f(bar.get("high")),
313                "low":  _f(bar.get("low")),
314                "close": _f(bar.get("close")),
315                "volume": _f(bar.get("volume")) if bar.get("volume") is not None else None,
316                "amount": _f(bar.get("amount")) if bar.get("amount") is not None else None,
317            })
318            if len(out) >= lookback:
319                break
320        # xrevrange returns newest-first; predictor wants oldest-first.
321        out.reverse()
322        return out
323
324    async def _write_cache(
325        self, symbol: str, timeframe: str,
326        payload: dict[str, Any], cadence_s: float,
327    ) -> None:
328        """Write the forecast summary under the canonical key with a TTL."""
329        client = await self._get_redis()
330        key = self.forecast_key_template.format(
331            symbol=symbol, timeframe=timeframe,
332        )
333        # Always carry a generated_at so the adapter's staleness check works
334        # regardless of what the sidecar wrote.
335        body = dict(payload)
336        body.setdefault("generated_at", datetime.now(UTC).isoformat())
337        body.setdefault("symbol", symbol)
338        body.setdefault("timeframe", timeframe)
339        # TTL = 6× cadence so a brief sidecar outage still lets arenas read
340        # a (now-stale) forecast and have ReplanAgent mark it accordingly.
341        await client.set(
342            key, json.dumps(body, default=str),
343            ex=int(max(60, cadence_s * 6)),
344        )
345
346    async def _maybe_emit(
347        self, symbol: str, timeframe: str,
348        payload: dict[str, Any], state: _SymbolState,
349    ) -> None:
350        """Compact emit on prob_up jump / direction flip."""
351        summary = payload.get("summary") or {}
352        prev = state.last_summary
353        # Don't use ``x or 0.5`` — a genuine 0.0 prob_up is falsy and would
354        # be silently coerced to 0.5, killing the emit on a strong bearish
355        # forecast. Use explicit ``is None`` checks instead.
356        prev_prob_raw = prev.get("prob_up")
357        new_prob_raw = summary.get("prob_up")
358        first_observation = prev_prob_raw is None
359        prev_prob = float(prev_prob_raw if prev_prob_raw is not None else 0.5)
360        new_prob = float(new_prob_raw if new_prob_raw is not None else 0.5)
361        prev_dir = str(prev.get("direction") or "")
362        new_dir = str(summary.get("direction") or "")
363
364        delta_prob_up = new_prob - prev_prob
365        flipped = (prev_dir and new_dir and prev_dir != new_dir)
366        # First observation always emits (baseline → first forecast is the
367        # most informative event), even when prob_up is exactly 0.5.
368        if not first_observation and (
369            abs(delta_prob_up) < _MIN_DELTA_PROB_UP and not flipped
370        ):
371            state.last_summary = dict(summary)
372            return
373
374        client = await self._get_redis()
375        evt = {
376            "schema_version": "1",
377            "symbol": symbol,
378            "timeframe": timeframe,
379            "direction": new_dir,
380            "prob_up": new_prob,
381            "prob_up_delta": round(delta_prob_up, 4),
382            "direction_flipped": flipped,
383            "exp_return_pct": float(summary.get("exp_return_pct") or 0.0),
384            "vol_estimate_pct": float(summary.get("vol_estimate_pct") or 0.0),
385            "model": payload.get("model", ""),
386            "horizon_min": payload.get("horizon_min"),
387            "generated_at": payload.get("generated_at", ""),
388        }
389        try:
390            await client.xadd(
391                self.emit_stream,
392                {"data": json.dumps(evt, default=str)},
393                maxlen=10_000, approximate=True,
394            )
395            bus = get_event_bus()
396            await bus.publish(
397                "system.status",
398                arena="", phase="",
399                payload={"kind": "kronos.forecast", **evt},
400            )
401        except Exception as exc:
402            logger.warning("emit failed: %s", exc)
403        state.last_summary = dict(summary)
404
405
406# ── helpers ────────────────────────────────────────────────────────────────
407
408
409def _f(v: Any) -> float:
410    try:
411        return float(v)
412    except (TypeError, ValueError):
413        return 0.0
414
415
416def _decode_bar(fields: Any) -> dict[str, Any] | None:
417    """Decode one Redis stream entry into a dict (per-field JSON-aware)."""
418    if not isinstance(fields, dict):
419        return None
420    out: dict[str, Any] = {}
421    for k, v in fields.items():
422        key = k.decode("utf-8") if isinstance(k, bytes) else str(k)
423        if isinstance(v, bytes):
424            try:
425                v = v.decode("utf-8")
426            except UnicodeDecodeError:
427                continue
428        if isinstance(v, str):
429            try:
430                v = json.loads(v)
431            except (json.JSONDecodeError, TypeError):
432                pass
433        out[key] = v
434    return out
435
436
437__all__ = [
438    "DEFAULT_EMIT_STREAM",
439    "DEFAULT_FORECAST_KEY",
440    "DEFAULT_PROFILES",
441    "DEFAULT_SIDECAR_URL",
442    "KronosRefresher",
443]