checking system…
Docs / back / src/maf/sources/adapters/trtools2_api.py · line 67
Python · 259 lines
  1"""trtools2 HTTP API adapter — call trtools2's dashboard endpoints from MAF.
  2
  3This is the canonical path for "I need Alpaca data MAF doesn't already
  4have in a Redis Stream." trtools2's dashboard (running on port 8888 by
  5default) is the single source of truth for OHLCV bars, news, snapshots,
  6coverage and pipeline health. MAF reaches it over HTTP — never calls
  7Alpaca directly.
  8
  9Configurable via the ``base_url`` config field or the ``TRTOOLS2_API_URL``
 10env var (default ``http://localhost:8888``).
 11
 12Supported query types
 13---------------------
 14
 15``bars``
 16    GET /api/chart/ohlcv/{symbol}?timeframe={tf}&limit={n}
 17    Returns OHLCV bars deduplicated/sorted ascending from QuestDB.
 18
 19``news``
 20    GET /api/live-news?limit={n}
 21    Recent news headlines with on-the-fly Loughran-McDonald sentiment.
 22
 23``prices``
 24    GET /api/live-prices
 25    Last trade price per tracked symbol.
 26
 27``snapshot``
 28    GET /api/symbol/{sym}/info
 29    Coverage + last bar + last news + last trade for one symbol.
 30
 31``coverage``
 32    GET /api/data/symbol/{sym}/coverage
 33    Per-timeframe coverage stats — earliest, latest, bar count, gaps.
 34
 35``feed_stats``
 36    GET /api/feed-stats
 37    Live feed health: rows-per-second, last published, lag.
 38
 39``pipeline_health``
 40    GET /api/data/pipeline/health
 41    Per-node freshness state across the whole pipeline.
 42
 43The agent picks which one to call by setting ``query_type`` either in
 44the binding's static config or at runtime via tool params. Any HTTP
 45failure degrades to ``{"error": "...", "data": []}`` so the calling
 46arena continues with a missing-data marker instead of crashing.
 47"""
 48
 49from __future__ import annotations
 50
 51import logging
 52import os
 53from typing import Any
 54from urllib.parse import quote as urlquote
 55
 56import httpx
 57
 58from maf.sources.base import BaseSource
 59
 60logger = logging.getLogger(__name__)
 61
 62
 63DEFAULT_BASE_URL = "http://localhost:8888"
 64DEFAULT_TIMEOUT_S = 8.0
 65
 66
 67class Trtools2ApiSource(BaseSource):
 68    """HTTP client for trtools2's dashboard API.
 69
 70    Config keys (overridable per-call via ``params``):
 71        base_url:    trtools2 dashboard URL (default $TRTOOLS2_API_URL or localhost:8888)
 72        api_key:     X-API-Key header value (default $TT2_API_KEY or $TRTOOLS2_API_KEY)
 73        query_type:  one of bars | news | prices | snapshot | coverage | feed_stats | pipeline_health
 74        symbol:      ticker (required for bars / snapshot / coverage)
 75        timeframe:   1m / 5m / 15m / 1h / 1d (for bars; default 1d)
 76        limit:       row cap (default 100 for bars, 30 for news)
 77        timeout_s:   per-request timeout (default 8s)
 78    """
 79
 80    adapter_name = "trtools2_api"
 81
 82    @classmethod
 83    def freshness_spec(cls, binding_config: dict[str, Any]) -> dict[str, Any]:
 84        return {
 85            "type": "external",
 86            "detail": "HTTP API — trtools2 dashboard (default :8888)",
 87        }
 88
 89    @property
 90    def _base_url(self) -> str:
 91        return str(
 92            self.config.get("base_url")
 93            or os.environ.get("TRTOOLS2_API_URL")
 94            or DEFAULT_BASE_URL
 95        ).rstrip("/")
 96
 97    @property
 98    def _api_key(self) -> str | None:
 99        # Honour explicit per-binding config first, then either env name.
