checking system…
Docs / back / src/maf/sources/adapters/trtools2_stream.py · line 240
Python · 291 lines
  1"""Explicit trtools2 Redis-Stream adapters.
  2
  3These adapters give MAF agents direct, low-latency access to the live
  4streams produced by trtools2:
  5
  6``Trtools2BarsSource``
  7    Tails ``trtools2:bars:<timeframe>:<symbol>`` (or a single configured
  8    ``stream``) and returns the latest N bars. Defaults: ``timeframe=1m``.
  9
 10``Trtools2NewsSource``
 11    Tails ``trtools2:news`` for the most recent news items the trtools2
 12    ingest pipeline has enriched.
 13
 14``Trtools2IndicatorsSource``
 15    Tails ``trtools2:indicators`` — pre-computed indicator values keyed by
 16    (symbol, timeframe, indicator). Agents use this instead of recomputing.
 17
 18``Trtools2StrategyEventsSource``
 19    Tails ``trtools2:strategy:events`` — every strategy state transition
 20    (signal fired, position opened/closed) the trtools2 engine publishes.
 21
 22All four share the same low-friction shape: ``fetch(params)`` returns
 23``{"type": "<adapter>", "stream": "...", "count": N, "items": [...]}`` where
 24each item is the decoded XADD payload (JSON-decoded when possible).
 25
 26Why not just use the existing ``trtools2_live`` adapter?
 27--------------------------------------------------------
 28That one expects ``bars:<exchange>:<timeframe>:<symbol>`` keys and a list of
 29symbols. The user's ask for "1m and 1h OHLC" plus "news, indicators,
 30strategy events" calls for distinct, named adapters that don't need a
 31symbol list to be useful — they can tail the firehose. Agents bind whichever
 32fits their need.
 33"""
 34
 35from __future__ import annotations
 36
 37import json
 38import logging
 39from typing import Any
 40
 41from maf.sources.base import BaseSource
 42
 43logger = logging.getLogger(__name__)
 44
 45
 46# Default stream names — keep in sync with config/default.yaml::streams.
 47DEFAULT_BARS_1M = "trtools2:bars:1m"
 48DEFAULT_BARS_1H = "trtools2:bars:1h"
 49DEFAULT_NEWS = "trtools2:news"
 50DEFAULT_INDICATORS = "trtools2:indicators"
 51DEFAULT_STRATEGY_EVENTS = "trtools2:strategy:events"
 52
 53
 54def _decode_fields(fields: Any) -> dict[str, Any]:
 55    """Decode a Redis Stream entry into a JSON-friendly dict.
 56
 57    Tries JSON-decode per field value so structured payloads round-trip
 58    cleanly without forcing the caller to know which fields were stringified.
 59    """
 60    out: dict[str, Any] = {}
 61    if not isinstance(fields, dict):
 62        return out
 63    for k, v in fields.items():
 64        key = k.decode("utf-8") if isinstance(k, bytes) else str(k)
 65        if isinstance(v, bytes):
 66            try:
 67                val: Any = v.decode("utf-8")
 68            except UnicodeDecodeError:
 69                out[key] = repr(v)
 70                continue
 71        else:
 72            val = v
 73        if isinstance(val, str):
 74            try:
 75                val = json.loads(val)
 76            except (json.JSONDecodeError, TypeError):
 77                pass
 78        out[key] = val
 79    return out
 80
 81
 82class _StreamReaderMixin:
 83    """Shared helper: connect to Redis, read most-recent N entries."""
 84
 85    def __init__(self, config: dict[str, Any]) -> None:
 86        super().__init__(config)  # type: ignore[misc]
 87        self._redis: Any = None
 88
 89    async def _get_redis(self) -> Any:
 90        if self._redis is None:
 91            import redis.asyncio as aioredis
 92
 93            self._redis = aioredis.from_url(
 94                self.config.get("redis_url", "redis://localhost:6379/0")
 95            )
 96        return self._redis
 97
 98    async def _xrev(self, stream: str, count: int) -> list[dict[str, Any]]:
 99        """Return up to ``count`` most-recent entries on ``stream``."""
