checking system…
Docs / back / src/maf/watch/list.py · line 72
Python · 272 lines
  1"""Redis-backed watch list.
  2
  3Wire shape
  4----------
  5Stored as a sorted set at ``maf:watch:zset``. Each member is a JSON string
  6identifying the watched target; the score is its absolute expiry timestamp
  7(unix seconds). A periodic call to :meth:`WatchList.decay` evicts expired
  8members. ZADD with GT keeps the most-recent interest signal.
  9
 10Why opaque IDs
 11--------------
 12The list isn't trading-specific. ``kind`` is just a tag — ``"symbol"``,
 13``"question"``, ``"document"``, anything an arena cares about. The Kronos
 14refresher only ever asks for ``kind="symbol"``; Mirofish refreshers will
 15ask for ``kind="document"``. Adding a new kind doesn't require schema
 16changes — pick a string and use it.
 17
 18Concurrency
 19-----------
 20Operations are individual Redis commands; we never read-modify-write across
 21multiple commands, so two writers can't lose each other's updates. ``add``
 22uses ZADD GT so a later interest signal can extend a TTL but an earlier
 23one can't shrink it.
 24
 25Memory + cost
 26-------------
 27Sorted-set entries are tiny (one JSON object per member). 1000 watched
 28items ≈ a few KB. We don't need a Bloom filter or shard scheme until
 29multi-tenant deployment lands (out of scope for Phase 1).
 30"""
 31
 32from __future__ import annotations
 33
 34import json
 35import logging
 36import os
 37import time
 38from dataclasses import dataclass, field
 39from typing import Any
 40
 41logger = logging.getLogger(__name__)
 42
 43
 44DEFAULT_KEY = "maf:watch:zset"
 45DEFAULT_TTL_SECONDS = 6 * 3600   # 6 h — covers a trading session + buffer
 46
 47# Canonical kind tags. Strings are intentionally open-ended; these constants
 48# exist so callers can refer to them without typos.
 49KIND_SYMBOL = "symbol"      # e.g. SPY, NVDA, BTC-USD — used by Kronos refresher
 50KIND_QUESTION = "question"  # mastermind / report_to_action — opaque id
 51KIND_DOCUMENT = "document"  # fomo2 report ids — drives Mirofish event triggers
 52
 53
 54@dataclass(frozen=True)
 55class WatchEntry:
 56    """One row in the watch list — what we hand back to callers."""
 57
 58    target_id: str
 59    kind: str
 60    expires_at: float           # unix seconds
 61    attrs: dict[str, Any] = field(default_factory=dict)
 62
 63    @property
 64    def remaining_seconds(self) -> float:
 65        return max(0.0, self.expires_at - time.time())
 66
 67    @property
 68    def expired(self) -> bool:
 69        return time.time() >= self.expires_at
 70
 71
 72class WatchList:
 73    """Redis sorted-set wrapper for the global watch list.
 74
 75    Use :func:`get` to fetch the process-wide singleton, or instantiate
 76    directly for tests / multi-tenant scenarios (pass a different ``key``).
 77    """
 78
 79    def __init__(
 80        self,
 81        redis_url: str | None = None,
 82        key: str = DEFAULT_KEY,
 83    ) -> None:
 84        self.redis_url = redis_url or os.environ.get(
 85            "REDIS_URL", "redis://localhost:6379/0",
 86        )
 87        self.key = key
 88        self._redis: Any = None
 89
 90    async def _get_redis(self) -> Any:
 91        if self._redis is None:
 92            import redis.asyncio as aioredis
 93            self._redis = aioredis.from_url(self.redis_url)
 94        return self._redis
 95
 96    async def aclose(self) -> None:
 97        if self._redis is None:
 98            return
 99        try:
