@@ -942,3 +942,225 @@ async def test_run_async_handles_none_parts_in_response():
942942 )
943943
944944 assert tool_result == ''
945+
946+
947+ class TestAgentToolWithCompositeAgents :
948+ """Tests for AgentTool wrapping composite agents (SequentialAgent, etc.)."""
949+
950+ def test_sequential_agent_with_first_sub_agent_input_schema (self ):
951+ """Test that AgentTool exposes input_schema from first sub-agent of SequentialAgent."""
952+
953+ class CustomInput (BaseModel ):
954+ query : str
955+ language : str
956+
957+ first_agent = Agent (
958+ name = 'first_agent' ,
959+ model = testing_utils .MockModel .create (responses = ['response1' ]),
960+ input_schema = CustomInput ,
961+ )
962+
963+ second_agent = Agent (
964+ name = 'second_agent' ,
965+ model = testing_utils .MockModel .create (responses = ['response2' ]),
966+ )
967+
968+ sequence = SequentialAgent (
969+ name = 'sequence' ,
970+ description = 'Process the query through multiple steps' ,
971+ sub_agents = [first_agent , second_agent ],
972+ )
973+
974+ agent_tool = AgentTool (agent = sequence )
975+ declaration = agent_tool ._get_declaration ()
976+
977+ # Should expose CustomInput schema, not fallback to 'request'
978+ assert declaration .name == 'sequence'
979+ assert declaration .description == 'Process the query through multiple steps'
980+ assert declaration .parameters .properties ['query' ].type == 'STRING'
981+ assert declaration .parameters .properties ['language' ].type == 'STRING'
982+ assert 'request' not in declaration .parameters .properties
983+
984+ def test_sequential_agent_without_input_schema_falls_back_to_request (self ):
985+ """Test that AgentTool falls back to 'request' when no sub-agent has input_schema."""
986+
987+ first_agent = Agent (
988+ name = 'first_agent' ,
989+ model = testing_utils .MockModel .create (responses = ['response1' ]),
990+ )
991+
992+ second_agent = Agent (
993+ name = 'second_agent' ,
994+ model = testing_utils .MockModel .create (responses = ['response2' ]),
995+ )
996+
997+ sequence = SequentialAgent (
998+ name = 'sequence' ,
999+ description = 'Process the query through multiple steps' ,
1000+ sub_agents = [first_agent , second_agent ],
1001+ )
1002+
1003+ agent_tool = AgentTool (agent = sequence )
1004+ declaration = agent_tool ._get_declaration ()
1005+
1006+ # Should fall back to 'request' parameter
1007+ assert declaration .name == 'sequence'
1008+ assert declaration .parameters .properties ['request' ].type == 'STRING'
1009+ assert 'query' not in declaration .parameters .properties
1010+
1011+ @mark .parametrize (
1012+ 'env_variables' ,
1013+ [
1014+ 'VERTEX' ,
1015+ ],
1016+ indirect = True ,
1017+ )
1018+ def test_sequential_agent_with_last_sub_agent_output_schema (
1019+ self , env_variables
1020+ ):
1021+ """Test that AgentTool uses output_schema from last sub-agent of SequentialAgent."""
1022+
1023+ class CustomOutput (BaseModel ):
1024+ result : str
1025+
1026+ first_agent = Agent (
1027+ name = 'first_agent' ,
1028+ model = testing_utils .MockModel .create (responses = ['response1' ]),
1029+ )
1030+
1031+ second_agent = Agent (
1032+ name = 'second_agent' ,
1033+ model = testing_utils .MockModel .create (responses = ['response2' ]),
1034+ output_schema = CustomOutput ,
1035+ )
1036+
1037+ sequence = SequentialAgent (
1038+ name = 'sequence' ,
1039+ description = 'Process the query' ,
1040+ sub_agents = [first_agent , second_agent ],
1041+ )
1042+
1043+ agent_tool = AgentTool (agent = sequence )
1044+ declaration = agent_tool ._get_declaration ()
1045+
1046+ # Should have object response schema from last sub-agent
1047+ assert declaration .response is not None
1048+ assert declaration .response .type == types .Type .OBJECT
1049+
1050+ def test_nested_sequential_agent_input_schema (self ):
1051+ """Test that AgentTool recursively finds input_schema in nested composite agents."""
1052+
1053+ class CustomInput (BaseModel ):
1054+ deep_query : str
1055+
1056+ inner_agent = Agent (
1057+ name = 'inner_agent' ,
1058+ model = testing_utils .MockModel .create (responses = ['response1' ]),
1059+ input_schema = CustomInput ,
1060+ )
1061+
1062+ inner_sequence = SequentialAgent (
1063+ name = 'inner_sequence' ,
1064+ sub_agents = [inner_agent ],
1065+ )
1066+
1067+ outer_sequence = SequentialAgent (
1068+ name = 'outer_sequence' ,
1069+ description = 'Nested sequence' ,
1070+ sub_agents = [inner_sequence ],
1071+ )
1072+
1073+ agent_tool = AgentTool (agent = outer_sequence )
1074+ declaration = agent_tool ._get_declaration ()
1075+
1076+ # Should recursively find CustomInput from inner_agent
1077+ assert declaration .name == 'outer_sequence'
1078+ assert 'deep_query' in declaration .parameters .properties
1079+ assert declaration .parameters .properties ['deep_query' ].type == 'STRING'
1080+ assert 'request' not in declaration .parameters .properties
1081+
1082+ @mark .parametrize (
1083+ 'env_variables' ,
1084+ [
1085+ 'GOOGLE_AI' ,
1086+ 'VERTEX' ,
1087+ ],
1088+ indirect = True ,
1089+ )
1090+ def test_sequential_agent_custom_schema_end_to_end (self , env_variables ):
1091+ """Test end-to-end flow with SequentialAgent using custom input/output schema."""
1092+
1093+ class CustomInput (BaseModel ):
1094+ custom_input : str
1095+
1096+ class CustomOutput (BaseModel ):
1097+ custom_output : str
1098+
1099+ function_call_seq = Part .from_function_call (
1100+ name = 'sequence' , args = {'custom_input' : 'test_input' }
1101+ )
1102+
1103+ mock_model = testing_utils .MockModel .create (
1104+ responses = [
1105+ function_call_seq ,
1106+ '{"custom_output": "step1_response"}' ,
1107+ '{"custom_output": "final_response"}' ,
1108+ 'root_response' ,
1109+ ]
1110+ )
1111+
1112+ first_agent = Agent (
1113+ name = 'first_agent' ,
1114+ model = mock_model ,
1115+ input_schema = CustomInput ,
1116+ )
1117+
1118+ second_agent = Agent (
1119+ name = 'second_agent' ,
1120+ model = mock_model ,
1121+ output_schema = CustomOutput ,
1122+ output_key = 'seq_output' ,
1123+ )
1124+
1125+ sequence = SequentialAgent (
1126+ name = 'sequence' ,
1127+ description = 'A sequential pipeline' ,
1128+ sub_agents = [first_agent , second_agent ],
1129+ )
1130+
1131+ root_agent = Agent (
1132+ name = 'root_agent' ,
1133+ model = mock_model ,
1134+ tools = [AgentTool (agent = sequence )],
1135+ )
1136+
1137+ runner = testing_utils .InMemoryRunner (root_agent )
1138+ runner .run ('test1' )
1139+
1140+ # Verify the tool declaration sent to LLM has the correct schema
1141+ # The first request is from root_agent, which should have the tool declaration
1142+ first_request = mock_model .requests [0 ]
1143+ tool_declarations = first_request .config .tools
1144+ assert len (tool_declarations ) == 1
1145+
1146+ sequence_tool = tool_declarations [0 ].function_declarations [0 ]
1147+ assert sequence_tool .name == 'sequence'
1148+ # Should have 'custom_input' parameter from first sub-agent's input_schema
1149+ assert 'custom_input' in sequence_tool .parameters .properties
1150+ # Should NOT have the fallback 'request' parameter
1151+ assert 'request' not in sequence_tool .parameters .properties
1152+
1153+ def test_empty_sequential_agent_falls_back_to_request (self ):
1154+ """Test that AgentTool with empty SequentialAgent falls back to 'request'."""
1155+
1156+ sequence = SequentialAgent (
1157+ name = 'empty_sequence' ,
1158+ description = 'An empty sequence' ,
1159+ sub_agents = [],
1160+ )
1161+
1162+ agent_tool = AgentTool (agent = sequence )
1163+ declaration = agent_tool ._get_declaration ()
1164+
1165+ # Should fall back to 'request' parameter
1166+ assert declaration .parameters .properties ['request' ].type == 'STRING'
0 commit comments