checking system…
Docs / back / src/maf/core/arena.py · line 48
Python · 277 lines
  1"""Arena — isolated agent group that loads config, builds phases, and runs."""
  2
  3from __future__ import annotations
  4
  5import logging
  6import time
  7import uuid
  8from typing import Any
  9
 10from maf.config import ArenaConfig, AgentConfig
 11from maf.core.agent import AgentContext, BaseAgent
 12from maf.core.graph import PhaseGraph
 13from maf.core.phase import Phase
 14from maf.core.state import ArenaState, create_initial_state
 15from maf.streaming import get_event_bus
 16
 17logger = logging.getLogger(__name__)
 18
 19
 20# ---------------------------------------------------------------------------
 21# Agent factory — resolves role to concrete agent class
 22# ---------------------------------------------------------------------------
 23
 24_ROLE_REGISTRY: dict[str, type[BaseAgent]] = {}
 25
 26
 27def register_role(role: str, cls: type[BaseAgent]) -> None:
 28    """Register an agent class for a role name."""
 29    _ROLE_REGISTRY[role] = cls
 30
 31
 32def _create_agent(config: AgentConfig) -> BaseAgent:
 33    """Create an agent instance from config."""
 34    cls = _ROLE_REGISTRY.get(config.role)
 35    if cls is None:
 36        raise ValueError(
 37            f"No agent class registered for role {config.role!r}. "
 38            f"Available: {list(_ROLE_REGISTRY)}"
 39        )
 40    return cls(config)
 41
 42
 43# ---------------------------------------------------------------------------
 44# Arena
 45# ---------------------------------------------------------------------------
 46
 47
 48class Arena:
 49    """A self-contained agent group with its own state, sources, and phases.
 50
 51    The arena is the main execution unit in MAF. It:
 52    1. Loads its configuration (phases, agents, sources, memory)
 53    2. Builds a PhaseGraph from the config
 54    3. Runs the graph against an ArenaState
 55    4. Returns the final state with decisions/signals
 56    """
 57
 58    def __init__(
 59        self,
 60        config: ArenaConfig,
 61        llm_factory: Any = None,
 62        source_registry: Any = None,
 63        memory_store: Any = None,
 64    ) -> None:
 65        self.config = config
 66        self.name = config.name
 67        self._llm_factory = llm_factory
 68        self._source_registry = source_registry
 69        self._memory_store = memory_store
 70        self._graph: PhaseGraph | None = None
 71
 72    @property
 73    def source_registry(self) -> Any:
 74        """Public access to the source registry."""
 75        return self._source_registry
 76
 77    @property
 78    def memory_store(self) -> Any:
 79        """Public access to the memory store."""
 80        return self._memory_store
 81
 82    def _has_target_filters(self) -> bool:
 83        """True iff any agent or source in this arena declares
 84        ``applicable_target_types``. Used to decide whether to rebuild
 85        the graph per-run or use the cached one.
 86        """
 87        if any(
 88            (a.applicable_target_types or [])
 89            for p in self.config.phases for a in p.agents
 90        ):
 91            return True
 92        if any((s.applicable_target_types or []) for s in self.config.sources):
 93            return True
 94        return False
 95
 96    def _build_graph(self, target_type: str | None = None) -> PhaseGraph:
 97        """Build the phase graph from config.
 98
 99        When ``target_type`` is given, agents whose ``applicable_target_types``
100        list is set AND doesn't include the target type are skipped — that's
101        the auto-prune behaviour. Agents with an empty list apply to
102        everything (legacy default).
103        """
104        phases: list[Phase] = []
105        for phase_cfg in self.config.phases:
106            agent_configs = phase_cfg.agents
107
108            # Filter agents if selected_analysts is configured
109            if (
110                self.config.selected_analysts
111                and phase_cfg.pattern == "parallel"
112                and all(a.role in ("analyst", "specialist") for a in agent_configs)
113            ):
114                agent_configs = [
115                    a for a in agent_configs
116                    if a.name in self.config.selected_analysts
117                    or any(
118                        sel in a.name for sel in self.config.selected_analysts
119                    )
120                ]
121
122            # Auto-prune by target type. Empty list = applies everywhere.
123            if target_type:
124                agent_configs = [
125                    a for a in agent_configs
126                    if not a.applicable_target_types
127                    or target_type in a.applicable_target_types
128                ]
129
130            agents = [_create_agent(ac) for ac in agent_configs]
131            phase = Phase(
132                config=phase_cfg,
133                agents=agents,
134                ctx_factory=self._make_context,
135            )
136            phases.append(phase)
137
138        return PhaseGraph(phases)
139
140    def _make_context(self, agent_config: AgentConfig) -> AgentContext:
141        """Create an AgentContext for a specific agent."""
142        from maf.llm.client import LLMClient
143        from maf.sources.registry import SourceRegistry
144
145        # Resolve LLM tier
146        llm: LLMClient
147        if self._llm_factory:
148            llm = self._llm_factory(agent_config.llm_tier)
149        else:
150            llm = LLMClient.null()
151
152        # Resolve sources
153        sources: SourceRegistry
154        if self._source_registry:
155            sources = self._source_registry
156        else:
157            sources = SourceRegistry()
158
159        # Resolve memory
160        memory = None
161        if agent_config.memory and self._memory_store:
162            memory = self._memory_store.get(agent_config.memory)
163
164        return AgentContext(
165            config=agent_config,
166            sources=sources,
167            llm=llm,
168            memory=memory,
169        )
170
171    async def run(
172        self,
173        target: dict[str, Any] | None = None,
174        state: ArenaState | None = None,
175        *,
176        correlation_id: str | None = None,
177    ) -> ArenaState:
178        """Execute the arena's full pipeline.
179
180        Parameters
181        ----------
182        target:
183            Context for this run (e.g. {"ticker": "AAPL", "date": "2024-05-10"}).
184        state:
185            Optional pre-existing state to resume from.
186        correlation_id:
187            Optional correlation id to thread through every emitted event
188            (set by the control plane so a caller can stitch arena.start →
189            arena.complete to its run_arena request).
190
191        Returns
192        -------
193        Final ArenaState with reports, debates, decisions, and signal.
194        """
195        if state is None:
196            state = create_initial_state(
197                arena_id=str(uuid.uuid4()),
198                arena_name=self.name,
199                target=target,
200            )
201        elif target:
202            state["target"] = target
203
204        if correlation_id:
205            state.setdefault("metadata", {})["correlation_id"] = correlation_id
206        else:
207            correlation_id = state.get("metadata", {}).get("correlation_id", "")
208            if not correlation_id:
209                correlation_id = uuid.uuid4().hex
210                state.setdefault("metadata", {})["correlation_id"] = correlation_id
211
212        # Make the iteration cap visible to ReplanAgent (read via state).
213        state.setdefault("metadata", {})["max_iterations"] = self.config.max_iterations
214        state.setdefault("iteration", 0)
215
216        # Resolve the target type for auto-prune. Cheap on the cached path
217        # (no filters declared); rebuilds the graph per-run only when at
218        # least one agent or source has ``applicable_target_types`` set.
219        from maf.config import Target as _Target
220        typed_target = _Target.from_dict(state.get("target") or {})
221        state.setdefault("metadata", {})["target_type"] = typed_target.type
222        if self._has_target_filters():
223            graph_for_run = self._build_graph(target_type=typed_target.type)
224        else:
225            if self._graph is None:
226                self._graph = self._build_graph()
227            graph_for_run = self._graph
228
229        bus = get_event_bus()
230        arena_id = state.get("arena_id", "")
231        await bus.publish(
232            "arena.start",
233            arena=self.name, arena_id=arena_id, correlation_id=correlation_id,
234            payload={
235                "target": state.get("target") or {},
236                "target_type": typed_target.type,
237                "phases": list(graph_for_run.order),
238                "selected_analysts": self.config.selected_analysts or [],
239            },
240        )
241
242        logger.info("Arena %r starting (target_type=%s target=%s)",
243                    self.name, typed_target.type, target)
244        t0 = time.monotonic()
245        try:
246            state = await graph_for_run.run(state)
247        except Exception as exc:
248            await bus.publish(
249                "arena.error",
250                arena=self.name, arena_id=arena_id, correlation_id=correlation_id,
251                payload={"error": str(exc), "type": type(exc).__name__},
252            )
253            raise
254
255        elapsed_s = round(time.monotonic() - t0, 3)
256        state["arena_total_seconds"] = elapsed_s
257        logger.info(
258            "Arena %r completed (signal=%s)", self.name, state.get("signal", "")
259        )
260
261        await bus.publish(
262            "arena.complete",
263            arena=self.name, arena_id=arena_id, correlation_id=correlation_id,
264            payload={
265                "signal": state.get("signal", ""),
266                "synthesis_verdict": state.get("synthesis_verdict", ""),
267                "synthesis_score": state.get("synthesis_score", 0.0),
268                "synthesis_confidence": state.get("synthesis_confidence", 0.0),
269                "elapsed_s": elapsed_s,
270                "reports": list((state.get("reports") or {}).keys()),
271                "decisions": list((state.get("decisions") or {}).keys()),
272                "signals": len(state.get("agent_signals") or []),
273            },
274        )
275
276        return state