1"""Kronos refresher — periodic worker that fills the forecast cache. 2 3For every watched symbol (``WatchList.members(kind="symbol")``) and every 4configured timeframe (default: 1m, 1h), this worker: 5 6 1. Pulls the most recent ``lookback`` OHLCV bars from 7 ``trtools2:bars:{timeframe}`` (already streamed by the engine). 8 2. POSTs them to the kronos sidecar's ``/forecast`` endpoint. 9 3. Writes the response summary + (optional) per-bar forecast to 10 ``kronos:forecast:{symbol}:{timeframe}`` with a TTL slightly longer 11 than the refresh cadence. 12 4. XADDs a compact event to ``kronos:forecasts:emitted`` so trigger 13 dispatchers can react. 14 15Design notes 16------------ 17*Per-timeframe cadence.* The 1m forecast refreshes every 60 s; the 1h 18forecast every 5 min. The cadence is *what* the freshness budget on the 19adapter side keys off — if the refresher stalls, the adapter flags 20``stale_kronos_forecast`` and ReplanAgent forces a re-run. 21 22*Per-symbol delta-emit.* We only XADD when the summary changes meaningfully 23(direction flipped or ``prob_up`` moved > 0.05). This keeps the emit 24stream signal-rich for trigger rules like "fire arena when prob_up jumps". 25 26*Backoff on sidecar errors.* If the sidecar returns 5xx or times out we 27double the wait for that symbol with a cap, so we don't hammer a dying 28service. A consecutive-success resets the backoff. 29""" 30 31from __future__ import annotations 32 33import asyncio 34import json 35import logging 36import os 37import time 38from dataclasses import dataclass, field 39from datetime import UTC, datetime 40from typing import Any 41 42import httpx 43 44from maf.streaming import get_event_bus 45from maf.watch.list import KIND_SYMBOL, WatchList 46 47logger = logging.getLogger(__name__) 48 49 50DEFAULT_SIDECAR_URL = "http://localhost:5102" 51DEFAULT_BARS_STREAM = "trtools2:bars:{timeframe}" 52DEFAULT_FORECAST_KEY = "kronos:forecast:{symbol}:{timeframe}" 53DEFAULT_EMIT_STREAM = "kronos:forecasts:emitted" 54 55# Per-timeframe defaults. Override via the constructor's ``profiles`` arg. 56DEFAULT_PROFILES: dict[str, dict[str, Any]] = { 57 "1m": { 58 "cadence_s": 60, 59 "lookback_bars": 256, 60 "pred_len": 60, 61 "sample_count": 5, 62 }, 63 "1h": { 64 "cadence_s": 300, 65 "lookback_bars": 256, 66 "pred_len": 24, 67 "sample_count": 5, 68 }, 69} 70 71_MIN_DELTA_PROB_UP = 0.05 # emit event only when |Δprob_up| ≥ 5pp 72_BACKOFF_MAX_S = 600 # cap per-symbol backoff at 10 min 73_HTTP_TIMEOUT_S = 30 74 75 76@dataclass 77class _SymbolState: 78 """Per-symbol tracking — last forecast emitted + backoff.""" 79 80 last_summary: dict[str, Any] = field(default_factory=dict) 81 backoff_s: float = 0.0 82 next_due_at: float = 0.0 83 consecutive_failures: int = 0 84 85 86class KronosRefresher: 87 """Long-running refresher. Run via :meth:`run`; stop with :meth:`stop`.""" 88 89 def __init__( 90 self, 91 *, 92 sidecar_url: str | None = None, 93 redis_url: str | None = None, 94 watch_list: WatchList | None = None, 95 profiles: dict[str, dict[str, Any]] | None = None, 96 forecast_key_template: str = DEFAULT_FORECAST_KEY, 97 emit_stream: str = DEFAULT_EMIT_STREAM, 98 bars_stream_template: str = DEFAULT_BARS_STREAM, 99 include_forecast: bool = False, 100 max_concurrent: int = 4, 101 ) -> None: 102 self.sidecar_url = (sidecar_url 103 or os.environ.get("KRONOS_SVC_URL") 104 or DEFAULT_SIDECAR_URL).rstrip("/") 105 self.redis_url = redis_url or os.environ.get( 106 "REDIS_URL", "redis://localhost:6379/0", 107 ) 108 self.watch = watch_list or WatchList(redis_url=self.redis_url) 109 self.profiles = dict(profiles or DEFAULT_PROFILES) 110 self.forecast_key_template = forecast_key_template 111 self.emit_stream = emit_stream 112 self.bars_stream_template = bars_stream_template 113 self.include_forecast = include_forecast 114 self._sem = asyncio.Semaphore(max(1, max_concurrent)) 115 self._redis: Any = None 116 self._stop = asyncio.Event() 117 # (symbol, timeframe) → _SymbolState 118 self._state: dict[tuple[str, str], _SymbolState] = {} 119 120 # ── lifecycle ────────────────────────────────────────────────────────── 121 122 async def _get_redis(self) -> Any: 123 if self._redis is None: 124 import redis.asyncio as aioredis 125 self._redis = aioredis.from_url(self.redis_url) 126 return self._redis 127 128 async def aclose(self) -> None: 129 self._stop.set() 130 if self._redis is not None: 131 try: 132 ac = getattr(self._redis, "aclose", None) 133 if ac: 134 await ac() 135 else: 136 await self._redis.close() 137 except Exception: 138 pass 139 140 async def _write_heartbeat(self, watched_count: int) -> None: 141 """Drop a heartbeat key the dashboard reads to render the status pill. 142 143 Stores a JSON blob with last tick timestamp, watched symbols count, 144 and minimum cadence so the UI knows how soon to expect the next tick. 145 """ 146 try: 147 client = await self._get_redis() 148 min_cadence = min(p["cadence_s"] for p in self.profiles.values()) 149 payload = json.dumps({ 150 "ts": time.time(), 151 "watched": watched_count, 152 "min_cadence_s": min_cadence, 153 "timeframes": list(self.profiles.keys()), 154 }) 155 # Set a generous TTL so a crashed refresher is detected as down 156 # within ~3× the min cadence. 157 await client.set( 158 "maf:refresher:kronos:heartbeat", 159 payload, 160 ex=max(min_cadence * 3, 180), 161 ) 162 except Exception as exc: 163 logger.debug("kronos heartbeat write failed: %s", exc) 164 165 def stop(self) -> None: 166 self._stop.set() 167 168 async def run(self) -> None: 169 """Main loop. Wakes every ``min_cadence`` and refreshes anything due.""" 170 min_cadence = min(p["cadence_s"] for p in self.profiles.values()) 171 logger.info( 172 "KronosRefresher started — sidecar=%s timeframes=%s min_cadence=%ds", 173 self.sidecar_url, list(self.profiles), min_cadence, 174 ) 175 async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT_S) as http: 176 while not self._stop.is_set(): 177 await self._tick(http) 178 # Sleep in small chunks so stop() exits quickly. 179 for _ in range(min_cadence): 180 if self._stop.is_set(): 181 break 182 await asyncio.sleep(1) 183 184 async def _tick(self, http: httpx.AsyncClient) -> None: 185 """One pass over the watch list × configured timeframes.""" 186 try: 187 symbols = await self.watch.members(kind=KIND_SYMBOL) 188 except Exception as exc: 189 logger.warning("WatchList read failed: %s", exc) 190 await self._write_heartbeat(0) 191 return 192 193 await self._write_heartbeat(len(symbols)) 194 195 if not symbols: 196 logger.debug("KronosRefresher: no watched symbols") 197 return 198 199 now = time.time() 200 tasks: list[asyncio.Task[None]] = [] 201 for entry in symbols: 202 for tf, profile in self.profiles.items(): 203 key = (entry.target_id.upper(), tf) 204 st = self._state.setdefault(key, _SymbolState()) 205 if now < st.next_due_at: 206 continue 207 tasks.append(asyncio.create_task( 208 self._refresh_one(http, entry.target_id.upper(), tf, profile, st), 209 )) 210 if tasks: 211 await asyncio.gather(*tasks, return_exceptions=True) 212 213 async def _refresh_one( 214 self, 215 http: httpx.AsyncClient, 216 symbol: str, 217 timeframe: str, 218 profile: dict[str, Any], 219 state: _SymbolState, 220 ) -> None: 221 async with self._sem: 222 cadence = float(profile["cadence_s"]) 223 try: 224 bars = await self._fetch_bars(symbol, timeframe, profile["lookback_bars"]) 225 if len(bars) < 8: 226 logger.debug( 227 "skip refresh: %s %s has only %d bars (need ≥8)", 228 symbol, timeframe, len(bars), 229 ) 230 state.next_due_at = time.time() + cadence 231 return 232 233 req_body = { 234 "symbol": symbol, 235 "timeframe": timeframe, 236 "history": bars, 237 "pred_len": int(profile["pred_len"]), 238 "sample_count": int(profile["sample_count"]), 239 "include_forecast": self.include_forecast, 240 } 241 t0 = time.perf_counter() 242 resp = await http.post( 243 f"{self.sidecar_url}/forecast", json=req_body, 244 ) 245 if resp.status_code != 200: 246 raise RuntimeError( 247 f"sidecar {resp.status_code}: {resp.text[:200]}", 248 ) 249 payload = resp.json() 250 elapsed = time.perf_counter() - t0 251 252 await self._write_cache(symbol, timeframe, payload, cadence) 253 await self._maybe_emit(symbol, timeframe, payload, state) 254 state.consecutive_failures = 0 255 state.backoff_s = 0.0 256 state.next_due_at = time.time() + cadence 257 logger.info( 258 "kronos forecast refreshed: %s %s dir=%s prob_up=%.2f elapsed=%.1fs", 259 symbol, timeframe, 260 (payload.get("summary") or {}).get("direction", "?"), 261 (payload.get("summary") or {}).get("prob_up", 0.0), 262 elapsed, 263 ) 264 except Exception as exc: 265 state.consecutive_failures += 1 266 # Exponential backoff capped at _BACKOFF_MAX_S 267 state.backoff_s = min( 268 _BACKOFF_MAX_S, 269 (state.backoff_s * 2) if state.backoff_s else cadence, 270 ) 271 state.next_due_at = time.time() + state.backoff_s 272 logger.warning( 273 "kronos refresh failed (%s %s, attempt=%d, backoff=%.0fs): %s", 274 symbol, timeframe, state.consecutive_failures, 275 state.backoff_s, exc, 276 ) 277 278 # ── data helpers ─────────────────────────────────────────────────────── 279 280 async def _fetch_bars( 281 self, symbol: str, timeframe: str, lookback: int, 282 ) -> list[dict[str, Any]]: 283 """Read the last N bars for ``symbol`` from the trtools2 stream. 284 285 The stream entries are dicts; the symbol may live under ``symbol`` 286 or ``ticker``. Volume / amount are optional. We grab a generous 287 window (lookback × 4) to account for multi-symbol interleaving, 288 then filter client-side. 289 """ 290 client = await self._get_redis() 291 stream = self.bars_stream_template.format(timeframe=timeframe) 292 try: 293 rows = await client.xrevrange(stream, count=lookback * 4) 294 except Exception as exc: 295 logger.warning( 296 "bars stream %s read failed: %s — kronos refresh skipped", 297 stream, exc, 298 ) 299 return [] 300 301 out: list[dict[str, Any]] = [] 302 for _id, fields in rows: 303 bar = _decode_bar(fields) 304 if bar is None: 305 continue 306 bar_symbol = str(bar.get("symbol") or bar.get("ticker") or "").upper() 307 if bar_symbol != symbol: 308 continue 309 out.append({ 310 "ts": bar.get("ts") or bar.get("timestamp") or "", 311 "open": _f(bar.get("open")), 312 "high": _f(bar.get("high")), 313 "low": _f(bar.get("low")), 314 "close": _f(bar.get("close")), 315 "volume": _f(bar.get("volume")) if bar.get("volume") is not None else None, 316 "amount": _f(bar.get("amount")) if bar.get("amount") is not None else None, 317 }) 318 if len(out) >= lookback: 319 break 320 # xrevrange returns newest-first; predictor wants oldest-first. 321 out.reverse() 322 return out 323 324 async def _write_cache( 325 self, symbol: str, timeframe: str, 326 payload: dict[str, Any], cadence_s: float, 327 ) -> None: 328 """Write the forecast summary under the canonical key with a TTL.""" 329 client = await self._get_redis() 330 key = self.forecast_key_template.format( 331 symbol=symbol, timeframe=timeframe, 332 ) 333 # Always carry a generated_at so the adapter's staleness check works 334 # regardless of what the sidecar wrote. 335 body = dict(payload) 336 body.setdefault("generated_at", datetime.now(UTC).isoformat()) 337 body.setdefault("symbol", symbol) 338 body.setdefault("timeframe", timeframe) 339 # TTL = 6× cadence so a brief sidecar outage still lets arenas read 340 # a (now-stale) forecast and have ReplanAgent mark it accordingly. 341 await client.set( 342 key, json.dumps(body, default=str), 343 ex=int(max(60, cadence_s * 6)), 344 ) 345 346 async def _maybe_emit( 347 self, symbol: str, timeframe: str, 348 payload: dict[str, Any], state: _SymbolState, 349 ) -> None: 350 """Compact emit on prob_up jump / direction flip.""" 351 summary = payload.get("summary") or {} 352 prev = state.last_summary 353 # Don't use ``x or 0.5`` — a genuine 0.0 prob_up is falsy and would 354 # be silently coerced to 0.5, killing the emit on a strong bearish 355 # forecast. Use explicit ``is None`` checks instead. 356 prev_prob_raw = prev.get("prob_up") 357 new_prob_raw = summary.get("prob_up") 358 first_observation = prev_prob_raw is None 359 prev_prob = float(prev_prob_raw if prev_prob_raw is not None else 0.5) 360 new_prob = float(new_prob_raw if new_prob_raw is not None else 0.5) 361 prev_dir = str(prev.get("direction") or "") 362 new_dir = str(summary.get("direction") or "") 363 364 delta_prob_up = new_prob - prev_prob 365 flipped = (prev_dir and new_dir and prev_dir != new_dir) 366 # First observation always emits (baseline → first forecast is the 367 # most informative event), even when prob_up is exactly 0.5. 368 if not first_observation and ( 369 abs(delta_prob_up) < _MIN_DELTA_PROB_UP and not flipped 370 ): 371 state.last_summary = dict(summary) 372 return 373 374 client = await self._get_redis() 375 evt = { 376 "schema_version": "1", 377 "symbol": symbol, 378 "timeframe": timeframe, 379 "direction": new_dir, 380 "prob_up": new_prob, 381 "prob_up_delta": round(delta_prob_up, 4), 382 "direction_flipped": flipped, 383 "exp_return_pct": float(summary.get("exp_return_pct") or 0.0), 384 "vol_estimate_pct": float(summary.get("vol_estimate_pct") or 0.0), 385 "model": payload.get("model", ""), 386 "horizon_min": payload.get("horizon_min"), 387 "generated_at": payload.get("generated_at", ""), 388 } 389 try: 390 await client.xadd( 391 self.emit_stream, 392 {"data": json.dumps(evt, default=str)}, 393 maxlen=10_000, approximate=True, 394 ) 395 bus = get_event_bus() 396 await bus.publish( 397 "system.status", 398 arena="", phase="", 399 payload={"kind": "kronos.forecast", **evt}, 400 ) 401 except Exception as exc: 402 logger.warning("emit failed: %s", exc) 403 state.last_summary = dict(summary) 404 405 406# ── helpers ──────────────────────────────────────────────────────────────── 407 408 409def _f(v: Any) -> float: 410 try: 411 return float(v) 412 except (TypeError, ValueError): 413 return 0.0 414 415 416def _decode_bar(fields: Any) -> dict[str, Any] | None: 417 """Decode one Redis stream entry into a dict (per-field JSON-aware).""" 418 if not isinstance(fields, dict): 419 return None 420 out: dict[str, Any] = {} 421 for k, v in fields.items(): 422 key = k.decode("utf-8") if isinstance(k, bytes) else str(k) 423 if isinstance(v, bytes): 424 try: 425 v = v.decode("utf-8") 426 except UnicodeDecodeError: 427 continue 428 if isinstance(v, str): 429 try: 430 v = json.loads(v) 431 except (json.JSONDecodeError, TypeError): 432 pass 433 out[key] = v 434 return out 435 436 437__all__ = [ 438 "DEFAULT_EMIT_STREAM", 439 "DEFAULT_FORECAST_KEY", 440 "DEFAULT_PROFILES", 441 "DEFAULT_SIDECAR_URL", 442 "KronosRefresher", 443]