Skip to content

Commit cf86ed2

Browse files
wenjin272claude
andauthored
[fix] Fix ReActAgent failure when output schema is null (#837)
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 462dbc4 commit cf86ed2

4 files changed

Lines changed: 176 additions & 21 deletions

File tree

api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public class ReActAgent extends Agent {
6060
public ReActAgent(
6161
ResourceDescriptor descriptor, @Nullable Prompt prompt, @Nullable Object outputSchema) {
6262
this.addResource(DEFAULT_CHAT_MODEL, ResourceType.CHAT_MODEL, descriptor);
63+
Map<String, Object> actionConfig = new HashMap<>();
6364

6465
if (outputSchema != null) {
6566
String jsonSchema;
@@ -82,15 +83,13 @@ public ReActAgent(
8283
"The final response should be json format, and match the schema %s",
8384
jsonSchema));
8485
this.addResource(DEFAULT_SCHEMA_PROMPT, ResourceType.PROMPT, schemaPrompt);
86+
actionConfig.put("output_schema", outputSchema);
8587
}
8688

8789
if (prompt != null) {
8890
this.addResource(DEFAULT_USER_PROMPT, ResourceType.PROMPT, prompt);
8991
}
9092

91-
Map<String, Object> actionConfig = new HashMap<>();
92-
actionConfig.put("output_schema", outputSchema);
93-
9493
try {
9594
Method method =
9695
this.getClass().getMethod("startAction", Event.class, RunnerContext.class);

e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.flink.api.common.typeinfo.TypeInformation;
3535
import org.apache.flink.api.java.functions.KeySelector;
3636
import org.apache.flink.api.java.typeutils.RowTypeInfo;
37+
import org.apache.flink.streaming.api.datastream.DataStream;
3738
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
3839
import org.apache.flink.table.api.DataTypes;
3940
import org.apache.flink.table.api.Schema;
@@ -46,6 +47,7 @@
4647
import org.junit.jupiter.api.Test;
4748

4849
import java.io.IOException;
50+
import java.util.ArrayList;
4951
import java.util.List;
5052

5153
import static org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY;
@@ -114,7 +116,7 @@ public void testReActAgent() throws Exception {
114116
agentsEnv.getConfig().set(MAX_RETRIES, 3);
115117

116118
// Declare the ReAct agent.
117-
Agent agent = getAgent();
119+
Agent agent = getAgent(true);
118120

119121
// Create input table from sample data
120122
Table inputTable =
@@ -152,8 +154,74 @@ public void testReActAgent() throws Exception {
152154
checkResult(results);
153155
}
154156

155-
// create ReAct agent.
156-
private static Agent getAgent() {
157+
@Test
158+
public void testReActAgentNoOutputSchema() throws Exception {
159+
Assumptions.assumeTrue(ollamaReady, String.format("%s is not ready", OLLAMA_MODEL));
160+
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
161+
env.setParallelism(1);
162+
163+
// Create the table environment
164+
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
165+
tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100");
166+
167+
// Create agents execution environment
168+
AgentsExecutionEnvironment agentsEnv =
169+
AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv);
170+
171+
// register resource to agents execution environment.
172+
agentsEnv
173+
.addResource(
174+
"ollama",
175+
ResourceType.CHAT_MODEL_CONNECTION,
176+
ResourceDescriptor.Builder.newBuilder(
177+
ResourceName.ChatModel.OLLAMA_CONNECTION)
178+
.addInitialArgument("endpoint", "http://localhost:11434")
179+
.addInitialArgument("requestTimeout", 240)
180+
.build())
181+
.addResource(
182+
"add",
183+
ResourceType.TOOL,
184+
Tool.fromMethod(
185+
ReActAgentTest.class.getMethod("add", Double.class, Double.class)))
186+
.addResource(
187+
"multiply",
188+
ResourceType.TOOL,
189+
Tool.fromMethod(
190+
ReActAgentTest.class.getMethod(
191+
"multiply", Double.class, Double.class)));
192+
193+
agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, ReActAgent.ErrorHandlingStrategy.RETRY);
194+
agentsEnv.getConfig().set(MAX_RETRIES, 3);
195+
196+
// Declare the ReAct agent without an output schema.
197+
Agent agent = getAgent(false);
198+
199+
// Create input table from sample data
200+
Table inputTable =
201+
tableEnv.fromValues(
202+
DataTypes.ROW(
203+
DataTypes.FIELD("a", DataTypes.DOUBLE()),
204+
DataTypes.FIELD("b", DataTypes.DOUBLE()),
205+
DataTypes.FIELD("c", DataTypes.DOUBLE())),
206+
Row.of(2131, 29847, 3));
207+
208+
// Apply agent to the Table; without an output schema the result is a string.
209+
DataStream<Object> out =
210+
agentsEnv
211+
.fromTable(
212+
inputTable,
213+
(KeySelector<Object, Double>)
214+
value -> (Double) ((Row) value).getField("a"))
215+
.apply(agent)
216+
.toDataStream();
217+
218+
out.print();
219+
220+
env.execute();
221+
}
222+
223+
// create ReAct agent; pass false to skip the output schema.
224+
private static Agent getAgent(boolean withSchema) {
157225
ResourceDescriptor chatModelDescriptor =
158226
ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.OLLAMA_SETUP)
159227
.addInitialArgument("connection", "ollama")
@@ -162,21 +230,24 @@ private static Agent getAgent() {
162230
.addInitialArgument("extract_reasoning", true)
163231
.build();
164232

165-
Prompt prompt =
166-
Prompt.fromMessages(
167-
List.of(
168-
new ChatMessage(
169-
MessageRole.SYSTEM,
170-
"Must call function tool to do the calculate."),
171-
new ChatMessage(
172-
MessageRole.SYSTEM,
173-
"An example of output is {\"result\": 30.32}"),
174-
new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}.")));
233+
List<ChatMessage> messages = new ArrayList<>();
234+
messages.add(
235+
new ChatMessage(
236+
MessageRole.SYSTEM, "Must call function tool to do the calculate."));
237+
if (withSchema) {
238+
messages.add(
239+
new ChatMessage(
240+
MessageRole.SYSTEM, "An example of output is {\"result\": 30.32}"));
241+
}
242+
messages.add(new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}."));
243+
175244
RowTypeInfo outputTypeInfo =
176-
new RowTypeInfo(
177-
new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO},
178-
new String[] {"result"});
179-
return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo);
245+
withSchema
246+
? new RowTypeInfo(
247+
new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO},
248+
new String[] {"result"})
249+
: null;
250+
return new ReActAgent(chatModelDescriptor, Prompt.fromMessages(messages), outputTypeInfo);
180251
}
181252

182253
private void checkResult(CloseableIterator<?> results) {

python/flink_agents/api/agents/react_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(
143143
name="start_action",
144144
events=[InputEvent.EVENT_TYPE],
145145
func=self.start_action,
146-
output_schema=OutputSchema(output_schema=output_schema),
146+
output_schema=OutputSchema(output_schema=output_schema) if output_schema else None,
147147
)
148148

149149
@staticmethod

python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,88 @@ def test_react_agent_on_remote_runner(
267267
# through the event-log capture path.
268268
invocations = collect_tool_invocations(log_dir)
269269
assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})
270+
271+
272+
@pytest.mark.skipif(
273+
client is None, reason="Ollama client is not available or test model is missing"
274+
)
275+
def test_react_agent_no_output_schema_on_remote_runner(
276+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
277+
) -> None:
278+
"""ReAct agent without an output_schema should emit a plain string result."""
279+
monkeypatch.setenv("OLLAMA_CHAT_MODEL", OLLAMA_MODEL)
280+
stream_env = StreamExecutionEnvironment.get_execution_environment()
281+
282+
stream_env.set_parallelism(1)
283+
284+
t_env = StreamTableEnvironment.create(stream_execution_environment=stream_env)
285+
286+
table = t_env.from_elements(
287+
elements=[(2123, 2321, 312)],
288+
schema=DataTypes.ROW(
289+
[
290+
DataTypes.FIELD("a", DataTypes.INT()),
291+
DataTypes.FIELD("b", DataTypes.INT()),
292+
DataTypes.FIELD("c", DataTypes.INT()),
293+
]
294+
),
295+
)
296+
297+
env = AgentsExecutionEnvironment.get_execution_environment(
298+
env=stream_env, t_env=t_env
299+
)
300+
301+
env.get_config().set(
302+
AgentExecutionOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY
303+
)
304+
305+
env.get_config().set(AgentExecutionOptions.MAX_RETRIES, 3)
306+
307+
log_dir = tmp_path / "event_logs"
308+
log_dir.mkdir(parents=True, exist_ok=True)
309+
env.get_config().set_str("baseLogDir", str(log_dir))
310+
311+
# register resource to execution environment
312+
(
313+
env.add_resource(
314+
"ollama",
315+
ResourceType.CHAT_MODEL_CONNECTION,
316+
ResourceDescriptor(
317+
clazz=ResourceName.ChatModel.OLLAMA_CONNECTION, request_timeout=240.0
318+
),
319+
)
320+
.add_resource("add", ResourceType.TOOL, Tool.from_callable(add))
321+
.add_resource("multiply", ResourceType.TOOL, Tool.from_callable(multiply))
322+
)
323+
324+
# prepare prompt
325+
prompt = Prompt.from_messages(
326+
messages=[
327+
ChatMessage(role=MessageRole.USER, content="What is ({a} + {b}) * {c}"),
328+
],
329+
)
330+
331+
# create ReAct agent without an output schema; result is emitted as a string.
332+
agent = ReActAgent(
333+
chat_model=ResourceDescriptor(
334+
clazz=ResourceName.ChatModel.OLLAMA_SETUP,
335+
connection="ollama",
336+
model=OLLAMA_MODEL,
337+
tools=["add", "multiply"],
338+
),
339+
prompt=prompt,
340+
)
341+
342+
output_stream = (
343+
env.from_table(input=table, key_selector=MyKeySelector())
344+
.apply(agent)
345+
.to_datastream()
346+
)
347+
output_stream.print()
348+
349+
env.execute()
350+
351+
# multiply's first arg (4444 = 2123 + 2321) proves the addition was computed
352+
# correctly and threaded into multiply, even without an output schema.
353+
invocations = collect_tool_invocations(log_dir)
354+
assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})

0 commit comments

Comments
 (0)