100        try:
101            client = await self._get_redis()
102            rows = await client.xrevrange(stream, count=count)
103        except Exception as exc:
104            logger.warning("trtools2 stream %s xrevrange failed: %s", stream, exc)
105            return []
106        items: list[dict[str, Any]] = []
107        for entry_id, fields in rows:
108            sid = (
109                entry_id.decode("utf-8")
110                if isinstance(entry_id, bytes)
111                else str(entry_id)
112            )
113            decoded = _decode_fields(fields)
114            decoded["_id"] = sid
115            items.append(decoded)
116        return items
117
118
119class Trtools2BarsSource(_StreamReaderMixin, BaseSource):
120    """Reads OHLC bars from a configurable trtools2 stream.
121
122    Config:
123        timeframe:   "1m" | "1h" | etc. — selects the default stream.
124        stream:      explicit stream name; overrides ``timeframe``.
125        symbol:      optional symbol filter applied client-side.
126        count:       max items (default 50).
127    """
128
129    adapter_name = "trtools2_bars"
130
131    @classmethod
132    def freshness_spec(cls, binding_config: dict[str, Any]) -> dict[str, Any]:
133        stream = binding_config.get("stream") or _bars_stream_for(binding_config.get("timeframe", "1m"))
134        return {"type": "stream", "stream": stream}
135
136    async def fetch(self, params: dict[str, Any] | None = None) -> dict[str, Any]:
137        p = {**self.config, **(params or {})}
138        stream = p.get("stream") or _bars_stream_for(p.get("timeframe", "1m"))
139        count = int(p.get("count", 50))
140        symbol_filter = (p.get("symbol") or "").upper()
141
142        items = await self._xrev(stream, count * 2 if symbol_filter else count)
143        if symbol_filter:
144            items = [
145                it for it in items
146                if str(it.get("symbol", "")).upper() == symbol_filter
147            ][:count]
148
149        return {
150            "type": "trtools2_bars",
151            "stream": stream,
152            "count": len(items),
153            "items": items,
154        }
155
156
157class Trtools2NewsSource(_StreamReaderMixin, BaseSource):
158    """Latest news entries from ``trtools2:news``.
159
160    Config:
161        stream:      override stream name (default ``trtools2:news``)
162        count:       max items (default 30)
163        symbol:      optional client-side symbol/ticker filter
164    """
165
166    adapter_name = "trtools2_news"
167
168    @classmethod
169    def freshness_spec(cls, binding_config: dict[str, Any]) -> dict[str, Any]:
170        return {"type": "stream", "stream": binding_config.get("stream") or DEFAULT_NEWS}
171
172    async def fetch(self, params: dict[str, Any] | None = None) -> dict[str, Any]:
173        p = {**self.config, **(params or {})}
174        stream = p.get("stream", DEFAULT_NEWS)
175        count = int(p.get("count", 30))
176        symbol_filter = (p.get("symbol") or "").upper()
177
178        items = await self._xrev(stream, count * 2 if symbol_filter else count)
179        if symbol_filter:
180            def matches(it: dict[str, Any]) -> bool:
181                sym = str(it.get("symbol") or it.get("ticker") or "").upper()
182                if sym == symbol_filter:
183                    return True
184                # Some pipelines stuff tickers into a list under "tickers"
185                tickers = it.get("tickers")
186                if isinstance(tickers, list):
187                    return any(str(t).upper() == symbol_filter for t in tickers)
188                return False
189            items = [it for it in items if matches(it)][:count]
190        return {
191            "type": "trtools2_news",
192            "stream": stream,
193            "count": len(items),
194            "items": items,
195        }
196
197
198class Trtools2IndicatorsSource(_StreamReaderMixin, BaseSource):
199    """Latest indicator values from ``trtools2:indicators``.
200
201    Config:
202        stream:      override stream name
203        count:       max items (default 50)
204        symbol:      optional filter
205        indicator:   optional filter (e.g. "RSI_14")
206    """
207
208    adapter_name = "trtools2_indicators"
209
210    @classmethod
211    def freshness_spec(cls, binding_config: dict[str, Any]) -> dict[str, Any]:
212        return {"type": "stream", "stream": binding_config.get("stream") or DEFAULT_INDICATORS}
213
214    async def fetch(self, params: dict[str, Any] | None = None) -> dict[str, Any]:
215        p = {**self.config, **(params or {})}
216        stream = p.get("stream", DEFAULT_INDICATORS)
217        count = int(p.get("count", 50))
218        sym = (p.get("symbol") or "").upper()
219        ind = p.get("indicator") or ""
220
221        items = await self._xrev(stream, count * 3 if (sym or ind) else count)
222        if sym:
223            items = [it for it in items if str(it.get("symbol", "")).upper() == sym]
224        if ind:
225            items = [it for it in items if str(it.get("indicator", "")) == ind]
226        items = items[:count]
227        return {
228            "type": "trtools2_indicators",
229            "stream": stream,
230            "count": len(items),
231            "items": items,
232        }
233
234
235class Trtools2StrategyEventsSource(_StreamReaderMixin, BaseSource):
236    """Latest strategy events from ``trtools2:strategy:events``.
237
238    Config:
239        stream:      override stream name
240        count:       max items (default 30)
241        strategy:    optional strategy-name filter
242        symbol:      optional symbol filter
243    """
244
245    adapter_name = "trtools2_strategy_events"
246
247    @classmethod
248    def freshness_spec(cls, binding_config: dict[str, Any]) -> dict[str, Any]:
249        return {"type": "stream", "stream": binding_config.get("stream") or DEFAULT_STRATEGY_EVENTS}
250
251    async def fetch(self, params: dict[str, Any] | None = None) -> dict[str, Any]:
252        p = {**self.config, **(params or {})}
253        stream = p.get("stream", DEFAULT_STRATEGY_EVENTS)
254        count = int(p.get("count", 30))
255        strategy = p.get("strategy") or ""
256        sym = (p.get("symbol") or "").upper()
257
258        items = await self._xrev(stream, count * 3 if (sym or strategy) else count)
259        if strategy:
260            items = [it for it in items if str(it.get("strategy", "")) == strategy]
261        if sym:
262            items = [it for it in items if str(it.get("symbol", "")).upper() == sym]
263        items = items[:count]
264        return {
265            "type": "trtools2_strategy_events",
266            "stream": stream,
267            "count": len(items),
268            "items": items,
269        }
270
271
272def _bars_stream_for(timeframe: str) -> str:
273    """Map a timeframe string to the default trtools2 bars stream."""
274    tf = (timeframe or "1m").lower().strip()
275    if tf in ("1h", "60m", "1hour", "1hr"):
276        return DEFAULT_BARS_1H
277    return DEFAULT_BARS_1M
278
279
280__all__ = [
281    "DEFAULT_BARS_1H",
282    "DEFAULT_BARS_1M",
283    "DEFAULT_INDICATORS",
284    "DEFAULT_NEWS",
285    "DEFAULT_STRATEGY_EVENTS",
286    "Trtools2BarsSource",
287    "Trtools2IndicatorsSource",
288    "Trtools2NewsSource",
289    "Trtools2StrategyEventsSource",
290]