1"""Phase execution — orchestrates agents according to patterns.""" 2 3from __future__ import annotations 4 5import asyncio 6import copy 7import logging 8import time 9from typing import Any 10 11from maf.config import PhaseConfig, AgentConfig 12from maf.core.agent import AgentContext, BaseAgent 13from maf.core.signal import extract_signal 14from maf.core.state import ArenaState, DebateState, DebateMessage 15from maf.streaming import get_event_bus 16 17logger = logging.getLogger(__name__) 18 19 20class Phase: 21 """Executes a group of agents according to a pattern. 22 23 Patterns: 24 - parallel: all agents run concurrently, state merged after 25 - sequential: agents run in order, each sees prior state updates 26 - debate: round-robin debaters for max_rounds, then judge synthesizes 27 """ 28 29 def __init__( 30 self, 31 config: PhaseConfig, 32 agents: list[BaseAgent], 33 ctx_factory: Any, # callable(AgentConfig) -> AgentContext 34 ) -> None: 35 self.config = config 36 self.name = config.name 37 self.pattern = config.pattern 38 self.agents = agents 39 self.ctx_factory = ctx_factory 40 41 async def run(self, state: ArenaState) -> ArenaState: 42 """Execute this phase and return updated state.""" 43 state["current_phase"] = self.name 44 # Record the agent roster so downstream phases (ReplanAgent) can 45 # detect specialists that failed without emitting a signal. 46 if self.pattern in ("parallel", "debate"): 47 state.setdefault("metadata", {})["phase_agents"] = [ 48 a.name for a in self.agents 49 if a.role in ("specialist", "analyst", "debater") 50 ] 51 logger.info("Phase %r starting (pattern=%s, agents=%d)", 52 self.name, self.pattern, len(self.agents)) 53 54 bus = get_event_bus() 55 arena_name = state.get("arena_name", "") 56 arena_id = state.get("arena_id", "") 57 correlation_id = state.get("metadata", {}).get("correlation_id", "") 58 59 t0 = time.monotonic() 60 await bus.publish( 61 "phase.start", 62 arena=arena_name, 63 arena_id=arena_id, 64 phase=self.name, 65 correlation_id=correlation_id, 66 payload={ 67 "pattern": self.pattern, 68 "agents": [a.name for a in self.agents], 69 }, 70 ) 71 72 try: 73 if self.pattern == "parallel": 74 state = await self._run_parallel(state) 75 elif self.pattern == "sequential": 76 state = await self._run_sequential(state) 77 elif self.pattern == "debate": 78 state = await self._run_debate(state) 79 else: 80 raise ValueError(f"Unknown pattern: {self.pattern}") 81 except Exception as exc: 82 await bus.publish( 83 "phase.error", 84 arena=arena_name, arena_id=arena_id, phase=self.name, 85 correlation_id=correlation_id, 86 payload={"error": str(exc), "type": type(exc).__name__}, 87 ) 88 raise 89 90 elapsed_s = round(time.monotonic() - t0, 3) 91 await bus.publish( 92 "phase.complete", 93 arena=arena_name, arena_id=arena_id, phase=self.name, 94 correlation_id=correlation_id, 95 payload={ 96 "elapsed_s": elapsed_s, 97 "agents_completed": len(self.agents), 98 "signals_so_far": len(state.get("agent_signals") or []), 99 "reports_so_far": list((state.get("reports") or {}).keys()), 100 }, 101 ) 102 return state 103 104 async def _run_parallel(self, state: ArenaState) -> ArenaState: 105 """Run all agents in parallel with isolated state copies. 106 107 Each agent gets its own copy of state (message isolation per the paper). 108 Reports are merged back into the original state afterward. 109 """ 110 bus = get_event_bus() 111 arena_name = state.get("arena_name", "") 112 arena_id = state.get("arena_id", "") 113 correlation_id = state.get("metadata", {}).get("correlation_id", "") 114 115 async def _run_one(agent: BaseAgent) -> tuple[str, ArenaState]: 116 agent_state = copy.deepcopy(state) 117 ctx = self.ctx_factory(agent.config) 118 await bus.publish( 119 "agent.start", 120 arena=arena_name, arena_id=arena_id, phase=self.name, 121 correlation_id=correlation_id, 122 payload={"agent": agent.name, "role": agent.role}, 123 ) 124 t0 = time.monotonic() 125 try: 126 result = await agent.run(agent_state, ctx) 127 except Exception as exc: 128 await bus.publish( 129 "agent.error", 130 arena=arena_name, arena_id=arena_id, phase=self.name, 131 correlation_id=correlation_id, 132 payload={"agent": agent.name, "error": str(exc), "type": type(exc).__name__}, 133 ) 134 raise 135 await bus.publish( 136 "agent.complete", 137 arena=arena_name, arena_id=arena_id, phase=self.name, 138 correlation_id=correlation_id, 139 payload={ 140 "agent": agent.name, 141 "elapsed_s": round(time.monotonic() - t0, 3), 142 "report_chars": len(result.get("reports", {}).get(agent.name, "")), 143 }, 144 ) 145 return agent.name, result 146 147 results = await asyncio.gather( 148 *[_run_one(a) for a in self.agents], 149 return_exceptions=True, 150 ) 151 152 # Merge reports and decisions from each agent back 153 for result in results: 154 if isinstance(result, Exception): 155 logger.error("Agent failed in parallel phase: %s", result) 156 continue 157 agent_name, agent_state = result 158 # Merge reports 159 for k, v in agent_state.get("reports", {}).items(): 160 state.setdefault("reports", {})[k] = v 161 # Merge decisions 162 for k, v in agent_state.get("decisions", {}).items(): 163 state.setdefault("decisions", {})[k] = v 164 # Merge structured agent signals (each parallel agent emits one) 165 for sig in agent_state.get("agent_signals", []) or []: 166 state.setdefault("agent_signals", []).append(sig) 167 if isinstance(sig, dict): 168 await bus.publish( 169 "agent.signal", 170 arena=arena_name, arena_id=arena_id, phase=self.name, 171 correlation_id=correlation_id, 172 payload={ 173 "agent": sig.get("agent", agent_name), 174 "domain": sig.get("domain", ""), 175 "signal": sig.get("signal", ""), 176 "confidence": sig.get("confidence", 0.0), 177 "summary": sig.get("summary", "")[:300], 178 }, 179 ) 180 # Merge source metrics (each call gets one row in the trail) 181 for m in agent_state.get("source_metrics", []) or []: 182 state.setdefault("source_metrics", []).append(m) 183 # Merge trace 184 for k, v in (agent_state.get("trace") or {}).items(): 185 state.setdefault("trace", {})[k] = v 186 187 return state 188 189 async def _run_sequential(self, state: ArenaState) -> ArenaState: 190 """Run agents one after another, each seeing prior state.""" 191 for agent in self.agents: 192 try: 193 ctx = self.ctx_factory(agent.config) 194 state = await agent.run(state, ctx) 195 except Exception as exc: 196 logger.error("Agent %r failed in sequential phase: %s", agent.name, exc) 197 state.setdefault("reports", {})[agent.name] = ( 198 f"[Agent {agent.name} encountered an error: {exc}]" 199 ) 200 return state 201 202 async def _run_debate(self, state: ArenaState) -> ArenaState: 203 """Run N-way debate pattern. 204 205 Debaters take turns for max_rounds. Judge (last agent or agent with 206 role="judge") synthesizes after all rounds complete. 207 """ 208 debaters = [a for a in self.agents if a.role == "debater"] 209 judges = [a for a in self.agents if a.role == "judge"] 210 211 if not debaters: 212 raise ValueError(f"Debate phase {self.name!r} has no debater agents") 213 214 # Initialize debate state 215 debate_name = self.name 216 debate_state = DebateState( 217 history=[], 218 per_agent={a.name: [] for a in debaters}, 219 latest_per_agent={a.name: "" for a in debaters}, 220 judge_decision="", 221 count=0, 222 ) 223 224 # Round-robin debate 225 for round_num in range(self.config.max_rounds): 226 for debater in debaters: 227 try: 228 ctx = self.ctx_factory(debater.config) 229 ctx.shared["debate_state"] = debate_state 230 ctx.shared["debate_round"] = round_num 231 ctx.shared["debate_name"] = debate_name 232 233 agent_state = copy.deepcopy(state) 234 agent_state = await debater.run(agent_state, ctx) 235 236 argument = agent_state.get("reports", {}).get(debater.name, "") 237 if argument: 238 msg = DebateMessage( 239 agent=debater.name, 240 content=argument, 241 round=round_num, 242 ) 243 debate_state["history"].append(msg) 244 debate_state["per_agent"][debater.name].append(argument) 245 debate_state["latest_per_agent"][debater.name] = argument 246 debate_state["count"] = debate_state.get("count", 0) + 1 247 except Exception as exc: 248 logger.error( 249 "Debater %r failed (round %d): %s", debater.name, round_num, exc, 250 ) 251 252 # Store debate state 253 state.setdefault("debates", {})[debate_name] = debate_state 254 255 # Judge synthesizes 256 for judge in judges: 257 try: 258 ctx = self.ctx_factory(judge.config) 259 ctx.shared["debate_state"] = debate_state 260 ctx.shared["debate_name"] = debate_name 261 state = await judge.run(state, ctx) 262 except Exception as exc: 263 logger.error("Judge %r failed: %s", judge.name, exc) 264 265 # Signal extraction if configured 266 if self.config.signal_extract: 267 judge_output = state.get("decisions", {}).get(judge.name, "") 268 if judge_output: 269 270 async def _chat(system: str, user: str) -> str: 271 resp = await ctx.llm.chat(system, user) 272 return resp.text 273 274 signal = await extract_signal(judge_output, _chat) 275 state["signal"] = signal 276 277 return state