checking system…
Docs / back / tests/core/test_target_autoprune.py · line 1
Python · 158 lines
  1"""Tests for the agnostic-target rework:
  2
  3* :class:`Target.from_dict` recognises the legacy target shapes.
  4* :class:`Arena._build_graph` prunes agents whose
  5  ``applicable_target_types`` doesn't include the active target type.
  6* Agents and sources with an empty list still apply to everything
  7  (backwards-compatible).
  8"""
  9
 10from __future__ import annotations
 11
 12from unittest.mock import MagicMock, patch
 13
 14import pytest
 15
 16from maf.config import (
 17    AgentConfig, ArenaConfig, PhaseConfig, SourceBinding, Target,
 18)
 19
 20
 21# ── Target.from_dict ──
 22
 23
 24def test_target_from_dict_ticker():
 25    t = Target.from_dict({"ticker": "NVDA"})
 26    assert t.type == "ticker"
 27    assert t.primary_id == "NVDA"
 28    assert t.secondary_ids == []
 29
 30
 31def test_target_from_dict_sector_with_peers():
 32    t = Target.from_dict({"sector": "AI semis", "tickers": ["NVDA", "AMD"]})
 33    assert t.type == "sector"
 34    assert t.primary_id == "AI semis"
 35    assert t.secondary_ids == ["NVDA", "AMD"]
 36
 37
 38def test_target_from_dict_tickers_basket():
 39    t = Target.from_dict({"tickers": ["NVDA", "AMD", "AVGO"]})
 40    assert t.type == "tickers"
 41    assert t.primary_id == "NVDA"
 42    assert t.secondary_ids == ["AMD", "AVGO"]
 43
 44
 45def test_target_from_dict_question():
 46    t = Target.from_dict({"question_id": "q-1", "angle": "risk"})
 47    assert t.type == "question"
 48    assert t.primary_id == "q-1"
 49    assert t.metadata == {"angle": "risk"}
 50
 51
 52def test_target_from_dict_empty_and_unknown_become_free_text():
 53    assert Target.from_dict({}).type == "free_text"
 54    assert Target.from_dict({"foo": "bar"}).type == "free_text"
 55    assert Target.from_dict(None).type == "free_text"
 56
 57
 58def test_target_from_dict_typed_pass_through():
 59    t = Target.from_dict({"type": "deal", "primary_id": "d-1", "secondary_ids": []})
 60    assert t.type == "deal"
 61    assert t.primary_id == "d-1"
 62
 63
 64# ── Auto-prune in Arena._build_graph ──
 65
 66
 67def _arena_with_filtered_agents():
 68    """Build a tiny arena with two agents — one ticker-only, one sector-only,
 69    one universal. We don't actually run it, just inspect the resulting graph."""
 70    return ArenaConfig(
 71        name="test_filter",
 72        applicable_target_types=["ticker", "sector"],
 73        sources=[],
 74        phases=[PhaseConfig(
 75            name="analysis", pattern="parallel", agents=[
 76                AgentConfig(name="ticker_only", role="specialist",
 77                            applicable_target_types=["ticker"]),
 78                AgentConfig(name="sector_only", role="specialist",
 79                            applicable_target_types=["sector"]),
 80                AgentConfig(name="universal", role="specialist",
 81                            applicable_target_types=[]),
 82            ],
 83        )],
 84    )
 85
 86
 87def test_has_target_filters_detects_agents():
 88    from maf.core.arena import Arena
 89    arena = Arena(_arena_with_filtered_agents())
 90    assert arena._has_target_filters() is True
 91
 92
 93def test_has_target_filters_false_when_none_set():
 94    from maf.core.arena import Arena
 95    cfg = ArenaConfig(
 96        name="no_filter",
 97        sources=[SourceBinding(name="x", adapter="alpaca")],
 98        phases=[PhaseConfig(name="analysis", pattern="parallel", agents=[
 99            AgentConfig(name="a", role="specialist"),
100        ])],
101    )
102    arena = Arena(cfg)
103    assert arena._has_target_filters() is False
104
105
106@pytest.fixture
107def registered_specialist_agent():
108    """Register a fake specialist class for the agent factory."""
109    from maf.core.arena import _ROLE_REGISTRY  # type: ignore[attr-defined]
110    fake_cls = MagicMock()
111    fake_cls.return_value = MagicMock(name="agent_instance")
112    saved = _ROLE_REGISTRY.get("specialist")
113    _ROLE_REGISTRY["specialist"] = fake_cls
114    yield fake_cls
115    if saved is None:
116        _ROLE_REGISTRY.pop("specialist", None)
117    else:
118        _ROLE_REGISTRY["specialist"] = saved
119
120
121def test_build_graph_prunes_ticker_only_for_sector_target(registered_specialist_agent):
122    from maf.core.arena import Arena
123    arena = Arena(_arena_with_filtered_agents())
124    g = arena._build_graph(target_type="sector")
125    phase = g.phases["analysis"]  # PhaseGraph.phases is a name → Phase dict
126    # ticker_only must be filtered out; sector_only + universal stay.
127    assert len(phase.agents) == 2
128
129
130def test_build_graph_includes_all_when_no_target_type(registered_specialist_agent):
131    """Without a target_type, no filtering happens — legacy behaviour."""
132    from maf.core.arena import Arena
133    arena = Arena(_arena_with_filtered_agents())
134    g = arena._build_graph()  # no target_type
135    phase = g.phases["analysis"]
136    assert len(phase.agents) == 3  # all three present
137
138
139def test_build_graph_keeps_universal_agents(registered_specialist_agent):
140    from maf.core.arena import Arena
141    arena = Arena(_arena_with_filtered_agents())
142    # For an unrelated target type, only the universal agent should remain.
143    g = arena._build_graph(target_type="question")
144    phase = g.phases["analysis"]
145    assert len(phase.agents) == 1  # just 'universal'
146
147
148# ── SourceBinding gets the field too ──
149
150
151def test_source_binding_applicable_target_types_optional():
152    sb = SourceBinding(name="x", adapter="alpaca")
153    assert sb.applicable_target_types == []
154
155    sb2 = SourceBinding(name="y", adapter="alpaca",
156                        applicable_target_types=["ticker", "sector"])
157    assert sb2.applicable_target_types == ["ticker", "sector"]