1"""Kronos forecasting sidecar — FastAPI wrapper around KronosPredictor. 2 3Wire shape 4---------- 5 6``POST /forecast`` accepts a JSON body with the history bars and forecast 7parameters, runs Kronos inference in a thread (so the event loop stays 8responsive), and returns a structured summary plus optional per-bar series. 9 10We keep this service *thin*: it does not poll Redis, does not talk to 11QuestDB, and does not own the freshness cadence. The MAF-side 12:mod:`maf.scheduler.kronos_refresher` is the policy layer; the sidecar is 13pure inference. That separation means: 14 - MAF stays torch-free. 15 - The sidecar is reusable from any caller (not just MAF). 16 - GPU upgrades replace the sidecar image, MAF doesn't change. 17 18Lazy model load 19--------------- 20The first ``/forecast`` request pays the model-download cost (a few hundred 21MB for Kronos-small). Subsequent requests are warm. We load the model under 22an asyncio Lock so concurrent first-requests don't race. 23 24Health 25------ 26``GET /health`` returns 200 + a minimal JSON envelope. It does *not* trigger 27a model load — that would make health probes painful during cold-start. 28A ``GET /ready`` endpoint signals model-loaded and is safe for 29container-orchestrator readiness gates. 30""" 31 32from __future__ import annotations 33 34import asyncio 35import logging 36import os 37import sys 38import time 39from datetime import datetime, timezone 40 41UTC = timezone.utc # ``datetime.UTC`` is Python 3.11+; the sidecar runs on 3.10+. 42from pathlib import Path 43from typing import Any 44 45import pandas as pd 46from fastapi import FastAPI, HTTPException 47from pydantic import BaseModel, ConfigDict, Field 48 49# Kronos source lives at ../../repos/Kronos when running this dev. The 50# Dockerfile copies it into /app/kronos so the import works in the container too. 51_THIS = Path(__file__).resolve().parent 52_KRONOS_CANDIDATE_PATHS = [ 53 Path("/app/kronos"), # container layout (Dockerfile copies here) 54 _THIS / "kronos", # repo-root override 55 Path("/home/trbck/workspace/repos/Kronos"), # dev host 56] 57for p in _KRONOS_CANDIDATE_PATHS: 58 if (p / "model" / "__init__.py").exists(): 59 sys.path.insert(0, str(p)) 60 break 61 62 63logger = logging.getLogger("kronos-svc") 64logging.basicConfig( 65 level=os.environ.get("LOG_LEVEL", "INFO").upper(), 66 format="%(asctime)s %(levelname)s %(name)s — %(message)s", 67) 68 69 70# ── Request / response schemas ──────────────────────────────────────────── 71 72 73class HistoryBar(BaseModel): 74 """One bar of OHLCV history. ``volume`` / ``amount`` are optional.""" 75 76 model_config = ConfigDict(extra="ignore") 77 78 ts: str 79 open: float 80 high: float 81 low: float 82 close: float 83 volume: float | None = None 84 amount: float | None = None 85 86 87class ForecastRequest(BaseModel): 88 model_config = ConfigDict(extra="ignore") 89 90 symbol: str 91 timeframe: str = "1m" 92 history: list[HistoryBar] = Field(..., min_length=8) 93 pred_len: int = Field(60, ge=1, le=240) 94 sample_count: int = Field(5, ge=1, le=20) 95 T: float = Field(1.0, ge=0.0, le=2.0) 96 top_p: float = Field(0.9, ge=0.0, le=1.0) 97 # Optional explicit future timestamps. When omitted we extend ``history`` 98 # by ``pred_len`` steps using the inferred cadence. 99 y_timestamps: list[str] | None = None 100 include_forecast: bool = True 101 102 103class ForecastSummary(BaseModel): 104 direction: str 105 prob_up: float 106 exp_return_pct: float 107 vol_estimate_pct: float 108 sample_count: int 109 110 111class ForecastResponse(BaseModel): 112 schema_version: str = "1" 113 symbol: str 114 timeframe: str 115 model: str 116 generated_at: str 117 horizon_min: int 118 elapsed_ms: float 119 summary: ForecastSummary 120 forecast: list[dict[str, Any]] = Field(default_factory=list) 121 122 123# ── Predictor pool ───────────────────────────────────────────────────────── 124 125 126class PredictorPool: 127 """Singleton holding the loaded Kronos model + tokenizer. 128 129 Lazy-loads on first request. Thread-safe via an asyncio Lock around 130 the load step; once loaded the predictor itself is reentrant because 131 KronosPredictor uses a single torch model + per-call state. 132 """ 133 134 _instance: "PredictorPool | None" = None 135 136 def __init__(self) -> None: 137 self._predictor: Any = None 138 self._model_name: str = os.environ.get( 139 "KRONOS_MODEL", "NeoQuasar/Kronos-small", 140 ) 141 self._tokenizer_name: str = os.environ.get( 142 "KRONOS_TOKENIZER", "NeoQuasar/Kronos-Tokenizer-base", 143 ) 144 self._max_context: int = int(os.environ.get("KRONOS_MAX_CONTEXT", "512")) 145 self._device: str = os.environ.get("KRONOS_DEVICE", "cpu") 146 self._lock = asyncio.Lock() 147 148 @classmethod 149 def get(cls) -> "PredictorPool": 150 if cls._instance is None: 151 cls._instance = PredictorPool() 152 return cls._instance 153 154 @property 155 def model_name(self) -> str: 156 return self._model_name 157 158 @property 159 def loaded(self) -> bool: 160 return self._predictor is not None 161 162 async def ensure_loaded(self) -> Any: 163 if self._predictor is not None: 164 return self._predictor 165 async with self._lock: 166 if self._predictor is not None: 167 return self._predictor 168 logger.info( 169 "loading model=%s tokenizer=%s device=%s", 170 self._model_name, self._tokenizer_name, self._device, 171 ) 172 t0 = time.perf_counter() 173 self._predictor = await asyncio.to_thread(self._load_sync) 174 logger.info( 175 "model loaded in %.1fs", time.perf_counter() - t0, 176 ) 177 return self._predictor 178 179 def _load_sync(self) -> Any: 180 from model import Kronos, KronosPredictor, KronosTokenizer # type: ignore 181 182 tokenizer = KronosTokenizer.from_pretrained(self._tokenizer_name) 183 model = Kronos.from_pretrained(self._model_name) 184 return KronosPredictor(model, tokenizer, max_context=self._max_context) 185 186 async def predict(self, req: ForecastRequest) -> ForecastResponse: 187 predictor = await self.ensure_loaded() 188 t0 = time.perf_counter() 189 result = await asyncio.to_thread(self._predict_sync, predictor, req) 190 elapsed_ms = (time.perf_counter() - t0) * 1000 191 result.elapsed_ms = round(elapsed_ms, 1) 192 return result 193 194 def _predict_sync(self, predictor: Any, req: ForecastRequest) -> ForecastResponse: 195 df = pd.DataFrame([b.model_dump() for b in req.history]) 196 df["timestamps"] = pd.to_datetime(df["ts"]) 197 df = df.drop(columns=["ts"]) 198 199 x_df = df[["open", "high", "low", "close"]].copy() 200 if "volume" in df.columns and df["volume"].notna().any(): 201 x_df["volume"] = df["volume"].fillna(0.0) 202 if "amount" in df.columns and df["amount"].notna().any(): 203 x_df["amount"] = df["amount"].fillna(0.0) 204 x_timestamp = df["timestamps"] 205 206 if req.y_timestamps: 207 y_timestamp = pd.to_datetime(pd.Series(req.y_timestamps)) 208 else: 209 # Infer cadence from history; extend forward. 210 cadence = _infer_cadence(x_timestamp, req.timeframe) 211 last = x_timestamp.iloc[-1] 212 y_timestamp = pd.Series([ 213 last + cadence * (i + 1) for i in range(req.pred_len) 214 ]) 215 216 pred_df = predictor.predict( 217 df=x_df, 218 x_timestamp=x_timestamp, 219 y_timestamp=y_timestamp, 220 pred_len=req.pred_len, 221 T=req.T, 222 top_p=req.top_p, 223 sample_count=req.sample_count, 224 ) 225 226 summary = _summarise(x_df, pred_df) 227 forecast_rows: list[dict[str, Any]] = [] 228 if req.include_forecast: 229 for ts, row in pred_df.iterrows(): 230 forecast_rows.append({ 231 "ts": ts.isoformat() if hasattr(ts, "isoformat") else str(ts), 232 "open": float(row.get("open", 0.0)), 233 "high": float(row.get("high", 0.0)), 234 "low": float(row.get("low", 0.0)), 235 "close": float(row.get("close", 0.0)), 236 "volume": float(row.get("volume", 0.0)) if "volume" in row else 0.0, 237 }) 238 239 horizon_min = _horizon_minutes(req.pred_len, req.timeframe) 240 return ForecastResponse( 241 symbol=req.symbol.upper(), 242 timeframe=req.timeframe, 243 model=PredictorPool.get().model_name, 244 generated_at=datetime.now(UTC).isoformat(), 245 horizon_min=horizon_min, 246 elapsed_ms=0.0, 247 summary=summary, 248 forecast=forecast_rows, 249 ) 250 251 252# ── Helpers ──────────────────────────────────────────────────────────────── 253 254 255def _infer_cadence(x_timestamp: pd.Series, timeframe: str) -> pd.Timedelta: 256 """Infer the bar cadence from history. Falls back to ``timeframe``.""" 257 if len(x_timestamp) >= 2: 258 deltas = x_timestamp.diff().dropna() 259 if len(deltas) > 0: 260 # Mode is robust to occasional gaps. 261 mode = deltas.mode() 262 if not mode.empty and mode.iloc[0].total_seconds() > 0: 263 return pd.Timedelta(mode.iloc[0]) 264 if timeframe.endswith("m"): 265 return pd.Timedelta(minutes=int(timeframe[:-1])) 266 if timeframe.endswith("h"): 267 return pd.Timedelta(hours=int(timeframe[:-1])) 268 if timeframe.endswith("d"): 269 return pd.Timedelta(days=int(timeframe[:-1])) 270 return pd.Timedelta(minutes=1) 271 272 273def _horizon_minutes(pred_len: int, timeframe: str) -> int: 274 if timeframe.endswith("m"): 275 return pred_len * int(timeframe[:-1]) 276 if timeframe.endswith("h"): 277 return pred_len * int(timeframe[:-1]) * 60 278 if timeframe.endswith("d"): 279 return pred_len * int(timeframe[:-1]) * 1440 280 return pred_len 281 282 283def _summarise(x_df: pd.DataFrame, pred_df: pd.DataFrame) -> ForecastSummary: 284 """Compress per-bar predictions into a single ``ForecastSummary``. 285 286 Direction comes from the sign of ``(pred_close_last / hist_close_last) - 1``. 287 ``prob_up`` is the share of predicted closes above the last historical 288 close; with ``sample_count>1`` averaging happens upstream in 289 ``KronosPredictor.predict``, so this is a single deterministic series — 290 we still expose ``prob_up`` as a sanity-checkable metric. 291 ``vol_estimate_pct`` is the std of returns over the predicted window. 292 """ 293 last_close = float(x_df["close"].iloc[-1]) if len(x_df) else 0.0 294 if pred_df.empty or last_close == 0: 295 return ForecastSummary( 296 direction="NEUTRAL", 297 prob_up=0.5, 298 exp_return_pct=0.0, 299 vol_estimate_pct=0.0, 300 sample_count=0, 301 ) 302 pred_close = pred_df["close"].astype(float) 303 end_ret = (pred_close.iloc[-1] / last_close) - 1.0 304 prob_up = float((pred_close > last_close).mean()) 305 returns = pred_close.pct_change().dropna() 306 vol = float(returns.std()) if len(returns) > 1 else 0.0 307 308 if abs(end_ret) < 0.0015: # < 0.15% — call it neutral 309 direction = "NEUTRAL" 310 elif end_ret > 0: 311 direction = "BULLISH" 312 else: 313 direction = "BEARISH" 314 return ForecastSummary( 315 direction=direction, 316 prob_up=round(prob_up, 4), 317 exp_return_pct=round(end_ret * 100.0, 4), 318 vol_estimate_pct=round(vol * 100.0, 4), 319 sample_count=int(len(pred_df)), 320 ) 321 322 323# ── FastAPI app ──────────────────────────────────────────────────────────── 324 325 326app = FastAPI(title="kronos-svc", version="1.0.0") 327 328 329@app.get("/health") 330async def health() -> dict[str, Any]: 331 pool = PredictorPool.get() 332 return { 333 "status": "ok", 334 "model": pool.model_name, 335 "loaded": pool.loaded, 336 } 337 338 339@app.get("/ready") 340async def ready() -> dict[str, Any]: 341 pool = PredictorPool.get() 342 if not pool.loaded: 343 raise HTTPException(503, "model not loaded yet") 344 return {"status": "ready", "model": pool.model_name} 345 346 347@app.post("/forecast", response_model=ForecastResponse) 348async def forecast(req: ForecastRequest) -> ForecastResponse: 349 pool = PredictorPool.get() 350 try: 351 return await pool.predict(req) 352 except FileNotFoundError as exc: 353 raise HTTPException(500, f"model load failed: {exc}") from exc 354 except Exception as exc: 355 logger.exception("forecast failed") 356 raise HTTPException(500, f"{type(exc).__name__}: {exc}") from exc 357 358 359if __name__ == "__main__": 360 import uvicorn 361 uvicorn.run( 362 "server:app", 363 host=os.environ.get("HOST", "0.0.0.0"), 364 port=int(os.environ.get("PORT", "5102")), 365 reload=False, 366 )