1"""Arena — isolated agent group that loads config, builds phases, and runs.""" 2 3from __future__ import annotations 4 5import logging 6import time 7import uuid 8from typing import Any 9 10from maf.config import ArenaConfig, AgentConfig 11from maf.core.agent import AgentContext, BaseAgent 12from maf.core.graph import PhaseGraph 13from maf.core.phase import Phase 14from maf.core.state import ArenaState, create_initial_state 15from maf.streaming import get_event_bus 16 17logger = logging.getLogger(__name__) 18 19 20# --------------------------------------------------------------------------- 21# Agent factory — resolves role to concrete agent class 22# --------------------------------------------------------------------------- 23 24_ROLE_REGISTRY: dict[str, type[BaseAgent]] = {} 25 26 27def register_role(role: str, cls: type[BaseAgent]) -> None: 28 """Register an agent class for a role name.""" 29 _ROLE_REGISTRY[role] = cls 30 31 32def _create_agent(config: AgentConfig) -> BaseAgent: 33 """Create an agent instance from config.""" 34 cls = _ROLE_REGISTRY.get(config.role) 35 if cls is None: 36 raise ValueError( 37 f"No agent class registered for role {config.role!r}. " 38 f"Available: {list(_ROLE_REGISTRY)}" 39 ) 40 return cls(config) 41 42 43# --------------------------------------------------------------------------- 44# Arena 45# --------------------------------------------------------------------------- 46 47 48class Arena: 49 """A self-contained agent group with its own state, sources, and phases. 50 51 The arena is the main execution unit in MAF. It: 52 1. Loads its configuration (phases, agents, sources, memory) 53 2. Builds a PhaseGraph from the config 54 3. Runs the graph against an ArenaState 55 4. Returns the final state with decisions/signals 56 """ 57 58 def __init__( 59 self, 60 config: ArenaConfig, 61 llm_factory: Any = None, 62 source_registry: Any = None, 63 memory_store: Any = None, 64 ) -> None: 65 self.config = config 66 self.name = config.name 67 self._llm_factory = llm_factory 68 self._source_registry = source_registry 69 self._memory_store = memory_store 70 self._graph: PhaseGraph | None = None 71 72 @property 73 def source_registry(self) -> Any: 74 """Public access to the source registry.""" 75 return self._source_registry 76 77 @property 78 def memory_store(self) -> Any: 79 """Public access to the memory store.""" 80 return self._memory_store 81 82 def _has_target_filters(self) -> bool: 83 """True iff any agent or source in this arena declares 84 ``applicable_target_types``. Used to decide whether to rebuild 85 the graph per-run or use the cached one. 86 """ 87 if any( 88 (a.applicable_target_types or []) 89 for p in self.config.phases for a in p.agents 90 ): 91 return True 92 if any((s.applicable_target_types or []) for s in self.config.sources): 93 return True 94 return False 95 96 def _build_graph(self, target_type: str | None = None) -> PhaseGraph: 97 """Build the phase graph from config. 98 99 When ``target_type`` is given, agents whose ``applicable_target_types`` 100 list is set AND doesn't include the target type are skipped — that's 101 the auto-prune behaviour. Agents with an empty list apply to 102 everything (legacy default). 103 """ 104 phases: list[Phase] = [] 105 for phase_cfg in self.config.phases: 106 agent_configs = phase_cfg.agents 107 108 # Filter agents if selected_analysts is configured 109 if ( 110 self.config.selected_analysts 111 and phase_cfg.pattern == "parallel" 112 and all(a.role in ("analyst", "specialist") for a in agent_configs) 113 ): 114 agent_configs = [ 115 a for a in agent_configs 116 if a.name in self.config.selected_analysts 117 or any( 118 sel in a.name for sel in self.config.selected_analysts 119 ) 120 ] 121 122 # Auto-prune by target type. Empty list = applies everywhere. 123 if target_type: 124 agent_configs = [ 125 a for a in agent_configs 126 if not a.applicable_target_types 127 or target_type in a.applicable_target_types 128 ] 129 130 agents = [_create_agent(ac) for ac in agent_configs] 131 phase = Phase( 132 config=phase_cfg, 133 agents=agents, 134 ctx_factory=self._make_context, 135 ) 136 phases.append(phase) 137 138 return PhaseGraph(phases) 139 140 def _make_context(self, agent_config: AgentConfig) -> AgentContext: 141 """Create an AgentContext for a specific agent.""" 142 from maf.llm.client import LLMClient 143 from maf.sources.registry import SourceRegistry 144 145 # Resolve LLM tier 146 llm: LLMClient 147 if self._llm_factory: 148 llm = self._llm_factory(agent_config.llm_tier) 149 else: 150 llm = LLMClient.null() 151 152 # Resolve sources 153 sources: SourceRegistry 154 if self._source_registry: 155 sources = self._source_registry 156 else: 157 sources = SourceRegistry() 158 159 # Resolve memory 160 memory = None 161 if agent_config.memory and self._memory_store: 162 memory = self._memory_store.get(agent_config.memory) 163 164 return AgentContext( 165 config=agent_config, 166 sources=sources, 167 llm=llm, 168 memory=memory, 169 ) 170 171 async def run( 172 self, 173 target: dict[str, Any] | None = None, 174 state: ArenaState | None = None, 175 *, 176 correlation_id: str | None = None, 177 ) -> ArenaState: 178 """Execute the arena's full pipeline. 179 180 Parameters 181 ---------- 182 target: 183 Context for this run (e.g. {"ticker": "AAPL", "date": "2024-05-10"}). 184 state: 185 Optional pre-existing state to resume from. 186 correlation_id: 187 Optional correlation id to thread through every emitted event 188 (set by the control plane so a caller can stitch arena.start → 189 arena.complete to its run_arena request). 190 191 Returns 192 ------- 193 Final ArenaState with reports, debates, decisions, and signal. 194 """ 195 if state is None: 196 state = create_initial_state( 197 arena_id=str(uuid.uuid4()), 198 arena_name=self.name, 199 target=target, 200 ) 201 elif target: 202 state["target"] = target 203 204 if correlation_id: 205 state.setdefault("metadata", {})["correlation_id"] = correlation_id 206 else: 207 correlation_id = state.get("metadata", {}).get("correlation_id", "") 208 if not correlation_id: 209 correlation_id = uuid.uuid4().hex 210 state.setdefault("metadata", {})["correlation_id"] = correlation_id 211 212 # Make the iteration cap visible to ReplanAgent (read via state). 213 state.setdefault("metadata", {})["max_iterations"] = self.config.max_iterations 214 state.setdefault("iteration", 0) 215 216 # Resolve the target type for auto-prune. Cheap on the cached path 217 # (no filters declared); rebuilds the graph per-run only when at 218 # least one agent or source has ``applicable_target_types`` set. 219 from maf.config import Target as _Target 220 typed_target = _Target.from_dict(state.get("target") or {}) 221 state.setdefault("metadata", {})["target_type"] = typed_target.type 222 if self._has_target_filters(): 223 graph_for_run = self._build_graph(target_type=typed_target.type) 224 else: 225 if self._graph is None: 226 self._graph = self._build_graph() 227 graph_for_run = self._graph 228 229 bus = get_event_bus() 230 arena_id = state.get("arena_id", "") 231 await bus.publish( 232 "arena.start", 233 arena=self.name, arena_id=arena_id, correlation_id=correlation_id, 234 payload={ 235 "target": state.get("target") or {}, 236 "target_type": typed_target.type, 237 "phases": list(graph_for_run.order), 238 "selected_analysts": self.config.selected_analysts or [], 239 }, 240 ) 241 242 logger.info("Arena %r starting (target_type=%s target=%s)", 243 self.name, typed_target.type, target) 244 t0 = time.monotonic() 245 try: 246 state = await graph_for_run.run(state) 247 except Exception as exc: 248 await bus.publish( 249 "arena.error", 250 arena=self.name, arena_id=arena_id, correlation_id=correlation_id, 251 payload={"error": str(exc), "type": type(exc).__name__}, 252 ) 253 raise 254 255 elapsed_s = round(time.monotonic() - t0, 3) 256 state["arena_total_seconds"] = elapsed_s 257 logger.info( 258 "Arena %r completed (signal=%s)", self.name, state.get("signal", "") 259 ) 260 261 await bus.publish( 262 "arena.complete", 263 arena=self.name, arena_id=arena_id, correlation_id=correlation_id, 264 payload={ 265 "signal": state.get("signal", ""), 266 "synthesis_verdict": state.get("synthesis_verdict", ""), 267 "synthesis_score": state.get("synthesis_score", 0.0), 268 "synthesis_confidence": state.get("synthesis_confidence", 0.0), 269 "elapsed_s": elapsed_s, 270 "reports": list((state.get("reports") or {}).keys()), 271 "decisions": list((state.get("decisions") or {}).keys()), 272 "signals": len(state.get("agent_signals") or []), 273 }, 274 ) 275 276 return state