100        # trtools2 reads ``TT2_API_KEY``; we also accept ``TRTOOLS2_API_KEY``
101        # for readability inside MAF.
102        return (
103            self.config.get("api_key")
104            or os.environ.get("TT2_API_KEY")
105            or os.environ.get("TRTOOLS2_API_KEY")
106        )
107
108    def _headers(self) -> dict[str, str]:
109        key = self._api_key
110        return {"X-API-Key": key} if key else {}
111
112    async def fetch(self, params: dict[str, Any] | None = None) -> dict[str, Any]:
113        cfg = {**self.config, **(params or {})}
114        query_type = (cfg.get("query_type") or "bars").lower()
115        timeout = float(cfg.get("timeout_s") or DEFAULT_TIMEOUT_S)
116        base = self._base_url
117
118        try:
119            async with httpx.AsyncClient(timeout=timeout, headers=self._headers()) as http:
120                if query_type == "bars":
121                    return await self._bars(http, base, cfg)
122                if query_type == "news":
123                    return await self._news(http, base, cfg)
124                if query_type == "prices":
125                    return await self._prices(http, base)
126                if query_type == "snapshot":
127                    return await self._snapshot(http, base, cfg)
128                if query_type == "coverage":
129                    return await self._coverage(http, base, cfg)
130                if query_type == "feed_stats":
131                    return await self._feed_stats(http, base)
132                if query_type == "pipeline_health":
133                    return await self._pipeline_health(http, base)
134                return {
135                    "type": "trtools2_api", "query_type": query_type, "data": [],
136                    "error": f"unknown query_type {query_type!r}",
137                }
138        except httpx.ConnectError as exc:
139            return self._err(query_type, base,
140                f"trtools2 dashboard unreachable ({exc}) — is it running on {base}?")
141        except httpx.TimeoutException as exc:
142            return self._err(query_type, base, f"trtools2 timeout ({exc})")
143        except Exception as exc:
144            logger.exception("trtools2_api %s failed", query_type)
145            return self._err(query_type, base, f"{type(exc).__name__}: {exc}")
146
147    @staticmethod
148    def _err(qtype: str, base: str, msg: str) -> dict[str, Any]:
149        return {
150            "type": "trtools2_api", "query_type": qtype, "base_url": base,
151            "data": [], "error": msg,
152        }
153
154    @staticmethod
155    def _symbol(cfg: dict[str, Any]) -> str | None:
156        s = cfg.get("symbol") or cfg.get("ticker")
157        if not s:
158            syms = cfg.get("symbols") or cfg.get("tickers")
159            if isinstance(syms, (list, tuple)) and syms:
160                s = syms[0]
161            elif isinstance(syms, str) and syms:
162                s = syms.split(",")[0]
163        return str(s).strip().upper() if s else None
164
165    async def _bars(self, http: httpx.AsyncClient, base: str, cfg: dict[str, Any]) -> dict[str, Any]:
166        sym = self._symbol(cfg)
167        if not sym:
168            return self._err("bars", base, "symbol required")
169        tf = cfg.get("timeframe") or "1d"
170        limit = int(cfg.get("limit") or 100)
171        url = f"{base}/api/chart/ohlcv/{urlquote(sym)}?timeframe={urlquote(tf)}&limit={limit}"
172        r = await http.get(url)
173        if r.status_code != 200:
174            return self._err("bars", base, f"HTTP {r.status_code}{r.text[:200]}")
175        body = r.json()
176        # trtools2 returns either {"bars": [...]} or {"error": "..."} or a raw list.
177        bars = body.get("bars") if isinstance(body, dict) else body
178        if isinstance(body, dict) and "error" in body:
179            return self._err("bars", base, body["error"])
180        return {
181            "type": "trtools2_api", "query_type": "bars",
182            "symbol": sym, "timeframe": tf, "limit": limit,
183            "count": len(bars or []), "data": bars or [],
184        }
185
186    async def _news(self, http: httpx.AsyncClient, base: str, cfg: dict[str, Any]) -> dict[str, Any]:
187        limit = int(cfg.get("limit") or 30)
188        url = f"{base}/api/live-news?limit={limit}"
189        r = await http.get(url)
190        if r.status_code != 200:
191            return self._err("news", base, f"HTTP {r.status_code}")
192        body = r.json()
193        rows = body if isinstance(body, list) else body.get("rows") or body.get("data") or []
194        # Optional client-side filter by symbol.
195        wanted = self._symbol(cfg)
196        if wanted:
197            rows = [
198                r for r in rows
199                if wanted.upper() in (str(r.get("symbols") or "")).upper()
200            ]
201        return {
202            "type": "trtools2_api", "query_type": "news",
203            "limit": limit, "filter_symbol": wanted,
204            "count": len(rows), "data": rows,
205        }
206
207    async def _prices(self, http: httpx.AsyncClient, base: str) -> dict[str, Any]:
208        r = await http.get(f"{base}/api/live-prices")
209        if r.status_code != 200:
210            return self._err("prices", base, f"HTTP {r.status_code}")
211        body = r.json()
212        rows = body if isinstance(body, list) else body.get("rows") or body.get("data") or []
213        return {
214            "type": "trtools2_api", "query_type": "prices",
215            "count": len(rows), "data": rows,
216        }
217
218    async def _snapshot(self, http: httpx.AsyncClient, base: str, cfg: dict[str, Any]) -> dict[str, Any]:
219        sym = self._symbol(cfg)
220        if not sym:
221            return self._err("snapshot", base, "symbol required")
222        r = await http.get(f"{base}/api/symbol/{urlquote(sym)}/info")
223        if r.status_code != 200:
224            return self._err("snapshot", base, f"HTTP {r.status_code}")
225        return {
226            "type": "trtools2_api", "query_type": "snapshot",
227            "symbol": sym, "data": r.json(),
228        }
229
230    async def _coverage(self, http: httpx.AsyncClient, base: str, cfg: dict[str, Any]) -> dict[str, Any]:
231        sym = self._symbol(cfg)
232        if not sym:
233            return self._err("coverage", base, "symbol required")
234        r = await http.get(f"{base}/api/data/symbol/{urlquote(sym)}/coverage")
235        if r.status_code != 200:
236            return self._err("coverage", base, f"HTTP {r.status_code}")
237        return {
238            "type": "trtools2_api", "query_type": "coverage",
239            "symbol": sym, "data": r.json(),
240        }
241
242    async def _feed_stats(self, http: httpx.AsyncClient, base: str) -> dict[str, Any]:
243        r = await http.get(f"{base}/api/feed-stats")
244        if r.status_code != 200:
245            return self._err("feed_stats", base, f"HTTP {r.status_code}")
246        return {
247            "type": "trtools2_api", "query_type": "feed_stats",
248            "data": r.json(),
249        }
250
251    async def _pipeline_health(self, http: httpx.AsyncClient, base: str) -> dict[str, Any]:
252        r = await http.get(f"{base}/api/data/pipeline/health")
253        if r.status_code != 200:
254            return self._err("pipeline_health", base, f"HTTP {r.status_code}")
255        return {
256            "type": "trtools2_api", "query_type": "pipeline_health",
257            "data": r.json(),
258        }