checking system…
Docs / back / services/kronos-svc/server.py · line 113
Python · 367 lines
  1"""Kronos forecasting sidecar — FastAPI wrapper around KronosPredictor.
  2
  3Wire shape
  4----------
  5
  6``POST /forecast`` accepts a JSON body with the history bars and forecast
  7parameters, runs Kronos inference in a thread (so the event loop stays
  8responsive), and returns a structured summary plus optional per-bar series.
  9
 10We keep this service *thin*: it does not poll Redis, does not talk to
 11QuestDB, and does not own the freshness cadence. The MAF-side
 12:mod:`maf.scheduler.kronos_refresher` is the policy layer; the sidecar is
 13pure inference. That separation means:
 14  - MAF stays torch-free.
 15  - The sidecar is reusable from any caller (not just MAF).
 16  - GPU upgrades replace the sidecar image, MAF doesn't change.
 17
 18Lazy model load
 19---------------
 20The first ``/forecast`` request pays the model-download cost (a few hundred
 21MB for Kronos-small). Subsequent requests are warm. We load the model under
 22an asyncio Lock so concurrent first-requests don't race.
 23
 24Health
 25------
 26``GET /health`` returns 200 + a minimal JSON envelope. It does *not* trigger
 27a model load — that would make health probes painful during cold-start.
 28A ``GET /ready`` endpoint signals model-loaded and is safe for
 29container-orchestrator readiness gates.
 30"""
 31
 32from __future__ import annotations
 33
 34import asyncio
 35import logging
 36import os
 37import sys
 38import time
 39from datetime import datetime, timezone
 40
 41UTC = timezone.utc  # ``datetime.UTC`` is Python 3.11+; the sidecar runs on 3.10+.
 42from pathlib import Path
 43from typing import Any
 44
 45import pandas as pd
 46from fastapi import FastAPI, HTTPException
 47from pydantic import BaseModel, ConfigDict, Field
 48
 49# Kronos source lives at ../../repos/Kronos when running this dev. The
 50# Dockerfile copies it into /app/kronos so the import works in the container too.
 51_THIS = Path(__file__).resolve().parent
 52_KRONOS_CANDIDATE_PATHS = [
 53    Path("/app/kronos"),  # container layout (Dockerfile copies here)
 54    _THIS / "kronos",      # repo-root override
 55    Path("/home/trbck/workspace/repos/Kronos"),  # dev host
 56]
 57for p in _KRONOS_CANDIDATE_PATHS:
 58    if (p / "model" / "__init__.py").exists():
 59        sys.path.insert(0, str(p))
 60        break
 61
 62
 63logger = logging.getLogger("kronos-svc")
 64logging.basicConfig(
 65    level=os.environ.get("LOG_LEVEL", "INFO").upper(),
 66    format="%(asctime)s %(levelname)s %(name)s%(message)s",
 67)
 68
 69
 70# ── Request / response schemas ────────────────────────────────────────────
 71
 72
 73class HistoryBar(BaseModel):
 74    """One bar of OHLCV history. ``volume`` / ``amount`` are optional."""
 75
 76    model_config = ConfigDict(extra="ignore")
 77
 78    ts: str
 79    open: float
 80    high: float
 81    low: float
 82    close: float
 83    volume: float | None = None
 84    amount: float | None = None
 85
 86
 87class ForecastRequest(BaseModel):
 88    model_config = ConfigDict(extra="ignore")
 89
 90    symbol: str
 91    timeframe: str = "1m"
 92    history: list[HistoryBar] = Field(..., min_length=8)
 93    pred_len: int = Field(60, ge=1, le=240)
 94    sample_count: int = Field(5, ge=1, le=20)
 95    T: float = Field(1.0, ge=0.0, le=2.0)
 96    top_p: float = Field(0.9, ge=0.0, le=1.0)
 97    # Optional explicit future timestamps. When omitted we extend ``history``
 98    # by ``pred_len`` steps using the inferred cadence.
 99    y_timestamps: list[str] | None = None
