Skip to content

Commit 33bf13d

Browse files
committed
Add forward_message tool
1 parent 4e62596 commit 33bf13d

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed

langgraph_supervisor/handoff.py

+61
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,64 @@ def create_handoff_back_messages(
127127
response_metadata={METADATA_KEY_IS_HANDOFF_BACK: True},
128128
),
129129
)
130+
131+
132+
def create_forward_message_tool(supervisor_name: str = "supervisor") -> BaseTool:
133+
"""Create a tool the supervisor can use to forward a worker message by name.
134+
135+
This helps avoid information loss any time the supervisor rewrites a worker query
136+
to the user and also can save some tokens.
137+
138+
Args:
139+
supervisor_name: The name of the supervisor node (used for namespacing the tool).
140+
141+
Returns:
142+
BaseTool: The 'forward_message' tool.
143+
"""
144+
tool_name = "forward_message"
145+
desc = (
146+
"Forwards the latest message from the specified agent to the user"
147+
" without any changes. Use this to preserve information fidelity, avoid"
148+
" misinterpretation of questions or responses, and save time."
149+
)
150+
151+
@tool(tool_name, description=desc)
152+
def forward_message(
153+
from_agent: str,
154+
state: Annotated[dict, InjectedState],
155+
):
156+
target_message = next(
157+
(
158+
(i, m)
159+
for i, m in enumerate(reversed(state["messages"]))
160+
if isinstance(m, AIMessage)
161+
and (m.name or "").lower() == from_agent.lower()
162+
and not m.response_metadata.get(METADATA_KEY_IS_HANDOFF_BACK)
163+
),
164+
None,
165+
)
166+
if not target_message:
167+
found_names = set(
168+
m.name for m in state["messages"] if isinstance(m, AIMessage) and m.name
169+
)
170+
return (
171+
f"Could not find message from source agent {from_agent}. Found names: {found_names}"
172+
)
173+
updates = [
174+
AIMessage(
175+
content=target_message[1].content,
176+
name=supervisor_name,
177+
id=str(uuid.uuid4()),
178+
),
179+
]
180+
181+
return Command(
182+
graph=Command.PARENT,
183+
# NOTE: this does nothing.
184+
goto="__end__",
185+
# we also propagate the update to make sure the handoff messages are applied
186+
# to the parent graph's state
187+
update={"messages": updates},
188+
)
189+
190+
return forward_message

