Skip to content

Commit 1aac6aa

Browse files
[KYUUBI #7379][2b/4][FOLLOWUP] Split ReactAgent into LlmStreamClient + composite MiddlewareDispatcher
ReactAgent had grown to mix three concerns: the ReAct control loop, OpenAI streaming/chunk assembly, and middleware fold logic. Extract: - LlmStreamClient: owns one streaming chat completion call, accumulates content + tool-call deltas, and exposes StreamResult.toAssistantMessage for SDK message construction. Depends only on the OpenAI SDK and AgentRunContext (emits ContentDelta via ctx.emit, no dispatcher reference). - MiddlewareDispatcher: implements AgentMiddleware as a composite over the configured list. ReactAgent calls onAgentStart / onEvent / beforeLlmCall etc. on the composite the same way it would call any middleware; resolveApproval stays as a non-interface accessor for the approval flow's special case. Also: afterToolCall now returns Decision<String> for symmetry with the other interceptor hooks; ABORT marks ToolResult.isError=true so listeners can distinguish a middleware-vetoed result from a successful one. The emit-then-forward step splits cleanly: the composite runs onEvent, and ReactAgent's ctx.setEventEmitter lambda forwards the filtered event to the user's raw consumer. ReactAgent's run() drops the eventConsumer parameter threading through internal helpers — everywhere downstream uses ctx.emit().
1 parent 9fe9625 commit 1aac6aa

4 files changed

Lines changed: 452 additions & 335 deletions

File tree

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.dataagent.runtime;
19+
20+
import com.openai.client.OpenAIClient;
21+
import com.openai.core.http.StreamResponse;
22+
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
23+
import com.openai.models.chat.completions.ChatCompletionChunk;
24+
import com.openai.models.chat.completions.ChatCompletionCreateParams;
25+
import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall;
26+
import com.openai.models.chat.completions.ChatCompletionMessageParam;
27+
import com.openai.models.chat.completions.ChatCompletionMessageToolCall;
28+
import com.openai.models.chat.completions.ChatCompletionStreamOptions;
29+
import java.util.ArrayList;
30+
import java.util.HashMap;
31+
import java.util.List;
32+
import java.util.Map;
33+
import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta;
34+
import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
35+
import org.slf4j.Logger;
36+
import org.slf4j.LoggerFactory;
37+
38+
/** Streams one chat completion call and assembles assistant content plus streamed tool calls. */
39+
final class LlmStreamClient {
40+
41+
private static final Logger LOG = LoggerFactory.getLogger(LlmStreamClient.class);
42+
43+
private final OpenAIClient client;
44+
private final ToolRegistry toolRegistry;
45+
46+
LlmStreamClient(OpenAIClient client, ToolRegistry toolRegistry) {
47+
this.client = client;
48+
this.toolRegistry = toolRegistry;
49+
}
50+
51+
/**
52+
* Stream LLM response, emitting ContentDelta through {@code ctx} for each text chunk. Assembles
53+
* tool calls directly from streamed chunks with no non-streaming fallback.
54+
*/
55+
StreamResult stream(
56+
AgentRunContext ctx, List<ChatCompletionMessageParam> messages, String effectiveModel) {
57+
ChatCompletionCreateParams.Builder paramsBuilder =
58+
ChatCompletionCreateParams.builder()
59+
.model(effectiveModel)
60+
.streamOptions(ChatCompletionStreamOptions.builder().includeUsage(true).build());
61+
for (ChatCompletionMessageParam msg : messages) {
62+
paramsBuilder.addMessage(msg);
63+
}
64+
toolRegistry.addToolsTo(paramsBuilder);
65+
66+
LOG.info("LLM request: model={}", effectiveModel);
67+
StreamAccumulator acc = new StreamAccumulator();
68+
try (StreamResponse<ChatCompletionChunk> stream =
69+
client.chat().completions().createStreaming(paramsBuilder.build())) {
70+
stream.stream().forEach(chunk -> consumeChunk(ctx, chunk, acc));
71+
}
72+
return new StreamResult(acc.content.toString(), acc.buildToolCalls());
73+
}
74+
75+
/** Fold one streaming chunk into {@code acc}, emitting per-token {@link ContentDelta}s. */
76+
private void consumeChunk(AgentRunContext ctx, ChatCompletionChunk chunk, StreamAccumulator acc) {
77+
if (!acc.serverModelLogged) {
78+
LOG.info("LLM response: server-echoed model={}", chunk.model());
79+
acc.serverModelLogged = true;
80+
}
81+
chunk
82+
.usage()
83+
.ifPresent(u -> ctx.addTokenUsage(u.promptTokens(), u.completionTokens(), u.totalTokens()));
84+
85+
for (ChatCompletionChunk.Choice c : chunk.choices()) {
86+
c.delta()
87+
.content()
88+
.ifPresent(
89+
text -> {
90+
acc.content.append(text);
91+
ctx.emit(new ContentDelta(text));
92+
});
93+
c.delta().toolCalls().ifPresent(acc::mergeToolCallDeltas);
94+
}
95+
}
96+
97+
/**
98+
* Mutable accumulator for a single streaming LLM turn. Tool call fields are keyed by the chunk's
99+
* {@code index} because provider SDKs may deliver a single logical call across multiple chunks
100+
* and only surface the {@code id}/{@code name} on the first one.
101+
*/
102+
private static final class StreamAccumulator {
103+
final StringBuilder content = new StringBuilder();
104+
final Map<Integer, String> toolCallIds = new HashMap<>();
105+
final Map<Integer, String> toolCallNames = new HashMap<>();
106+
final Map<Integer, StringBuilder> toolCallArgs = new HashMap<>();
107+
boolean serverModelLogged = false;
108+
109+
void mergeToolCallDeltas(List<ChatCompletionChunk.Choice.Delta.ToolCall> deltas) {
110+
for (ChatCompletionChunk.Choice.Delta.ToolCall tc : deltas) {
111+
int idx = (int) tc.index();
112+
tc.id().ifPresent(id -> toolCallIds.put(idx, id));
113+
tc.function()
114+
.ifPresent(
115+
fn -> {
116+
fn.name().ifPresent(name -> toolCallNames.put(idx, name));
117+
fn.arguments()
118+
.ifPresent(
119+
args ->
120+
toolCallArgs
121+
.computeIfAbsent(idx, k -> new StringBuilder())
122+
.append(args));
123+
});
124+
}
125+
}
126+
127+
/**
128+
* Materialize accumulated deltas into SDK tool-call objects. Returns {@code null} (not an empty
129+
* list) if no tool calls were seen, matching the existing {@link StreamResult} contract.
130+
*/
131+
List<ChatCompletionMessageToolCall> buildToolCalls() {
132+
if (toolCallIds.isEmpty()) return null;
133+
List<ChatCompletionMessageToolCall> out = new ArrayList<>(toolCallIds.size());
134+
for (Map.Entry<Integer, String> e : toolCallIds.entrySet()) {
135+
int idx = e.getKey();
136+
String id = (e.getValue() == null || e.getValue().isEmpty()) ? synthId() : e.getValue();
137+
String args = toolCallArgs.containsKey(idx) ? toolCallArgs.get(idx).toString() : "{}";
138+
out.add(
139+
ChatCompletionMessageToolCall.ofFunction(
140+
ChatCompletionMessageFunctionToolCall.builder()
141+
.id(id)
142+
.function(
143+
ChatCompletionMessageFunctionToolCall.Function.builder()
144+
.name(toolCallNames.getOrDefault(idx, ""))
145+
.arguments(args)
146+
.build())
147+
.build()));
148+
}
149+
return out;
150+
}
151+
152+
/**
153+
* Synthesize an id for tool calls whose id never arrived on the stream (some OpenAI-compatible
154+
* providers omit it). The id has to be stable within a turn and unique across turns so the
155+
* assistant/tool_result pairing downstream holds.
156+
*/
157+
private static String synthId() {
158+
return "local_" + java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 24);
159+
}
160+
}
161+
162+
/** Result of a streaming LLM call, assembled from chunks. */
163+
static final class StreamResult {
164+
final String content;
165+
final List<ChatCompletionMessageToolCall> toolCalls;
166+
167+
StreamResult(String content, List<ChatCompletionMessageToolCall> toolCalls) {
168+
this.content = content;
169+
this.toolCalls = toolCalls;
170+
}
171+
172+
boolean isEmpty() {
173+
return content.isEmpty() && (toolCalls == null || toolCalls.isEmpty());
174+
}
175+
176+
/** Build the SDK assistant message corresponding to this streamed result. */
177+
ChatCompletionAssistantMessageParam toAssistantMessage() {
178+
ChatCompletionAssistantMessageParam.Builder b = ChatCompletionAssistantMessageParam.builder();
179+
if (!content.isEmpty()) {
180+
b.content(content);
181+
}
182+
if (toolCalls != null && !toolCalls.isEmpty()) {
183+
b.toolCalls(toolCalls);
184+
}
185+
return b.build();
186+
}
187+
}
188+
}
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.dataagent.runtime;
19+
20+
import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
21+
import com.openai.models.chat.completions.ChatCompletionMessageParam;
22+
import java.util.List;
23+
import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
24+
import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware;
25+
import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware;
26+
import org.apache.kyuubi.engine.dataagent.runtime.middleware.Decision;
27+
import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolInvocation;
28+
import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
31+
32+
/**
33+
* Composite {@link AgentMiddleware} — folds a list of middlewares into one. Hook ordering follows
34+
* the onion model: {@code before*} / {@code on*Start} run first-to-last, {@code after*} / {@code
35+
* on*Finish} run last-to-first.
36+
*
37+
* <p>Component middlewares are internal framework code. If one throws during ordinary hook
38+
* dispatch, the agent run fails via {@link ReactAgent#run}; lifecycle cleanup hooks ({@link
39+
* #onAgentFinish}, {@link #onSessionClose}, {@link #onStop}) swallow exceptions so later
40+
* middlewares still get a chance to release state.
41+
*/
42+
final class MiddlewareDispatcher implements AgentMiddleware {
43+
44+
private static final Logger LOG = LoggerFactory.getLogger(MiddlewareDispatcher.class);
45+
46+
private final List<AgentMiddleware> middlewares;
47+
private final ApprovalMiddleware approvalMiddleware;
48+
49+
MiddlewareDispatcher(List<AgentMiddleware> middlewares) {
50+
this.middlewares = middlewares;
51+
this.approvalMiddleware = findApprovalMiddleware(middlewares);
52+
}
53+
54+
/**
55+
* Resolve a pending approval request. Not part of {@link AgentMiddleware} — special accessor for
56+
* the approval flow.
57+
*/
58+
boolean resolveApproval(String requestId, boolean approved) {
59+
if (approvalMiddleware == null) return false;
60+
return approvalMiddleware.resolve(requestId, approved);
61+
}
62+
63+
@Override
64+
public void onRegister(ToolRegistry registry) {
65+
for (AgentMiddleware mw : middlewares) {
66+
mw.onRegister(registry);
67+
}
68+
}
69+
70+
@Override
71+
public void onAgentStart(AgentRunContext ctx) {
72+
for (AgentMiddleware mw : middlewares) {
73+
mw.onAgentStart(ctx);
74+
}
75+
}
76+
77+
@Override
78+
public void onAgentFinish(AgentRunContext ctx) {
79+
// Runs even when the agent body threw, so swallow here to ensure every middleware's cleanup
80+
// gets a chance to run; otherwise we'd leak session state in later middlewares.
81+
for (int i = middlewares.size() - 1; i >= 0; i--) {
82+
try {
83+
middlewares.get(i).onAgentFinish(ctx);
84+
} catch (Exception e) {
85+
LOG.warn("Middleware onAgentFinish error", e);
86+
}
87+
}
88+
}
89+
90+
@Override
91+
public void onSessionClose(String sessionId) {
92+
for (AgentMiddleware mw : middlewares) {
93+
try {
94+
mw.onSessionClose(sessionId);
95+
} catch (Exception e) {
96+
LOG.warn("Middleware onSessionClose error", e);
97+
}
98+
}
99+
}
100+
101+
@Override
102+
public void onStop() {
103+
for (AgentMiddleware mw : middlewares) {
104+
try {
105+
mw.onStop();
106+
} catch (Exception e) {
107+
LOG.warn("Middleware onStop error", e);
108+
}
109+
}
110+
}
111+
112+
/**
113+
* Fold {@code onEvent} in onion order. Returns PROCEED if untouched, REPLACE with the final event
114+
* if any middleware rewrote it, or ABORT if any short-circuited.
115+
*/
116+
@Override
117+
public Decision<AgentEvent> onEvent(AgentRunContext ctx, AgentEvent event) {
118+
AgentEvent current = event;
119+
for (AgentMiddleware mw : middlewares) {
120+
Decision<AgentEvent> d = mw.onEvent(ctx, current);
121+
if (d.kind() == Decision.Kind.ABORT) return d;
122+
if (d.kind() == Decision.Kind.REPLACE) current = d.replacement();
123+
}
124+
return Decision.of(event, current);
125+
}
126+
127+
/**
128+
* Fold {@code beforeLlmCall} in onion order so later middlewares see rewritten messages. Returns
129+
* PROCEED if untouched, REPLACE with the final value if any did, or ABORT if any short-circuited.
130+
*/
131+
@Override
132+
public Decision<List<ChatCompletionMessageParam>> beforeLlmCall(
133+
AgentRunContext ctx, List<ChatCompletionMessageParam> messages) {
134+
List<ChatCompletionMessageParam> current = messages;
135+
for (AgentMiddleware mw : middlewares) {
136+
Decision<List<ChatCompletionMessageParam>> d = mw.beforeLlmCall(ctx, current);
137+
if (d.kind() == Decision.Kind.ABORT) return d;
138+
if (d.kind() == Decision.Kind.REPLACE) current = d.replacement();
139+
}
140+
return Decision.of(messages, current);
141+
}
142+
143+
/**
144+
* Fold {@code afterLlmCall} in reverse onion order so earlier middlewares see rewritten
145+
* responses. Returns the final response, or ABORT if any middleware short-circuits.
146+
*/
147+
@Override
148+
public Decision<ChatCompletionAssistantMessageParam> afterLlmCall(
149+
AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {
150+
ChatCompletionAssistantMessageParam current = response;
151+
for (int i = middlewares.size() - 1; i >= 0; i--) {
152+
Decision<ChatCompletionAssistantMessageParam> d =
153+
middlewares.get(i).afterLlmCall(ctx, current);
154+
if (d.kind() == Decision.Kind.ABORT) return d;
155+
if (d.kind() == Decision.Kind.REPLACE) current = d.replacement();
156+
}
157+
return Decision.of(response, current);
158+
}
159+
160+
/**
161+
* Fold {@code beforeToolCall} in onion order so later middlewares can further rewrite. Returns
162+
* PROCEED if untouched, REPLACE with the final invocation otherwise, or ABORT if any middleware
163+
* denies the call.
164+
*/
165+
@Override
166+
public Decision<ToolInvocation> beforeToolCall(AgentRunContext ctx, ToolInvocation call) {
167+
ToolInvocation current = call;
168+
for (AgentMiddleware mw : middlewares) {
169+
Decision<ToolInvocation> d = mw.beforeToolCall(ctx, current);
170+
if (d.kind() == Decision.Kind.ABORT) return d;
171+
if (d.kind() == Decision.Kind.REPLACE) current = d.replacement();
172+
}
173+
return Decision.of(call, current);
174+
}
175+
176+
/**
177+
* Fold {@code afterToolCall} in reverse onion order so earlier middlewares see rewritten results.
178+
* Returns the final result, or ABORT if any middleware short-circuits — caller decides how to
179+
* surface the abort (typically: use {@code reason()} as the result text the LLM sees).
180+
*/
181+
@Override
182+
public Decision<String> afterToolCall(AgentRunContext ctx, ToolInvocation call, String result) {
183+
String current = result;
184+
for (int i = middlewares.size() - 1; i >= 0; i--) {
185+
Decision<String> d = middlewares.get(i).afterToolCall(ctx, call, current);
186+
if (d.kind() == Decision.Kind.ABORT) return d;
187+
if (d.kind() == Decision.Kind.REPLACE) current = d.replacement();
188+
}
189+
return Decision.of(result, current);
190+
}
191+
192+
private static ApprovalMiddleware findApprovalMiddleware(List<AgentMiddleware> middlewares) {
193+
for (AgentMiddleware mw : middlewares) {
194+
if (mw instanceof ApprovalMiddleware) return (ApprovalMiddleware) mw;
195+
}
196+
return null;
197+
}
198+
}

0 commit comments

Comments
 (0)