100            ac = getattr(self._redis, "aclose", None)
101            if ac:
102                await ac()
103            else:
104                await self._redis.close()
105        except Exception:
106            pass
107
108    # ── Writes ──────────────────────────────────────────────────────────────
109
110    async def add(
111        self,
112        target_id: str,
113        *,
114        kind: str = KIND_SYMBOL,
115        ttl_seconds: int = DEFAULT_TTL_SECONDS,
116        attrs: dict[str, Any] | None = None,
117    ) -> WatchEntry:
118        """Add ``target_id`` (or extend its TTL if already present).
119
120        Uses ``ZADD GT`` so a later expiry wins. An *earlier* interest signal
121        (i.e. one that would shorten the TTL) is silently ignored — interest
122        only extends, never shortens.
123        """
124        client = await self._get_redis()
125        member = _encode_member(target_id, kind, attrs or {})
126        score = time.time() + max(1, int(ttl_seconds))
127        # ZADD GT XX would skip non-members; we want to upsert. ZADD GT
128        # alone upserts and only updates if the new score > old.
129        await client.zadd(self.key, {member: score}, gt=True)
130        return WatchEntry(
131            target_id=target_id, kind=kind,
132            expires_at=score, attrs=dict(attrs or {}),
133        )
134
135    async def remove(self, target_id: str, *, kind: str | None = None) -> int:
136        """Remove every member matching ``target_id`` (and optionally ``kind``).
137
138        Returns the number of members removed. Because attrs are part of the
139        encoded member, the same target can appear with different attrs;
140        ``remove`` clears them all.
141        """
142        client = await self._get_redis()
143        # Scan the whole set; for our cardinality this is fine. If we ever
144        # hit 10k+ entries we'd switch to a hash sidecar keyed by target_id.
145        members = await client.zrange(self.key, 0, -1, withscores=False)
146        to_remove: list[bytes | str] = []
147        for raw in members:
148            entry = _decode_member(raw)
149            if entry is None:
150                continue
151            if entry.target_id != target_id:
152                continue
153            if kind is not None and entry.kind != kind:
154                continue
155            to_remove.append(raw)
156        if not to_remove:
157            return 0
158        return int(await client.zrem(self.key, *to_remove))
159
160    async def clear(self) -> int:
161        """Drop every entry. Mostly for tests."""
162        client = await self._get_redis()
163        return int(await client.delete(self.key))
164
165    # ── Reads ───────────────────────────────────────────────────────────────
166
167    async def members(
168        self,
169        *,
170        kind: str | None = None,
171        include_expired: bool = False,
172    ) -> list[WatchEntry]:
173        """Return current watch entries, optionally filtered by ``kind``."""
174        client = await self._get_redis()
175        rows = await client.zrange(self.key, 0, -1, withscores=True)
176        now = time.time()
177        out: list[WatchEntry] = []
178        for raw, score in rows:
179            entry = _decode_member(raw)
180            if entry is None:
181                continue
182            # Re-build with the score from Redis (the stored expiry).
183            entry = WatchEntry(
184                target_id=entry.target_id,
185                kind=entry.kind,
186                expires_at=float(score),
187                attrs=entry.attrs,
188            )
189            if not include_expired and entry.expires_at <= now:
190                continue
191            if kind is not None and entry.kind != kind:
192                continue
193            out.append(entry)
194        return out
195
196    async def is_watched(self, target_id: str, *, kind: str | None = None) -> bool:
197        for e in await self.members(kind=kind):
198            if e.target_id == target_id:
199                return True
200        return False
201
202    # ── Maintenance ─────────────────────────────────────────────────────────
203
204    async def decay(self) -> int:
205        """Drop expired members. Returns the number evicted.
206
207        Safe to call frequently; uses ``ZREMRANGEBYSCORE`` which is O(log N + M).
208        """
209        client = await self._get_redis()
210        return int(await client.zremrangebyscore(self.key, 0, time.time()))
211
212
213# ── Process-global singleton ───────────────────────────────────────────────
214
215
216_INSTANCE: WatchList | None = None
217
218
219def get_watch_list(
220    redis_url: str | None = None,
221    key: str = DEFAULT_KEY,
222) -> WatchList:
223    """Lazy-built process-wide WatchList.
224
225    Pass an explicit ``redis_url`` / ``key`` only on the first call; subsequent
226    calls return the cached instance regardless of args (so callers don't
227    need to thread config everywhere).
228    """
229    global _INSTANCE
230    if _INSTANCE is None:
231        _INSTANCE = WatchList(redis_url=redis_url, key=key)
232    return _INSTANCE
233
234
235def reset_watch_list_singleton() -> None:
236    """For tests — drop the cached singleton so the next call rebuilds it."""
237    global _INSTANCE
238    _INSTANCE = None
239
240
241# ── Encoding helpers ───────────────────────────────────────────────────────
242
243
244def _encode_member(target_id: str, kind: str, attrs: dict[str, Any]) -> str:
245    """Stable JSON encoding for the ZSET member.
246
247    Sort keys so two callers with the same logical entry produce the same
248    Redis member (lets ZADD GT deduplicate properly).
249    """
250    body: dict[str, Any] = {"id": target_id, "kind": kind}
251    if attrs:
252        body["attrs"] = dict(attrs)
253    return json.dumps(body, sort_keys=True, separators=(",", ":"))
254
255
256def _decode_member(raw: Any) -> WatchEntry | None:
257    if isinstance(raw, bytes):
258        raw = raw.decode("utf-8", errors="replace")
259    try:
260        body = json.loads(raw)
261    except (json.JSONDecodeError, TypeError):
262        logger.debug("WatchList: skipping malformed member %r", raw)
263        return None
264    if not isinstance(body, dict) or "id" not in body or "kind" not in body:
265        return None
266    return WatchEntry(
267        target_id=str(body["id"]),
268        kind=str(body["kind"]),
269        expires_at=0.0,  # filled in by the caller from the zscore
270        attrs=dict(body.get("attrs") or {}),
271    )