100    include_forecast: bool = True
101
102
103class ForecastSummary(BaseModel):
104    direction: str
105    prob_up: float
106    exp_return_pct: float
107    vol_estimate_pct: float
108    sample_count: int
109
110
111class ForecastResponse(BaseModel):
112    schema_version: str = "1"
113    symbol: str
114    timeframe: str
115    model: str
116    generated_at: str
117    horizon_min: int
118    elapsed_ms: float
119    summary: ForecastSummary
120    forecast: list[dict[str, Any]] = Field(default_factory=list)
121
122
123# ── Predictor pool ─────────────────────────────────────────────────────────
124
125
126class PredictorPool:
127    """Singleton holding the loaded Kronos model + tokenizer.
128
129    Lazy-loads on first request. Thread-safe via an asyncio Lock around
130    the load step; once loaded the predictor itself is reentrant because
131    KronosPredictor uses a single torch model + per-call state.
132    """
133
134    _instance: "PredictorPool | None" = None
135
136    def __init__(self) -> None:
137        self._predictor: Any = None
138        self._model_name: str = os.environ.get(
139            "KRONOS_MODEL", "NeoQuasar/Kronos-small",
140        )
141        self._tokenizer_name: str = os.environ.get(
142            "KRONOS_TOKENIZER", "NeoQuasar/Kronos-Tokenizer-base",
143        )
144        self._max_context: int = int(os.environ.get("KRONOS_MAX_CONTEXT", "512"))
145        self._device: str = os.environ.get("KRONOS_DEVICE", "cpu")
146        self._lock = asyncio.Lock()
147
148    @classmethod
149    def get(cls) -> "PredictorPool":
150        if cls._instance is None:
151            cls._instance = PredictorPool()
152        return cls._instance
153
154    @property
155    def model_name(self) -> str:
156        return self._model_name
157
158    @property
159    def loaded(self) -> bool:
160        return self._predictor is not None
161
162    async def ensure_loaded(self) -> Any:
163        if self._predictor is not None:
164            return self._predictor
165        async with self._lock:
166            if self._predictor is not None:
167                return self._predictor
168            logger.info(
169                "loading model=%s tokenizer=%s device=%s",
170                self._model_name, self._tokenizer_name, self._device,
171            )
172            t0 = time.perf_counter()
173            self._predictor = await asyncio.to_thread(self._load_sync)
174            logger.info(
175                "model loaded in %.1fs", time.perf_counter() - t0,
176            )
177            return self._predictor
178
179    def _load_sync(self) -> Any:
180        from model import Kronos, KronosPredictor, KronosTokenizer  # type: ignore
181
182        tokenizer = KronosTokenizer.from_pretrained(self._tokenizer_name)
183        model = Kronos.from_pretrained(self._model_name)
184        return KronosPredictor(model, tokenizer, max_context=self._max_context)
185
186    async def predict(self, req: ForecastRequest) -> ForecastResponse:
187        predictor = await self.ensure_loaded()
188        t0 = time.perf_counter()
189        result = await asyncio.to_thread(self._predict_sync, predictor, req)
190        elapsed_ms = (time.perf_counter() - t0) * 1000
191        result.elapsed_ms = round(elapsed_ms, 1)
192        return result
193
194    def _predict_sync(self, predictor: Any, req: ForecastRequest) -> ForecastResponse:
195        df = pd.DataFrame([b.model_dump() for b in req.history])
196        df["timestamps"] = pd.to_datetime(df["ts"])
197        df = df.drop(columns=["ts"])
198
199        x_df = df[["open", "high", "low", "close"]].copy()
200        if "volume" in df.columns and df["volume"].notna().any():
201            x_df["volume"] = df["volume"].fillna(0.0)
202        if "amount" in df.columns and df["amount"].notna().any():
203            x_df["amount"] = df["amount"].fillna(0.0)
204        x_timestamp = df["timestamps"]
205
206        if req.y_timestamps:
207            y_timestamp = pd.to_datetime(pd.Series(req.y_timestamps))
208        else:
209            # Infer cadence from history; extend forward.
210            cadence = _infer_cadence(x_timestamp, req.timeframe)
211            last = x_timestamp.iloc[-1]
212            y_timestamp = pd.Series([
213                last + cadence * (i + 1) for i in range(req.pred_len)
214            ])
215
216        pred_df = predictor.predict(
217            df=x_df,
218            x_timestamp=x_timestamp,
219            y_timestamp=y_timestamp,
220            pred_len=req.pred_len,
221            T=req.T,
222            top_p=req.top_p,
223            sample_count=req.sample_count,
224        )
225
226        summary = _summarise(x_df, pred_df)
227        forecast_rows: list[dict[str, Any]] = []
228        if req.include_forecast:
229            for ts, row in pred_df.iterrows():
230                forecast_rows.append({
231                    "ts": ts.isoformat() if hasattr(ts, "isoformat") else str(ts),
232                    "open": float(row.get("open", 0.0)),
233                    "high": float(row.get("high", 0.0)),
234                    "low":  float(row.get("low", 0.0)),
235                    "close": float(row.get("close", 0.0)),
236                    "volume": float(row.get("volume", 0.0)) if "volume" in row else 0.0,
237                })
238
239        horizon_min = _horizon_minutes(req.pred_len, req.timeframe)
240        return ForecastResponse(
241            symbol=req.symbol.upper(),
242            timeframe=req.timeframe,
243            model=PredictorPool.get().model_name,
244            generated_at=datetime.now(UTC).isoformat(),
245            horizon_min=horizon_min,
246            elapsed_ms=0.0,
247            summary=summary,
248            forecast=forecast_rows,
249        )
250
251
252# ── Helpers ────────────────────────────────────────────────────────────────
253
254
255def _infer_cadence(x_timestamp: pd.Series, timeframe: str) -> pd.Timedelta:
256    """Infer the bar cadence from history. Falls back to ``timeframe``."""
257    if len(x_timestamp) >= 2:
258        deltas = x_timestamp.diff().dropna()
259        if len(deltas) > 0:
260            # Mode is robust to occasional gaps.
261            mode = deltas.mode()
262            if not mode.empty and mode.iloc[0].total_seconds() > 0:
263                return pd.Timedelta(mode.iloc[0])
264    if timeframe.endswith("m"):
265        return pd.Timedelta(minutes=int(timeframe[:-1]))
266    if timeframe.endswith("h"):
267        return pd.Timedelta(hours=int(timeframe[:-1]))
268    if timeframe.endswith("d"):
269        return pd.Timedelta(days=int(timeframe[:-1]))
270    return pd.Timedelta(minutes=1)
271
272
273def _horizon_minutes(pred_len: int, timeframe: str) -> int:
274    if timeframe.endswith("m"):
275        return pred_len * int(timeframe[:-1])
276    if timeframe.endswith("h"):
277        return pred_len * int(timeframe[:-1]) * 60
278    if timeframe.endswith("d"):
279        return pred_len * int(timeframe[:-1]) * 1440
280    return pred_len
281
282
283def _summarise(x_df: pd.DataFrame, pred_df: pd.DataFrame) -> ForecastSummary:
284    """Compress per-bar predictions into a single ``ForecastSummary``.
285
286    Direction comes from the sign of ``(pred_close_last / hist_close_last) - 1``.
287    ``prob_up`` is the share of predicted closes above the last historical
288    close; with ``sample_count>1`` averaging happens upstream in
289    ``KronosPredictor.predict``, so this is a single deterministic series —
290    we still expose ``prob_up`` as a sanity-checkable metric.
291    ``vol_estimate_pct`` is the std of returns over the predicted window.
292    """
293    last_close = float(x_df["close"].iloc[-1]) if len(x_df) else 0.0
294    if pred_df.empty or last_close == 0:
295        return ForecastSummary(
296            direction="NEUTRAL",
297            prob_up=0.5,
298            exp_return_pct=0.0,
299            vol_estimate_pct=0.0,
300            sample_count=0,
301        )
302    pred_close = pred_df["close"].astype(float)
303    end_ret = (pred_close.iloc[-1] / last_close) - 1.0
304    prob_up = float((pred_close > last_close).mean())
305    returns = pred_close.pct_change().dropna()
306    vol = float(returns.std()) if len(returns) > 1 else 0.0
307
308    if abs(end_ret) < 0.0015:  # < 0.15% — call it neutral
309        direction = "NEUTRAL"
310    elif end_ret > 0:
311        direction = "BULLISH"
312    else:
313        direction = "BEARISH"
314    return ForecastSummary(
315        direction=direction,
316        prob_up=round(prob_up, 4),
317        exp_return_pct=round(end_ret * 100.0, 4),
318        vol_estimate_pct=round(vol * 100.0, 4),
319        sample_count=int(len(pred_df)),
320    )
321
322
323# ── FastAPI app ────────────────────────────────────────────────────────────
324
325
326app = FastAPI(title="kronos-svc", version="1.0.0")
327
328
329@app.get("/health")
330async def health() -> dict[str, Any]:
331    pool = PredictorPool.get()
332    return {
333        "status": "ok",
334        "model": pool.model_name,
335        "loaded": pool.loaded,
336    }
337
338
339@app.get("/ready")
340async def ready() -> dict[str, Any]:
341    pool = PredictorPool.get()
342    if not pool.loaded:
343        raise HTTPException(503, "model not loaded yet")
344    return {"status": "ready", "model": pool.model_name}
345
346
347@app.post("/forecast", response_model=ForecastResponse)
348async def forecast(req: ForecastRequest) -> ForecastResponse:
349    pool = PredictorPool.get()
350    try:
351        return await pool.predict(req)
352    except FileNotFoundError as exc:
353        raise HTTPException(500, f"model load failed: {exc}") from exc
354    except Exception as exc:
355        logger.exception("forecast failed")
356        raise HTTPException(500, f"{type(exc).__name__}: {exc}") from exc
357
358
359if __name__ == "__main__":
360    import uvicorn
361    uvicorn.run(
362        "server:app",
363        host=os.environ.get("HOST", "0.0.0.0"),
364        port=int(os.environ.get("PORT", "5102")),
365        reload=False,
366    )