checking system…
Docs / back / src/maf/core/phase.py · line 20
Python · 278 lines
  1"""Phase execution — orchestrates agents according to patterns."""
  2
  3from __future__ import annotations
  4
  5import asyncio
  6import copy
  7import logging
  8import time
  9from typing import Any
 10
 11from maf.config import PhaseConfig, AgentConfig
 12from maf.core.agent import AgentContext, BaseAgent
 13from maf.core.signal import extract_signal
 14from maf.core.state import ArenaState, DebateState, DebateMessage
 15from maf.streaming import get_event_bus
 16
 17logger = logging.getLogger(__name__)
 18
 19
 20class Phase:
 21    """Executes a group of agents according to a pattern.
 22
 23    Patterns:
 24    - parallel: all agents run concurrently, state merged after
 25    - sequential: agents run in order, each sees prior state updates
 26    - debate: round-robin debaters for max_rounds, then judge synthesizes
 27    """
 28
 29    def __init__(
 30        self,
 31        config: PhaseConfig,
 32        agents: list[BaseAgent],
 33        ctx_factory: Any,  # callable(AgentConfig) -> AgentContext
 34    ) -> None:
 35        self.config = config
 36        self.name = config.name
 37        self.pattern = config.pattern
 38        self.agents = agents
 39        self.ctx_factory = ctx_factory
 40
 41    async def run(self, state: ArenaState) -> ArenaState:
 42        """Execute this phase and return updated state."""
 43        state["current_phase"] = self.name
 44        # Record the agent roster so downstream phases (ReplanAgent) can
 45        # detect specialists that failed without emitting a signal.
 46        if self.pattern in ("parallel", "debate"):
 47            state.setdefault("metadata", {})["phase_agents"] = [
 48                a.name for a in self.agents
 49                if a.role in ("specialist", "analyst", "debater")
 50            ]
 51        logger.info("Phase %r starting (pattern=%s, agents=%d)",
 52                     self.name, self.pattern, len(self.agents))
 53
 54        bus = get_event_bus()
 55        arena_name = state.get("arena_name", "")
 56        arena_id = state.get("arena_id", "")
 57        correlation_id = state.get("metadata", {}).get("correlation_id", "")
 58
 59        t0 = time.monotonic()
 60        await bus.publish(
 61            "phase.start",
 62            arena=arena_name,
 63            arena_id=arena_id,
 64            phase=self.name,
 65            correlation_id=correlation_id,
 66            payload={
 67                "pattern": self.pattern,
 68                "agents": [a.name for a in self.agents],
 69            },
 70        )
 71
 72        try:
 73            if self.pattern == "parallel":
 74                state = await self._run_parallel(state)
 75            elif self.pattern == "sequential":
 76                state = await self._run_sequential(state)
 77            elif self.pattern == "debate":
 78                state = await self._run_debate(state)
 79            else:
 80                raise ValueError(f"Unknown pattern: {self.pattern}")
 81        except Exception as exc:
 82            await bus.publish(
 83                "phase.error",
 84                arena=arena_name, arena_id=arena_id, phase=self.name,
 85                correlation_id=correlation_id,
 86                payload={"error": str(exc), "type": type(exc).__name__},
 87            )
 88            raise
 89
 90        elapsed_s = round(time.monotonic() - t0, 3)
 91        await bus.publish(
 92            "phase.complete",
 93            arena=arena_name, arena_id=arena_id, phase=self.name,
 94            correlation_id=correlation_id,
 95            payload={
 96                "elapsed_s": elapsed_s,
 97                "agents_completed": len(self.agents),
 98                "signals_so_far": len(state.get("agent_signals") or []),
 99                "reports_so_far": list((state.get("reports") or {}).keys()),
100            },
101        )
102        return state
103
104    async def _run_parallel(self, state: ArenaState) -> ArenaState:
105        """Run all agents in parallel with isolated state copies.
106
107        Each agent gets its own copy of state (message isolation per the paper).
108        Reports are merged back into the original state afterward.
109        """
110        bus = get_event_bus()
111        arena_name = state.get("arena_name", "")
112        arena_id = state.get("arena_id", "")
113        correlation_id = state.get("metadata", {}).get("correlation_id", "")
114
115        async def _run_one(agent: BaseAgent) -> tuple[str, ArenaState]:
116            agent_state = copy.deepcopy(state)
117            ctx = self.ctx_factory(agent.config)
118            await bus.publish(
119                "agent.start",
120                arena=arena_name, arena_id=arena_id, phase=self.name,
121                correlation_id=correlation_id,
122                payload={"agent": agent.name, "role": agent.role},
123            )
124            t0 = time.monotonic()
125            try:
126                result = await agent.run(agent_state, ctx)
127            except Exception as exc:
128                await bus.publish(
129                    "agent.error",
130                    arena=arena_name, arena_id=arena_id, phase=self.name,
131                    correlation_id=correlation_id,
132                    payload={"agent": agent.name, "error": str(exc), "type": type(exc).__name__},
133                )
134                raise
135            await bus.publish(
136                "agent.complete",
137                arena=arena_name, arena_id=arena_id, phase=self.name,
138                correlation_id=correlation_id,
139                payload={
140                    "agent": agent.name,
141                    "elapsed_s": round(time.monotonic() - t0, 3),
142                    "report_chars": len(result.get("reports", {}).get(agent.name, "")),
143                },
144            )
145            return agent.name, result
146
147        results = await asyncio.gather(
148            *[_run_one(a) for a in self.agents],
149            return_exceptions=True,
150        )
151
152        # Merge reports and decisions from each agent back
153        for result in results:
154            if isinstance(result, Exception):
155                logger.error("Agent failed in parallel phase: %s", result)
156                continue
157            agent_name, agent_state = result
158            # Merge reports
159            for k, v in agent_state.get("reports", {}).items():
160                state.setdefault("reports", {})[k] = v
161            # Merge decisions
162            for k, v in agent_state.get("decisions", {}).items():
163                state.setdefault("decisions", {})[k] = v
164            # Merge structured agent signals (each parallel agent emits one)
165            for sig in agent_state.get("agent_signals", []) or []:
166                state.setdefault("agent_signals", []).append(sig)
167                if isinstance(sig, dict):
168                    await bus.publish(
169                        "agent.signal",
170                        arena=arena_name, arena_id=arena_id, phase=self.name,
171                        correlation_id=correlation_id,
172                        payload={
173                            "agent": sig.get("agent", agent_name),
174                            "domain": sig.get("domain", ""),
175                            "signal": sig.get("signal", ""),
176                            "confidence": sig.get("confidence", 0.0),
177                            "summary": sig.get("summary", "")[:300],
178                        },
179                    )
180            # Merge source metrics (each call gets one row in the trail)
181            for m in agent_state.get("source_metrics", []) or []:
182                state.setdefault("source_metrics", []).append(m)
183            # Merge trace
184            for k, v in (agent_state.get("trace") or {}).items():
185                state.setdefault("trace", {})[k] = v
186
187        return state
188
189    async def _run_sequential(self, state: ArenaState) -> ArenaState:
190        """Run agents one after another, each seeing prior state."""
191        for agent in self.agents:
192            try:
193                ctx = self.ctx_factory(agent.config)
194                state = await agent.run(state, ctx)
195            except Exception as exc:
196                logger.error("Agent %r failed in sequential phase: %s", agent.name, exc)
197                state.setdefault("reports", {})[agent.name] = (
198                    f"[Agent {agent.name} encountered an error: {exc}]"
199                )
200        return state
201
202    async def _run_debate(self, state: ArenaState) -> ArenaState:
203        """Run N-way debate pattern.
204
205        Debaters take turns for max_rounds. Judge (last agent or agent with
206        role="judge") synthesizes after all rounds complete.
207        """
208        debaters = [a for a in self.agents if a.role == "debater"]
209        judges = [a for a in self.agents if a.role == "judge"]
210
211        if not debaters:
212            raise ValueError(f"Debate phase {self.name!r} has no debater agents")
213
214        # Initialize debate state
215        debate_name = self.name
216        debate_state = DebateState(
217            history=[],
218            per_agent={a.name: [] for a in debaters},
219            latest_per_agent={a.name: "" for a in debaters},
220            judge_decision="",
221            count=0,
222        )
223
224        # Round-robin debate
225        for round_num in range(self.config.max_rounds):
226            for debater in debaters:
227                try:
228                    ctx = self.ctx_factory(debater.config)
229                    ctx.shared["debate_state"] = debate_state
230                    ctx.shared["debate_round"] = round_num
231                    ctx.shared["debate_name"] = debate_name
232
233                    agent_state = copy.deepcopy(state)
234                    agent_state = await debater.run(agent_state, ctx)
235
236                    argument = agent_state.get("reports", {}).get(debater.name, "")
237                    if argument:
238                        msg = DebateMessage(
239                            agent=debater.name,
240                            content=argument,
241                            round=round_num,
242                        )
243                        debate_state["history"].append(msg)
244                        debate_state["per_agent"][debater.name].append(argument)
245                        debate_state["latest_per_agent"][debater.name] = argument
246                        debate_state["count"] = debate_state.get("count", 0) + 1
247                except Exception as exc:
248                    logger.error(
249                        "Debater %r failed (round %d): %s", debater.name, round_num, exc,
250                    )
251
252        # Store debate state
253        state.setdefault("debates", {})[debate_name] = debate_state
254
255        # Judge synthesizes
256        for judge in judges:
257            try:
258                ctx = self.ctx_factory(judge.config)
259                ctx.shared["debate_state"] = debate_state
260                ctx.shared["debate_name"] = debate_name
261                state = await judge.run(state, ctx)
262            except Exception as exc:
263                logger.error("Judge %r failed: %s", judge.name, exc)
264
265            # Signal extraction if configured
266            if self.config.signal_extract:
267                judge_output = state.get("decisions", {}).get(judge.name, "")
268                if judge_output:
269
270                    async def _chat(system: str, user: str) -> str:
271                        resp = await ctx.llm.chat(system, user)
272                        return resp.text
273
274                    signal = await extract_signal(judge_output, _chat)
275                    state["signal"] = signal
276
277        return state