Skip to content

Commit bf2b56d

Browse files
xuanyang15copybara-github
authored andcommitted
fix: recursively extract input/output schema for AgentTool
Fixes: #4154 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 859440231
1 parent 3d96b78 commit bf2b56d

File tree

2 files changed

+288
-11
lines changed

2 files changed

+288
-11
lines changed

src/google/adk/tools/agent_tool.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from __future__ import annotations
1616

1717
from typing import Any
18+
from typing import Optional
1819
from typing import TYPE_CHECKING
1920

2021
from google.genai import types
22+
from pydantic import BaseModel
2123
from pydantic import model_validator
2224
from typing_extensions import override
2325

@@ -37,6 +39,56 @@
3739
from ..agents.base_agent import BaseAgent
3840

3941

42+
def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]:
43+
"""Extracts the input_schema from an agent.
44+
45+
For LlmAgent, returns its input_schema directly.
46+
For agents with sub_agents, recursively searches the first sub-agent for an
47+
input_schema.
48+
49+
Args:
50+
agent: The agent to extract input_schema from.
51+
52+
Returns:
53+
The input_schema if found, None otherwise.
54+
"""
55+
from ..agents.llm_agent import LlmAgent
56+
57+
if isinstance(agent, LlmAgent):
58+
return agent.input_schema
59+
60+
# For composite agents, check the first sub-agent
61+
if agent.sub_agents:
62+
return _get_input_schema(agent.sub_agents[0])
63+
64+
return None
65+
66+
67+
def _get_output_schema(agent: BaseAgent) -> Optional[type[BaseModel]]:
68+
"""Extracts the output_schema from an agent.
69+
70+
For LlmAgent, returns its output_schema directly.
71+
For agents with sub_agents, recursively searches the last sub-agent for an
72+
output_schema.
73+
74+
Args:
75+
agent: The agent to extract output_schema from.
76+
77+
Returns:
78+
The output_schema if found, None otherwise.
79+
"""
80+
from ..agents.llm_agent import LlmAgent
81+
82+
if isinstance(agent, LlmAgent):
83+
return agent.output_schema
84+
85+
# For composite agents, check the last sub-agent
86+
if agent.sub_agents:
87+
return _get_output_schema(agent.sub_agents[-1])
88+
89+
return None
90+
91+
4092
class AgentTool(BaseTool):
4193
"""A tool that wraps an agent.
4294
@@ -74,12 +126,14 @@ def populate_name(cls, data: Any) -> Any:
74126

75127
@override
76128
def _get_declaration(self) -> types.FunctionDeclaration:
77-
from ..agents.llm_agent import LlmAgent
78129
from ..utils.variant_utils import GoogleLLMVariant
79130

80-
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
131+
input_schema = _get_input_schema(self.agent)
132+
output_schema = _get_output_schema(self.agent)
133+
134+
if input_schema:
81135
result = _automatic_function_calling_util.build_function_declaration(
82-
func=self.agent.input_schema, variant=self._api_variant
136+
func=input_schema, variant=self._api_variant
83137
)
84138
# Override the description with the agent's description
85139
result.description = self.agent.description
@@ -114,7 +168,7 @@ def _get_declaration(self) -> types.FunctionDeclaration:
114168
# Set response schema for non-GEMINI_API variants
115169
if self._api_variant != GoogleLLMVariant.GEMINI_API:
116170
# Determine response type based on agent's output schema
117-
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
171+
if output_schema:
118172
# Agent has structured output schema - response is an object
119173
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
120174
result.response_json_schema = {'type': 'object'}
@@ -137,15 +191,15 @@ async def run_async(
137191
args: dict[str, Any],
138192
tool_context: ToolContext,
139193
) -> Any:
140-
from ..agents.llm_agent import LlmAgent
141194
from ..runners import Runner
142195
from ..sessions.in_memory_session_service import InMemorySessionService
143196

144197
if self.skip_summarization:
145198
tool_context.actions.skip_summarization = True
146199

147-
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
148-
input_value = self.agent.input_schema.model_validate(args)
200+
input_schema = _get_input_schema(self.agent)
201+
if input_schema:
202+
input_value = input_schema.model_validate(args)
149203
content = types.Content(
150204
role='user',
151205
parts=[
@@ -212,10 +266,11 @@ async def run_async(
212266
merged_text = '\n'.join(
213267
p.text for p in last_content.parts if p.text and not p.thought
214268
)
215-
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
216-
tool_result = self.agent.output_schema.model_validate_json(
217-
merged_text
218-
).model_dump(exclude_none=True)
269+
output_schema = _get_output_schema(self.agent)
270+
if output_schema:
271+
tool_result = output_schema.model_validate_json(merged_text).model_dump(
272+
exclude_none=True
273+
)
219274
else:
220275
tool_result = merged_text
221276
return tool_result

tests/unittests/tools/test_agent_tool.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,3 +942,225 @@ async def test_run_async_handles_none_parts_in_response():
942942
)
943943

944944
assert tool_result == ''
945+
946+
947+
class TestAgentToolWithCompositeAgents:
948+
"""Tests for AgentTool wrapping composite agents (SequentialAgent, etc.)."""
949+
950+
def test_sequential_agent_with_first_sub_agent_input_schema(self):
951+
"""Test that AgentTool exposes input_schema from first sub-agent of SequentialAgent."""
952+
953+
class CustomInput(BaseModel):
954+
query: str
955+
language: str
956+
957+
first_agent = Agent(
958+
name='first_agent',
959+
model=testing_utils.MockModel.create(responses=['response1']),
960+
input_schema=CustomInput,
961+
)
962+
963+
second_agent = Agent(
964+
name='second_agent',
965+
model=testing_utils.MockModel.create(responses=['response2']),
966+
)
967+
968+
sequence = SequentialAgent(
969+
name='sequence',
970+
description='Process the query through multiple steps',
971+
sub_agents=[first_agent, second_agent],
972+
)
973+
974+
agent_tool = AgentTool(agent=sequence)
975+
declaration = agent_tool._get_declaration()
976+
977+
# Should expose CustomInput schema, not fallback to 'request'
978+
assert declaration.name == 'sequence'
979+
assert declaration.description == 'Process the query through multiple steps'
980+
assert declaration.parameters.properties['query'].type == 'STRING'
981+
assert declaration.parameters.properties['language'].type == 'STRING'
982+
assert 'request' not in declaration.parameters.properties
983+
984+
def test_sequential_agent_without_input_schema_falls_back_to_request(self):
985+
"""Test that AgentTool falls back to 'request' when no sub-agent has input_schema."""
986+
987+
first_agent = Agent(
988+
name='first_agent',
989+
model=testing_utils.MockModel.create(responses=['response1']),
990+
)
991+
992+
second_agent = Agent(
993+
name='second_agent',
994+
model=testing_utils.MockModel.create(responses=['response2']),
995+
)
996+
997+
sequence = SequentialAgent(
998+
name='sequence',
999+
description='Process the query through multiple steps',
1000+
sub_agents=[first_agent, second_agent],
1001+
)
1002+
1003+
agent_tool = AgentTool(agent=sequence)
1004+
declaration = agent_tool._get_declaration()
1005+
1006+
# Should fall back to 'request' parameter
1007+
assert declaration.name == 'sequence'
1008+
assert declaration.parameters.properties['request'].type == 'STRING'
1009+
assert 'query' not in declaration.parameters.properties
1010+
1011+
@mark.parametrize(
1012+
'env_variables',
1013+
[
1014+
'VERTEX',
1015+
],
1016+
indirect=True,
1017+
)
1018+
def test_sequential_agent_with_last_sub_agent_output_schema(
1019+
self, env_variables
1020+
):
1021+
"""Test that AgentTool uses output_schema from last sub-agent of SequentialAgent."""
1022+
1023+
class CustomOutput(BaseModel):
1024+
result: str
1025+
1026+
first_agent = Agent(
1027+
name='first_agent',
1028+
model=testing_utils.MockModel.create(responses=['response1']),
1029+
)
1030+
1031+
second_agent = Agent(
1032+
name='second_agent',
1033+
model=testing_utils.MockModel.create(responses=['response2']),
1034+
output_schema=CustomOutput,
1035+
)
1036+
1037+
sequence = SequentialAgent(
1038+
name='sequence',
1039+
description='Process the query',
1040+
sub_agents=[first_agent, second_agent],
1041+
)
1042+
1043+
agent_tool = AgentTool(agent=sequence)
1044+
declaration = agent_tool._get_declaration()
1045+
1046+
# Should have object response schema from last sub-agent
1047+
assert declaration.response is not None
1048+
assert declaration.response.type == types.Type.OBJECT
1049+
1050+
def test_nested_sequential_agent_input_schema(self):
1051+
"""Test that AgentTool recursively finds input_schema in nested composite agents."""
1052+
1053+
class CustomInput(BaseModel):
1054+
deep_query: str
1055+
1056+
inner_agent = Agent(
1057+
name='inner_agent',
1058+
model=testing_utils.MockModel.create(responses=['response1']),
1059+
input_schema=CustomInput,
1060+
)
1061+
1062+
inner_sequence = SequentialAgent(
1063+
name='inner_sequence',
1064+
sub_agents=[inner_agent],
1065+
)
1066+
1067+
outer_sequence = SequentialAgent(
1068+
name='outer_sequence',
1069+
description='Nested sequence',
1070+
sub_agents=[inner_sequence],
1071+
)
1072+
1073+
agent_tool = AgentTool(agent=outer_sequence)
1074+
declaration = agent_tool._get_declaration()
1075+
1076+
# Should recursively find CustomInput from inner_agent
1077+
assert declaration.name == 'outer_sequence'
1078+
assert 'deep_query' in declaration.parameters.properties
1079+
assert declaration.parameters.properties['deep_query'].type == 'STRING'
1080+
assert 'request' not in declaration.parameters.properties
1081+
1082+
@mark.parametrize(
1083+
'env_variables',
1084+
[
1085+
'GOOGLE_AI',
1086+
'VERTEX',
1087+
],
1088+
indirect=True,
1089+
)
1090+
def test_sequential_agent_custom_schema_end_to_end(self, env_variables):
1091+
"""Test end-to-end flow with SequentialAgent using custom input/output schema."""
1092+
1093+
class CustomInput(BaseModel):
1094+
custom_input: str
1095+
1096+
class CustomOutput(BaseModel):
1097+
custom_output: str
1098+
1099+
function_call_seq = Part.from_function_call(
1100+
name='sequence', args={'custom_input': 'test_input'}
1101+
)
1102+
1103+
mock_model = testing_utils.MockModel.create(
1104+
responses=[
1105+
function_call_seq,
1106+
'{"custom_output": "step1_response"}',
1107+
'{"custom_output": "final_response"}',
1108+
'root_response',
1109+
]
1110+
)
1111+
1112+
first_agent = Agent(
1113+
name='first_agent',
1114+
model=mock_model,
1115+
input_schema=CustomInput,
1116+
)
1117+
1118+
second_agent = Agent(
1119+
name='second_agent',
1120+
model=mock_model,
1121+
output_schema=CustomOutput,
1122+
output_key='seq_output',
1123+
)
1124+
1125+
sequence = SequentialAgent(
1126+
name='sequence',
1127+
description='A sequential pipeline',
1128+
sub_agents=[first_agent, second_agent],
1129+
)
1130+
1131+
root_agent = Agent(
1132+
name='root_agent',
1133+
model=mock_model,
1134+
tools=[AgentTool(agent=sequence)],
1135+
)
1136+
1137+
runner = testing_utils.InMemoryRunner(root_agent)
1138+
runner.run('test1')
1139+
1140+
# Verify the tool declaration sent to LLM has the correct schema
1141+
# The first request is from root_agent, which should have the tool declaration
1142+
first_request = mock_model.requests[0]
1143+
tool_declarations = first_request.config.tools
1144+
assert len(tool_declarations) == 1
1145+
1146+
sequence_tool = tool_declarations[0].function_declarations[0]
1147+
assert sequence_tool.name == 'sequence'
1148+
# Should have 'custom_input' parameter from first sub-agent's input_schema
1149+
assert 'custom_input' in sequence_tool.parameters.properties
1150+
# Should NOT have the fallback 'request' parameter
1151+
assert 'request' not in sequence_tool.parameters.properties
1152+
1153+
def test_empty_sequential_agent_falls_back_to_request(self):
1154+
"""Test that AgentTool with empty SequentialAgent falls back to 'request'."""
1155+
1156+
sequence = SequentialAgent(
1157+
name='empty_sequence',
1158+
description='An empty sequence',
1159+
sub_agents=[],
1160+
)
1161+
1162+
agent_tool = AgentTool(agent=sequence)
1163+
declaration = agent_tool._get_declaration()
1164+
1165+
# Should fall back to 'request' parameter
1166+
assert declaration.parameters.properties['request'].type == 'STRING'

0 commit comments

Comments
 (0)