langgraph_supervisor/supervisor.py

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from langgraph_supervisor.agent_name import AgentNameMode, with_agent_name
1818
from langgraph_supervisor.handoff import (
1919
METADATA_KEY_HANDOFF_DESTINATION,
20+
create_forward_message_tool,
2021
create_handoff_back_messages,
2122
create_handoff_tool,
2223
)
@@ -116,6 +117,7 @@ def create_supervisor(
116117
add_handoff_back_messages: bool = True,
117118
supervisor_name: str = "supervisor",
118119
include_agent_name: AgentNameMode | None = None,
120+
enable_forwarding: bool = False,
119121
) -> StateGraph:
120122
"""Create a multi-agent supervisor.
121123
@@ -170,6 +172,9 @@ def create_supervisor(
170172
- None: Relies on the LLM provider using the name attribute on the AI message. Currently, only OpenAI supports this.
171173
- "inline": Add the agent name directly into the content field of the AI message using XML-style tags.
172174
Example: "How can I help you" -> "<name>agent_name</name><content>How can I help you?</content>"
175+
enable_forwarding: Whether to add a forward_message tool to the supervisor.
176+
This permits the supervisor to forward the latest message from a specific agent to the user.
177+
Recommended that you set "add_handoff_back_messages" to False when using this option.
173178
"""
174179
agent_names = set()
175180
for agent in agents:
@@ -200,6 +205,9 @@ def create_supervisor(
200205
handoff_destinations = [create_handoff_tool(agent_name=agent.name) for agent in agents]
201206
all_tools = (tools or []) + handoff_destinations
202207

208+
if enable_forwarding:
209+
all_tools.append(create_forward_message_tool(supervisor_name))
210+
203211
if _supports_disable_parallel_tool_calls(model):
204212
model = model.bind_tools(all_tools, parallel_tool_calls=parallel_tool_calls)
205213
else:

tests/test_supervisor.py

+131
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,134 @@ def web_search(query: str) -> str:
275275
assert result_full_history["messages"][17] == math_agent_messages[2]
276276
# final supervisor message
277277
assert result_full_history["messages"][-1] == supervisor_messages[-1]
278+
279+
280+
def test_supervisor_message_forwarding():
281+
"""Test that the supervisor forwards a message to a specific agent and receives the correct response."""
282+
283+
@tool
284+
def echo_tool(text: str) -> str:
285+
"""Echo the input text."""
286+
return text
287+
288+
# Agent that simply echoes the message
289+
echo_model = FakeChatModel(
290+
responses=[
291+
AIMessage(content="Echo: test forwarding!"),
292+
]
293+
)
294+
echo_agent = create_react_agent(
295+
model=echo_model.bind_tools([echo_tool]),
296+
tools=[echo_tool],
297+
name="echo_agent",
298+
)
299+
300+
supervisor_messages = [
301+
AIMessage(
302+
content="",
303+
tool_calls=[
304+
{
305+
"name": "transfer_to_echo_agent",
306+
"args": {},
307+
"id": "call_gyQSgJQm5jJtPcF5ITe8GGGF",
308+
"type": "tool_call",
309+
}
310+
],
311+
),
312+
AIMessage(
313+
content="",
314+
tool_calls=[
315+
{
316+
"name": "forward_message",
317+
"args": {"from_agent": "echo_agent"},
318+
"id": "abcd123",
319+
"type": "tool_call",
320+
}
321+
],
322+
),
323+
]
324+
325+
workflow = create_supervisor(
326+
[echo_agent],
327+
model=FakeChatModel(responses=supervisor_messages),
328+
enable_forwarding=True,
329+
)
330+
app = workflow.compile()
331+
332+
result = app.invoke({"messages": [HumanMessage(content="Scooby-dooby-doo")]})
333+
334+
def get_tool_calls(msg):
335+
tool_calls = getattr(msg, "tool_calls", None)
336+
if tool_calls is None:
337+
return None
338+
return [
339+
{"name": tc["name"], "args": tc["args"]}
340+
for tc in tool_calls
341+
if tc["type"] == "tool_call"
342+
]
343+
344+
received = [
345+
{
346+
"name": msg.name,
347+
"content": msg.content,
348+
"tool_calls": get_tool_calls(msg),
349+
"type": msg.type,
350+
}
351+
for msg in result["messages"]
352+
]
353+
354+
expected = [
355+
{
356+
"name": None,
357+
"content": "Scooby-dooby-doo",
358+
"tool_calls": None,
359+
"type": "human",
360+
},
361+
{
362+
"name": "supervisor",
363+
"content": "",
364+
"tool_calls": [
365+
{
366+
"name": "transfer_to_echo_agent",
367+
"args": {},
368+
}
369+
],
370+
"type": "ai",
371+
},
372+
{
373+
"name": "transfer_to_echo_agent",
374+
"content": "Successfully transferred to echo_agent",
375+
"tool_calls": None,
376+
"type": "tool",
377+
},
378+
{
379+
"name": "echo_agent",
380+
"content": "Echo: test forwarding!",
381+
"tool_calls": [],
382+
"type": "ai",
383+
},
384+
{
385+
"name": "echo_agent",
386+
"content": "Transferring back to supervisor",
387+
"tool_calls": [
388+
{
389+
"name": "transfer_back_to_supervisor",
390+
"args": {},
391+
}
392+
],
393+
"type": "ai",
394+
},
395+
{
396+
"name": "transfer_back_to_supervisor",
397+
"content": "Successfully transferred back to supervisor",
398+
"tool_calls": None,
399+
"type": "tool",
400+
},
401+
{
402+
"name": "supervisor",
403+
"content": "Echo: test forwarding!",
404+
"tool_calls": [],
405+
"type": "ai",
406+
},
407+
]
408+
assert received == expected

0 commit comments

Comments
 (0)