|
| 1 | +import json |
| 2 | +from typing import TYPE_CHECKING, Any |
| 3 | + |
| 4 | +from ag_ui.core.events import EventType, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent |
| 5 | + |
| 6 | +if TYPE_CHECKING: |
| 7 | + from collections.abc import AsyncIterator, Iterable |
| 8 | + |
| 9 | + from ag_ui.core.events import BaseEvent |
| 10 | + |
| 11 | + |
| 12 | +class SubagentEventFilter: |
| 13 | + """Reorder/suppress AGUI events so subagent frames don't leak into the parent turn. |
| 14 | +
|
| 15 | + Two upstream behaviors collide on ``task``-tool turns: |
| 16 | +
|
| 17 | + 1. ag_ui_langgraph drops the parent's ``task`` TOOL_CALL_START on the |
| 18 | + text→tool_call transition chunk (the chunk that ends the parent's text |
| 19 | + stream also carries the new tool_call name, but the handler returns |
| 20 | + after emitting TEXT_MESSAGE_END). Subsequent chunks only have args, so |
| 21 | + OnChatModelStream never reaches ``is_tool_call_start_event``. The |
| 22 | + ``task`` TOOL_CALL_START finally arrives from the OnToolEnd re-emit — |
| 23 | + *after* the subagent has already streamed text/tool calls to the |
| 24 | + client. |
| 25 | +
|
| 26 | + 2. With ``stream_subgraphs=True``, every chunk emitted from inside |
| 27 | + ``subagent.ainvoke()`` flows through the parent's stream with a |
| 28 | + nested ``langgraph_checkpoint_ns`` (``"tools:UUID|model:UUID"``). |
| 29 | + Without (1)'s TOOL_CALL_START there is no ``task`` segment to |
| 30 | + suppress them against. |
| 31 | +
|
| 32 | + This filter: |
| 33 | +
|
| 34 | + * captures ``task`` tool_call ids from top-level STATE_SNAPSHOT events, |
| 35 | + * synthesizes TOOL_CALL_START + ARGS + END for each on the first nested |
| 36 | + event so the chat creates the segment *before* the subagent runs, |
| 37 | + * drops every nested event (``|`` in ns), |
| 38 | + * drops the LATE OnToolEnd re-emitted START/ARGS/END for tool_calls we |
| 39 | + already synthesized (deduping by tool_call_id). |
| 40 | +
|
| 41 | + The parent's TOOL_CALL_RESULT for the task tool still flows through |
| 42 | + untouched — it's a top-level event with the same ``tool_call_id``, so the |
| 43 | + chat UI flips the synthesized segment to ``done`` exactly like a normal |
| 44 | + tool call. |
| 45 | + """ |
| 46 | + |
| 47 | + # Tool name used by deepagents' SubAgentMiddleware to invoke a subagent. |
| 48 | + TASK_TOOL_NAME = "task" |
| 49 | + |
| 50 | + def __init__(self) -> None: |
| 51 | + # Two-state lifecycle: a tool_call_id starts in ``_pending`` (synthesize |
| 52 | + # on next nested event), then moves to ``_emitted`` (drop the late |
| 53 | + # re-emit). Membership in either is enough to dedup a STATE_SNAPSHOT |
| 54 | + # rebroadcast; ``_emitted`` alone gates the late TOOL_CALL_* |
| 55 | + # re-emit drop. |
| 56 | + self._pending: dict[str, tuple[str, Any]] = {} |
| 57 | + self._emitted: set[str] = set() |
| 58 | + |
| 59 | + async def apply(self, stream: AsyncIterator[BaseEvent]) -> AsyncIterator[BaseEvent]: |
| 60 | + async for event in stream: |
| 61 | + ns = self._checkpoint_ns(event) |
| 62 | + is_nested = "|" in ns |
| 63 | + |
| 64 | + if not is_nested and event.type == EventType.STATE_SNAPSHOT: |
| 65 | + for tcid, name, args in self._iter_latest_task_calls(event): |
| 66 | + if tcid not in self._pending and tcid not in self._emitted: |
| 67 | + self._pending[tcid] = (name, args) |
| 68 | + |
| 69 | + if is_nested: |
| 70 | + for tcid, (name, args) in self._pending.items(): |
| 71 | + yield ToolCallStartEvent(type=EventType.TOOL_CALL_START, tool_call_id=tcid, tool_call_name=name) |
| 72 | + if args: |
| 73 | + # ``default=str`` so a Pydantic model / datetime / other |
| 74 | + # non-JSON-native object in args doesn't kill the entire |
| 75 | + # chat stream — better a stringified field than RUN_ERROR. |
| 76 | + delta = args if isinstance(args, str) else json.dumps(args, default=str) |
| 77 | + yield ToolCallArgsEvent(type=EventType.TOOL_CALL_ARGS, tool_call_id=tcid, delta=delta) |
| 78 | + yield ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=tcid) |
| 79 | + self._emitted.add(tcid) |
| 80 | + self._pending.clear() |
| 81 | + continue |
| 82 | + |
| 83 | + if event.type in (EventType.TOOL_CALL_START, EventType.TOOL_CALL_ARGS, EventType.TOOL_CALL_END): |
| 84 | + tcid = getattr(event, "tool_call_id", None) |
| 85 | + if isinstance(tcid, str) and tcid in self._emitted: |
| 86 | + continue |
| 87 | + |
| 88 | + yield event |
| 89 | + |
| 90 | + @staticmethod |
| 91 | + def _checkpoint_ns(event: BaseEvent) -> str: |
| 92 | + """Extract the LangGraph checkpoint namespace from an AGUI event's raw_event. |
| 93 | +
|
| 94 | + A ``|`` in the namespace means the event was emitted from a *nested* |
| 95 | + LangGraph execution — i.e. from inside a subagent invoked by the parent's |
| 96 | + ``task`` tool. Top-level events have an empty ns or a single |
| 97 | + ``"<node>:UUID"`` segment (e.g. ``"model:..."``, ``"tools:..."``) with no |
| 98 | + pipe. |
| 99 | + """ |
| 100 | + raw = getattr(event, "raw_event", None) |
| 101 | + if not isinstance(raw, dict): |
| 102 | + return "" |
| 103 | + md = raw.get("metadata") or {} |
| 104 | + return str(md.get("langgraph_checkpoint_ns", "") or "") |
| 105 | + |
| 106 | + @classmethod |
| 107 | + def _iter_latest_task_calls(cls, event: BaseEvent) -> Iterable[tuple[str, str, Any]]: |
| 108 | + """Yield ``(tool_call_id, name, args)`` for every ``task`` tool_call on the |
| 109 | + snapshot's latest AIMessage. Caller is responsible for dedup against |
| 110 | + already-emitted ids — this is just the per-snapshot scan. |
| 111 | + """ |
| 112 | + snap = getattr(event, "snapshot", None) |
| 113 | + if not isinstance(snap, dict): |
| 114 | + return |
| 115 | + msgs = snap.get("messages") |
| 116 | + if not isinstance(msgs, list): |
| 117 | + return |
| 118 | + # Only the latest AIMessage matters — older AIMessages were emitted on |
| 119 | + # earlier snapshots and their task ids are already in ``task_calls``. |
| 120 | + # Walking past the latest is just wasted work; the dedup map at the call |
| 121 | + # site is what guarantees no double-synthesis. |
| 122 | + for m in reversed(msgs): |
| 123 | + if cls._msg_role(m) not in ("ai", "assistant"): |
| 124 | + continue |
| 125 | + for tc in cls._msg_field(m, "tool_calls") or []: |
| 126 | + tcid = cls._msg_field(tc, "id") |
| 127 | + name = cls._msg_field(tc, "name") |
| 128 | + if name == cls.TASK_TOOL_NAME and isinstance(tcid, str): |
| 129 | + yield tcid, name, cls._msg_field(tc, "args") |
| 130 | + return |
| 131 | + |
| 132 | + @staticmethod |
| 133 | + def _msg_field(message: Any, name: str, default: Any = None) -> Any: |
| 134 | + """Read a field from a LangChain message or its dict-encoded form. |
| 135 | +
|
| 136 | + STATE_SNAPSHOT can carry either shape depending on whether the snapshot |
| 137 | + has been serialized yet — running through the AGUI encoder turns objects |
| 138 | + into dicts, but the filter here sits *before* the encoder. |
| 139 | + """ |
| 140 | + if isinstance(message, dict): |
| 141 | + return message.get(name, default) |
| 142 | + return getattr(message, name, default) |
| 143 | + |
| 144 | + @classmethod |
| 145 | + def _msg_role(cls, message: Any) -> str: |
| 146 | + return str(cls._msg_field(message, "type", "") or cls._msg_field(message, "role", "") or "").lower() |
0 commit comments