From ce4eeccf5aa812f0ccd02967a6612bbed476bac5 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 23 Apr 2026 00:14:33 +0800 Subject: [PATCH 01/10] [KYUUBI #7379][2b/4] Data Agent Engine: agent runtime, middleware stack, OpenAI provider, and live E2E tests This PR delivers the runtime layer of the Data Agent Engine on top of the tool system and data source plumbing from 2a/4: - ReactAgent: ReAct-style loop with streaming LLM responses, per-step tool dispatch, and AgentRunContext tracking token usage, iterations, and session. - Middleware stack (AgentMiddleware + ReactAgent.Builder): * LoggingMiddleware -- structured per-step/LLM/tool/finish logs with MDC. * ApprovalMiddleware -- CompletableFuture-based resolve for DESTRUCTIVE tools; modes NORMAL / STRICT / AUTO_APPROVE. * CompactionMiddleware -- token-threshold-triggered history summarization with KEEP_RECENT_TURNS=4, emits a Compaction AgentEvent so clients can observe the mechanism firing. * ToolResultOffloadMiddleware -- spills large tool outputs to disk and surfaces `read_tool_output` / `grep_tool_output` companion tools for the LLM to re-query truncated previews. - OpenAiProvider: single shared ReactAgent, per-session ConversationMemory, streaming chat completions, Hikari-pooled JDBC data source; reads model and thresholds from KyuubiConf. - ExecuteStatement (Scala): encodes all AgentEvents (including compaction and approval_request) as SSE JSON rows streamed through the JDBC reply column. - KyuubiConf: new keys for LLM provider/api-url/model/api-key, approval mode, compaction trigger tokens, offload root/thresholds, max iterations, etc. - Tests: * Unit tests for runtime, middlewares, offload store, and event shapes. * Live tests gated on DATA_AGENT_LLM_API_KEY covering full LLM round-trips: ReactAgentLiveTest (offload+grep, approval approve/deny), DataAgentE2ESuite and DataAgentApprovalE2ESuite (JDBC layer), DataAgentCompactionE2ESuite (JDBC-observable compaction event + post-compaction recovery), CompactionMiddlewareLiveTest. * Compatibility verified against qwen3.6-plus, glm-5, and kimi-k2.5 via per-call `model=` logging in ReactAgent. --- docs/configuration/settings.md | 1 + externals/kyuubi-data-agent-engine/pom.xml | 43 +- .../dataagent/datasource/JdbcDialect.java | 6 + .../{ => dialect}/GenericDialect.java | 4 +- .../{ => dialect}/MysqlDialect.java | 6 +- .../{ => dialect}/SparkDialect.java | 6 +- .../{ => dialect}/SqliteDialect.java | 6 +- .../{ => dialect}/TrinoDialect.java | 6 +- .../provider/ProviderRunRequest.java | 26 +- .../provider/openai/OpenAiProvider.java | 176 +++++ .../dataagent/runtime/AgentInvocation.java | 72 +++ .../dataagent/runtime/AgentRunContext.java | 112 ++++ .../dataagent/runtime/ApprovalMode.java | 28 + .../dataagent/runtime/ConversationMemory.java | 200 ++++++ .../engine/dataagent/runtime/ReactAgent.java | 606 ++++++++++++++++++ .../dataagent/runtime/ToolOutputStore.java | 242 +++++++ .../dataagent/runtime/event/Compaction.java | 69 ++ .../dataagent/runtime/event/EventType.java | 5 + .../runtime/middleware/AgentMiddleware.java | 152 +++++ .../middleware/ApprovalMiddleware.java | 152 +++++ .../middleware/CompactionMiddleware.java | 409 ++++++++++++ .../runtime/middleware/LoggingMiddleware.java | 160 +++++ .../ToolResultOffloadMiddleware.java | 191 ++++++ .../engine/dataagent/tool/AgentTool.java | 5 +- .../engine/dataagent/tool/ToolContext.java | 40 ++ .../engine/dataagent/tool/ToolRegistry.java | 149 +++-- .../tool/output/GrepToolOutputArgs.java | 39 ++ .../tool/output/GrepToolOutputTool.java | 68 ++ .../tool/output/ReadToolOutputArgs.java | 37 ++ .../tool/output/ReadToolOutputTool.java | 71 ++ .../tool/sql/RunMutationQueryTool.java | 3 +- .../tool/sql/RunSelectQueryTool.java | 3 +- .../tool/sql/SqlReadOnlyChecker.java | 4 +- .../engine/dataagent/util/ConfUtils.java | 62 ++ .../operation/ExecuteStatement.scala | 13 +- .../dataagent/datasource/JdbcDialectTest.java | 1 + .../engine/dataagent/mysql/DialectTest.java | 5 +- .../provider/mock/MockLlmProvider.java | 185 ++++++ .../runtime/ConversationMemoryTest.java | 47 ++ .../dataagent/runtime/ReactAgentLiveTest.java | 568 ++++++++++++++++ .../runtime/ToolOutputStoreTest.java | 116 ++++ .../dataagent/runtime/event/EventTest.java | 3 +- .../middleware/ApprovalMiddlewareTest.java | 294 +++++++++ .../CompactionMiddlewareLiveTest.java | 100 +++ .../middleware/CompactionMiddlewareTest.java | 322 ++++++++++ .../ToolResultOffloadMiddlewareTest.java | 141 ++++ .../tool/ToolRegistryThreadSafetyTest.java | 6 +- .../engine/dataagent/tool/ToolTest.java | 4 +- .../tool/sql/RunMutationQueryToolTest.java | 17 +- .../tool/sql/RunSelectQueryToolTest.java | 45 +- .../DataAgentCompactionE2ESuite.scala | 195 ++++++ .../operation/DataAgentE2ESuite.scala | 92 ++- .../org/apache/kyuubi/config/KyuubiConf.scala | 15 + 53 files changed, 5187 insertions(+), 141 deletions(-) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/GenericDialect.java (92%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/MysqlDialect.java (85%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/SparkDialect.java (85%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/SqliteDialect.java (85%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/TrinoDialect.java (85%) create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md index 3e34882b28f..a2910abfca6 100644 --- a/docs/configuration/settings.md +++ b/docs/configuration/settings.md @@ -144,6 +144,7 @@ You can configure the Kyuubi properties in `$KYUUBI_HOME/conf/kyuubi-defaults.co | kyuubi.engine.chat.provider | ECHO | The provider for the Chat engine. Candidates: | string | 1.8.0 | | kyuubi.engine.connection.url.use.hostname | true | (deprecated) When true, the engine registers with hostname to zookeeper. When Spark runs on K8s with cluster mode, set to false to ensure that server can connect to engine | boolean | 1.3.0 | | kyuubi.engine.data.agent.approval.mode | NORMAL | Default approval mode for tool execution in the Data Agent engine. Candidates: | string | 1.12.0 | +| kyuubi.engine.data.agent.compaction.trigger.tokens | 128000 | The prompt-token threshold above which the Data Agent's compaction middleware summarizes older conversation history into a compact message. The check is made each turn as real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; when this predicted prompt size reaches the configured value, older messages are replaced by a single summary message while the most recent exchanges are kept verbatim. Set to a very large value (e.g., 9223372036854775807) to effectively disable compaction. | long | 1.12.0 | | kyuubi.engine.data.agent.extra.classpath | <undefined> | The extra classpath for the Data Agent engine, for configuring the location of the LLM SDK and etc. | string | 1.12.0 | | kyuubi.engine.data.agent.java.options | <undefined> | The extra Java options for the Data Agent engine | string | 1.12.0 | | kyuubi.engine.data.agent.jdbc.url | <undefined> | The JDBC URL for the Data Agent engine to connect to the target database. If not set, the Data Agent will connect back to Kyuubi server via ZooKeeper service discovery. | string | 1.12.0 | diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml index c34d049360c..74da5005784 100644 --- a/externals/kyuubi-data-agent-engine/pom.xml +++ b/externals/kyuubi-data-agent-engine/pom.xml @@ -50,19 +50,48 @@ ${project.version} + com.openai openai-java + ${openai.sdk.version} + com.github.victools jsonschema-generator + ${victools.jsonschema.version} - com.github.victools jsonschema-module-jackson + ${victools.jsonschema.version} + + + + + org.xerial + sqlite-jdbc + ${sqlite.version} + + + + + com.mysql + mysql-connector-j + + + + + io.trino + trino-jdbc + + + + + com.zaxxer + HikariCP @@ -74,24 +103,12 @@ test - - org.xerial - sqlite-jdbc - test - - org.testcontainers testcontainers-mysql test - - com.mysql - mysql-connector-j - test - - junit junit diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java index c3be1dad61a..c771ad222aa 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java @@ -17,6 +17,12 @@ package org.apache.kyuubi.engine.dataagent.datasource; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.SparkDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.SqliteDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.TrinoDialect; + /** * SQL dialect abstraction for datasource-specific SQL generation. * diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java similarity index 92% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java index 3ea22ed54e3..d8c4512de03 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** * Fallback dialect for JDBC subprotocols that have no dedicated implementation. Carries the diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java index 98747ffa30c..350789a6a87 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** MySQL dialect. Uses backtick quoting for identifiers. */ public final class MysqlDialect implements JdbcDialect { - static final MysqlDialect INSTANCE = new MysqlDialect(); + public static final MysqlDialect INSTANCE = new MysqlDialect(); private MysqlDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java index 3adb43fa398..34e20034bfb 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** Spark SQL dialect. Uses backtick quoting for identifiers. */ public final class SparkDialect implements JdbcDialect { - static final SparkDialect INSTANCE = new SparkDialect(); + public static final SparkDialect INSTANCE = new SparkDialect(); private SparkDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java index a53255a9c67..eb98ca8edfa 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** SQLite dialect. Uses double-quote quoting for identifiers. */ public final class SqliteDialect implements JdbcDialect { - static final SqliteDialect INSTANCE = new SqliteDialect(); + public static final SqliteDialect INSTANCE = new SqliteDialect(); private SqliteDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java index edacf2f87e2..75fbd4bb242 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** Trino SQL dialect. Uses double-quote quoting for identifiers. */ public final class TrinoDialect implements JdbcDialect { - static final TrinoDialect INSTANCE = new TrinoDialect(); + public static final TrinoDialect INSTANCE = new TrinoDialect(); private TrinoDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java index f4e40b2fae8..26ad8be77fb 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java @@ -17,13 +17,23 @@ package org.apache.kyuubi.engine.dataagent.provider; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * User-facing request parameters for a provider-level agent invocation. Only contains fields from * the caller (question, model override, etc.). Adding new per-request options does not require * changing the {@link DataAgentProvider} interface. + * + *

The approval mode is accepted as a raw string (natural for config-driven callers) and parsed + * into {@link ApprovalMode} by {@link #getApprovalMode()}. Unrecognised values fall back to {@link + * ApprovalMode#NORMAL} with a warning. */ public class ProviderRunRequest { + private static final Logger LOG = LoggerFactory.getLogger(ProviderRunRequest.class); + private final String question; private String modelName; private String approvalMode; @@ -45,8 +55,20 @@ public ProviderRunRequest modelName(String modelName) { return this; } - public String getApprovalMode() { - return approvalMode; + /** + * Resolved approval mode. Returns {@link ApprovalMode#NORMAL} when the caller did not set one or + * supplied an unknown value. + */ + public ApprovalMode getApprovalMode() { + if (approvalMode == null || approvalMode.isEmpty()) { + return ApprovalMode.NORMAL; + } + try { + return ApprovalMode.valueOf(approvalMode.toUpperCase()); + } catch (IllegalArgumentException e) { + LOG.warn("Unknown approval mode '{}', using default NORMAL", approvalMode); + return ApprovalMode.NORMAL; + } } public ProviderRunRequest approvalMode(String approvalMode) { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java new file mode 100644 index 00000000000..bcd647b9326 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.provider.openai; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.zaxxer.hikari.HikariDataSource; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import javax.sql.DataSource; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.config.KyuubiReservedKeys; +import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory; +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; +import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider; +import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest; +import org.apache.kyuubi.engine.dataagent.runtime.AgentInvocation; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.ReactAgent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.apache.kyuubi.engine.dataagent.util.ConfUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An OpenAI-compatible provider that wires up the full ReactAgent with streaming LLM, tools, and + * middleware pipeline. Uses the official OpenAI Java SDK. + * + *

The ReactAgent, DataSource, and ToolRegistry are shared across all sessions within this engine + * instance. Each session only maintains its own {@link ConversationMemory}. This works because each + * engine is bound to one user + one datasource, so all sessions within the engine naturally share + * the same data connection. + */ +public class OpenAiProvider implements DataAgentProvider { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiProvider.class); + + private final ReactAgent agent; + private final ToolRegistry toolRegistry; + private final DataSource dataSource; + private final OpenAIClient client; + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + public OpenAiProvider(KyuubiConf conf) { + String apiKey = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_KEY()); + String baseUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_URL()); + String modelName = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_MODEL()); + + int maxIterations = ConfUtils.intConf(conf, KyuubiConf.ENGINE_DATA_AGENT_MAX_ITERATIONS()); + long compactionTriggerTokens = + ConfUtils.longConf(conf, KyuubiConf.ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS()); + int queryTimeoutSeconds = + (int) ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_QUERY_TIMEOUT()); + long toolCallTimeoutSeconds = + ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_TOOL_CALL_TIMEOUT()); + + this.client = + OpenAIOkHttpClient.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .maxRetries(3) + .timeout(Duration.ofSeconds(180)) + .build(); + + this.toolRegistry = new ToolRegistry(toolCallTimeoutSeconds); + + SystemPromptBuilder promptBuilder = SystemPromptBuilder.create(); + this.dataSource = attachJdbcDataSource(conf, toolRegistry, promptBuilder, queryTimeoutSeconds); + + this.agent = + ReactAgent.builder() + .client(client) + .modelName(modelName) + .toolRegistry(toolRegistry) + .addMiddleware(new ToolResultOffloadMiddleware()) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(new CompactionMiddleware(client, modelName, compactionTriggerTokens)) + .addMiddleware(new ApprovalMiddleware()) + .maxIterations(maxIterations) + .systemPrompt(promptBuilder.build()) + .build(); + } + + /** + * Register JDBC-backed SQL tools if a JDBC URL is configured. Returns the created {@link + * DataSource} so the provider can close it on shutdown, or {@code null} when no JDBC is wired. + */ + private static DataSource attachJdbcDataSource( + KyuubiConf conf, + ToolRegistry registry, + SystemPromptBuilder promptBuilder, + int queryTimeoutSeconds) { + String jdbcUrl = ConfUtils.optionalString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL()); + if (jdbcUrl == null) { + return null; + } + LOG.info("Data Agent JDBC URL configured ({})", jdbcUrl.replaceAll("//.*@", "//@")); + + String sessionUser = + ConfUtils.optionalString(conf, KyuubiReservedKeys.KYUUBI_SESSION_USER_KEY()); + + DataSource ds = DataSourceFactory.create(jdbcUrl, sessionUser); + registry.register(new RunSelectQueryTool(ds, queryTimeoutSeconds)); + registry.register(new RunMutationQueryTool(ds, queryTimeoutSeconds)); + promptBuilder.datasource(JdbcDialect.fromUrl(jdbcUrl).datasourceName()); + return ds; + } + + @Override + public void open(String sessionId, String user) { + sessions.put(sessionId, new ConversationMemory()); + LOG.info("Opened Data Agent session {} for user {}", sessionId, user); + } + + @Override + public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) { + ConversationMemory memory = sessions.get(sessionId); + if (memory == null) { + throw new IllegalStateException("No open Data Agent session for id=" + sessionId); + } + + AgentInvocation invocation = + new AgentInvocation(request.getQuestion()) + .modelName(request.getModelName()) + .approvalMode(request.getApprovalMode()) + .sessionId(sessionId); + agent.run(invocation, memory, onEvent); + } + + @Override + public boolean resolveApproval(String requestId, boolean approved) { + return agent.resolveApproval(requestId, approved); + } + + @Override + public void close(String sessionId) { + sessions.remove(sessionId); + agent.closeSession(sessionId); + LOG.info("Closed Data Agent session {}", sessionId); + } + + @Override + public void stop() { + agent.stop(); + toolRegistry.close(); + if (dataSource instanceof HikariDataSource) { + ((HikariDataSource) dataSource).close(); + LOG.info("Closed Data Agent connection pool"); + } + client.close(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java new file mode 100644 index 00000000000..0695d556058 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.util.Objects; + +/** + * User-facing request parameters for a single agent invocation. Only contains fields that come from + * the caller (question, model override, etc.). Framework-level concerns like memory and event + * consumer are separate method parameters. + * + *

Adding new per-request options (e.g. temperature, maxTokens) does not require changing the + * {@code ReactAgent.run()} signature. + */ +public class AgentInvocation { + + private final String userInput; + private String modelName; + private ApprovalMode approvalMode = ApprovalMode.NORMAL; + private String sessionId; + + public AgentInvocation(String userInput) { + this.userInput = Objects.requireNonNull(userInput, "userInput must not be null"); + } + + public String getUserInput() { + return userInput; + } + + public String getModelName() { + return modelName; + } + + public AgentInvocation modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public ApprovalMode getApprovalMode() { + return approvalMode; + } + + public AgentInvocation approvalMode(ApprovalMode approvalMode) { + this.approvalMode = approvalMode; + return this; + } + + public String getSessionId() { + return sessionId; + } + + /** Upstream session id, propagated into {@link AgentRunContext#getSessionId()}. */ + public AgentInvocation sessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java new file mode 100644 index 00000000000..e7c92df8033 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.util.function.Consumer; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; + +/** + * Mutable context passed through the middleware pipeline and agent loop. Tracks the current state + * of agent execution including iteration count, token usage, and custom middleware state. + */ +public class AgentRunContext { + + private final ConversationMemory memory; + private final String sessionId; + private Consumer eventEmitter; + private int iteration; + private long promptTokens; + private long completionTokens; + private long totalTokens; + private ApprovalMode approvalMode; + + public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode) { + this(memory, approvalMode, null); + } + + public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode, String sessionId) { + this.memory = memory; + this.iteration = 0; + this.approvalMode = approvalMode; + this.sessionId = sessionId; + } + + public ConversationMemory getMemory() { + return memory; + } + + /** + * The upstream session identifier this run belongs to. Threaded down from {@code + * DataAgentProvider.run(sessionId, ...)}. May be {@code null} in unit tests that do not exercise + * session-scoped middleware. + */ + public String getSessionId() { + return sessionId; + } + + public int getIteration() { + return iteration; + } + + public void setIteration(int iteration) { + this.iteration = iteration; + } + + public long getPromptTokens() { + return promptTokens; + } + + public long getCompletionTokens() { + return completionTokens; + } + + public long getTotalTokens() { + return totalTokens; + } + + /** + * Record one LLM call's usage. Updates both the per-run counters on this context and the + * session-level cumulative on the underlying {@link ConversationMemory}, so middlewares that need + * a session-wide picture can read it directly from memory without keeping their own bookkeeping. + */ + public void addTokenUsage(long prompt, long completion, long total) { + this.promptTokens += prompt; + this.completionTokens += completion; + this.totalTokens += total; + memory.addCumulativeTokens(prompt, completion, total); + } + + public ApprovalMode getApprovalMode() { + return approvalMode; + } + + public void setApprovalMode(ApprovalMode approvalMode) { + this.approvalMode = approvalMode; + } + + public void setEventEmitter(Consumer eventEmitter) { + this.eventEmitter = eventEmitter; + } + + /** Emit an event through the agent's event pipeline. Available for use by middlewares. */ + public void emit(AgentEvent event) { + if (eventEmitter != null) { + eventEmitter.accept(event); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java new file mode 100644 index 00000000000..57bc20bc2bb --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +/** Approval modes for tool execution in the Data Agent engine. */ +public enum ApprovalMode { + /** All tools require explicit user approval. */ + STRICT, + /** Only non-readonly tools require approval. */ + NORMAL, + /** All tools are auto-approved. */ + AUTO_APPROVE +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java new file mode 100644 index 00000000000..0bfae26ec27 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Manages conversation history for a Data Agent session. Ensures tool result messages are never + * orphaned from their corresponding AI messages. + * + *

Each instance is session-scoped and accessed sequentially within a single ReAct loop — no + * synchronization is needed. Cross-session concurrency is handled by the provider's session map. + */ +public class ConversationMemory { + + /** + * System prompt prepended to every LLM call built from this memory. Rebuilt by the provider on + * each invocation (datasource/tool metadata can change between turns), so it lives outside the + * {@link #messages} list rather than being inserted as the first entry. May be {@code null} until + * the first {@link #setSystemPrompt} call — {@link #buildLlmMessages()} will simply omit the + * system slot in that case. + */ + private String systemPrompt; + + /** + * The raw content of the most recent user message added via {@link #addUserMessage}. Cached + * separately from {@link #messages} so middleware and callers can recover the current turn's + * question after compaction rewrites history. Not used by the LLM call path itself. + */ + private String lastUserInput; + + /** + * The ordered conversation history: user / assistant / tool messages, in the order the LLM will + * see them. The system prompt is intentionally NOT stored here (see {@link #systemPrompt}). + * Mutated in place by {@link #replaceHistory} during compaction; otherwise append-only. + */ + private final List messages = new ArrayList<>(); + + /** + * Session-level running total of {@code prompt_tokens} reported by every LLM call on this + * conversation (across ReAct turns). Intended for billing, quota, and observability — not used by + * any runtime decision. Updated via {@link #addCumulativeTokens}. + */ + private long cumulativePromptTokens; + + /** + * Session-level running total of {@code completion_tokens}. See {@link #cumulativePromptTokens}. + */ + private long cumulativeCompletionTokens; + + /** + * Session-level running total of {@code total_tokens} (prompt + completion as reported by the + * provider — not necessarily the sum of the two counters above, since providers may count + * cached/reasoning tokens differently). See {@link #cumulativePromptTokens}. + */ + private long cumulativeTotalTokens; + + /** + * The {@code total_tokens} reported by the single most recent LLM call, or {@code 0} if no call + * has completed yet. Distinct from the cumulative counters: this is a snapshot, overwritten every + * call. Used by {@link + * org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware} to estimate the + * next prompt size (the last response becomes part of the next prompt, so the next call's prompt + * is at least {@code lastTotalTokens}). Persists across ReAct turns until the next call + * overwrites it. + */ + private long lastTotalTokens; + + public ConversationMemory() {} + + public String getSystemPrompt() { + return systemPrompt; + } + + public void setSystemPrompt(String prompt) { + this.systemPrompt = prompt; + } + + public void addUserMessage(String content) { + this.lastUserInput = content; + messages.add( + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(content).build())); + } + + public String getLastUserInput() { + return lastUserInput; + } + + public void addAssistantMessage(ChatCompletionAssistantMessageParam message) { + messages.add(ChatCompletionMessageParam.ofAssistant(message)); + } + + public void addToolResult(String toolCallId, String content) { + messages.add( + ChatCompletionMessageParam.ofTool( + ChatCompletionToolMessageParam.builder() + .toolCallId(toolCallId) + .content(content) + .build())); + } + + /** + * Build the full message list for LLM API invocation: [system prompt] + conversation history. + * + *

No windowing is applied — callers are responsible for managing context length (e.g. via a + * token-based truncation strategy). + * + * @see #getHistory() for history-only access without system prompt + */ + public List buildLlmMessages() { + List result = new ArrayList<>(messages.size() + 1); + if (systemPrompt != null) { + result.add( + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content(systemPrompt).build())); + } + result.addAll(messages); + return Collections.unmodifiableList(result); + } + + /** + * Returns the conversation history (user, assistant, tool messages) without the system prompt. + * Useful for middleware that needs to inspect or compact history. + */ + public List getHistory() { + return Collections.unmodifiableList(new ArrayList<>(messages)); + } + + /** + * Replace the conversation history with a compacted version. Useful for context-length management + * strategies (e.g., summarizing older messages). + * + *

Also clears {@link #lastTotalTokens}: the prior snapshot referred to a prompt whose bulk we + * just discarded, so it no longer describes anything in memory. Leaving it stale would keep the + * compaction trigger armed until the next successful LLM call overwrites it — fine on the happy + * path, but if that call fails the next turn would re-enter compaction against already-compacted + * history. Zeroing means "unknown, wait for the next real usage report". Cumulative totals are + * intentionally preserved (session-level accounting, must not regress on internal compaction). + */ + public void replaceHistory(List compacted) { + messages.clear(); + messages.addAll(compacted); + this.lastTotalTokens = 0; + } + + public void clear() { + messages.clear(); + } + + public int size() { + return messages.size(); + } + + public long getCumulativePromptTokens() { + return cumulativePromptTokens; + } + + public long getCumulativeCompletionTokens() { + return cumulativeCompletionTokens; + } + + public long getCumulativeTotalTokens() { + return cumulativeTotalTokens; + } + + public long getLastTotalTokens() { + return lastTotalTokens; + } + + /** Add one LLM call's usage to the session cumulative. Intended for {@link AgentRunContext}. */ + public void addCumulativeTokens(long prompt, long completion, long total) { + this.cumulativePromptTokens += prompt; + this.cumulativeCompletionTokens += completion; + this.cumulativeTotalTokens += total; + this.lastTotalTokens = total; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java new file mode 100644 index 00000000000..520cd963ef8 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.core.http.StreamResponse; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionChunk; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionStreamOptions; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * ReAct (Reasoning + Acting) agent loop using the OpenAI official Java SDK. Iterates through LLM + * reasoning, tool execution, and result verification until the agent produces a final answer or + * hits the iteration limit. + * + *

Emits {@link AgentEvent}s via the provided consumer for real-time token-level streaming. + */ +public class ReactAgent { + + private static final Logger LOG = LoggerFactory.getLogger(ReactAgent.class); + private static final ObjectMapper JSON = new ObjectMapper(); + + private final OpenAIClient client; + private final String defaultModelName; + private final ToolRegistry toolRegistry; + private final List middlewares; + private final ApprovalMiddleware approvalMiddleware; + private final int maxIterations; + private final String systemPrompt; + + public ReactAgent( + OpenAIClient client, + String modelName, + ToolRegistry toolRegistry, + List middlewares, + int maxIterations, + String systemPrompt) { + this.client = client; + this.defaultModelName = modelName; + this.toolRegistry = toolRegistry; + this.middlewares = middlewares; + this.approvalMiddleware = findApprovalMiddleware(middlewares); + this.maxIterations = maxIterations; + this.systemPrompt = systemPrompt; + } + + private static ApprovalMiddleware findApprovalMiddleware(List middlewares) { + for (AgentMiddleware mw : middlewares) { + if (mw instanceof ApprovalMiddleware) return (ApprovalMiddleware) mw; + } + return null; + } + + /** Resolve a pending approval request. Returns false if no pending request matches. */ + public boolean resolveApproval(String requestId, boolean approved) { + if (approvalMiddleware == null) return false; + return approvalMiddleware.resolve(requestId, approved); + } + + /** Fan out session-close to every middleware. Errors in one middleware don't block others. */ + public void closeSession(String sessionId) { + for (AgentMiddleware mw : middlewares) { + try { + mw.onSessionClose(sessionId); + } catch (Exception e) { + LOG.warn("Middleware onSessionClose error", e); + } + } + } + + /** Fan out engine-stop to every middleware. Errors in one middleware don't block others. */ + public void stop() { + for (AgentMiddleware mw : middlewares) { + try { + mw.onStop(); + } catch (Exception e) { + LOG.warn("Middleware onStop error", e); + } + } + } + + /** + * Run the ReAct loop for the given request. + * + * @param request user-facing parameters (question, model override, etc.) + * @param memory the conversation memory (may contain prior context) + * @param eventConsumer callback for each agent event (token-level streaming) + */ + public void run( + AgentInvocation request, ConversationMemory memory, Consumer eventConsumer) { + String userInput = request.getUserInput(); + ApprovalMode approvalMode = request.getApprovalMode(); + String modelNameOverride = request.getModelName(); + + String effectiveModel = + (modelNameOverride != null && !modelNameOverride.isEmpty()) + ? modelNameOverride + : defaultModelName; + + // System prompt is immutable for the lifetime of this agent — only set it on first run + // to avoid redundant overwrites on multi-turn conversations. + if (memory.getSystemPrompt() == null) { + memory.setSystemPrompt(systemPrompt); + } + memory.addUserMessage(userInput); + + AgentRunContext ctx = new AgentRunContext(memory, approvalMode, request.getSessionId()); + ctx.setEventEmitter(event -> emit(ctx, event, eventConsumer)); + dispatchAgentStart(ctx); + emit(ctx, new AgentStart(), eventConsumer); + + try { + for (int step = 1; step <= maxIterations; step++) { + ctx.setIteration(step); + emit(ctx, new StepStart(step), eventConsumer); + + List messages = + resolveMessagesForCall(ctx, memory.buildLlmMessages(), eventConsumer); + if (messages == null) { + // Middleware asked us to skip — AgentError + AgentFinish have already been emitted. + return; + } + + StreamResult result = streamLlmResponse(ctx, messages, effectiveModel, eventConsumer); + if (result.isEmpty()) { + emit(ctx, new AgentError("LLM returned empty response"), eventConsumer); + emitFinish(ctx, eventConsumer); + return; + } + + if (!result.content.isEmpty()) { + emit(ctx, new ContentComplete(result.content), eventConsumer); + } + ChatCompletionAssistantMessageParam assistantMsg = buildAssistantMessage(result); + memory.addAssistantMessage(assistantMsg); + dispatchAfterLlmCall(ctx, assistantMsg); + + if (result.toolCalls == null || result.toolCalls.isEmpty()) { + // No tool calls — agent is done. + emit(ctx, new StepEnd(step), eventConsumer); + emitFinish(ctx, eventConsumer); + return; + } + + executeToolCalls(ctx, memory, result.toolCalls, eventConsumer); + emit(ctx, new StepEnd(step), eventConsumer); + } + + emit( + ctx, new AgentError("Reached maximum iterations (" + maxIterations + ")"), eventConsumer); + emitFinish(ctx, eventConsumer); + + } catch (Exception e) { + LOG.error("Agent run failed", e); + emit( + ctx, new AgentError(e.getClass().getSimpleName() + ": " + e.getMessage()), eventConsumer); + emitFinish(ctx, eventConsumer); + } finally { + dispatchAgentFinish(ctx); + } + } + + /** + * Run {@code beforeLlmCall} middleware against {@code messages}. Returns the messages to send, + * possibly rewritten by middleware, or {@code null} if middleware aborted the call (in which case + * this method has already emitted the terminal events). + */ + private List resolveMessagesForCall( + AgentRunContext ctx, + List messages, + Consumer eventConsumer) { + AgentMiddleware.LlmCallAction action = dispatchBeforeLlmCall(ctx, messages); + if (action instanceof AgentMiddleware.LlmSkip) { + String reason = ((AgentMiddleware.LlmSkip) action).reason(); + emit(ctx, new AgentError("LLM call skipped by middleware: " + reason), eventConsumer); + emitFinish(ctx, eventConsumer); + return null; + } + if (action instanceof AgentMiddleware.LlmModifyMessages) { + return ((AgentMiddleware.LlmModifyMessages) action).messages(); + } + return messages; + } + + private static ChatCompletionAssistantMessageParam buildAssistantMessage(StreamResult result) { + ChatCompletionAssistantMessageParam.Builder b = ChatCompletionAssistantMessageParam.builder(); + if (!result.content.isEmpty()) { + b.content(result.content); + } + if (result.toolCalls != null && !result.toolCalls.isEmpty()) { + b.toolCalls(result.toolCalls); + } + return b.build(); + } + + /** + * Execute the assistant's tool calls in 3 phases: + * + *

    + *
  1. Serial: run {@code beforeToolCall} middleware, emit {@link ToolCall} events, and collect + * the calls that survived approval. + *
  2. Concurrent: fan out to {@link ToolRegistry#submitTool}, which always returns a future + * that completes normally — timeouts and execution errors surface as error strings. + *
  3. Serial: join futures in order, run {@code afterToolCall}, and record results to memory. + *
+ */ + private void executeToolCalls( + AgentRunContext ctx, + ConversationMemory memory, + List toolCalls, + Consumer eventConsumer) { + List approved = new ArrayList<>(); + for (ChatCompletionMessageToolCall toolCall : toolCalls) { + ChatCompletionMessageFunctionToolCall fnCall = toolCall.asFunction(); + String toolName = fnCall.function().name(); + Map toolArgs; + try { + toolArgs = parseToolArgs(fnCall.function().arguments()); + } catch (IllegalArgumentException e) { + // Malformed JSON from the LLM: record an error tool_result (preserves the + // assistant/tool_result pairing the next API call needs) and let the loop self-correct. + String err = "Tool call failed: " + e.getMessage(); + memory.addToolResult(fnCall.id(), err); + emit(ctx, new ToolResult(fnCall.id(), toolName, err, true), eventConsumer); + continue; + } + + AgentMiddleware.ToolCallDenial denial = + dispatchBeforeToolCall(ctx, fnCall.id(), toolName, toolArgs); + if (denial != null) { + String denied = "Tool call denied: " + denial.reason(); + memory.addToolResult(fnCall.id(), denied); + emit(ctx, new ToolResult(fnCall.id(), toolName, denied, true), eventConsumer); + continue; + } + + emit(ctx, new ToolCall(fnCall.id(), toolName, toolArgs), eventConsumer); + approved.add(new ToolCallEntry(fnCall, toolName, toolArgs)); + } + + ToolContext toolCtx = new ToolContext(ctx.getSessionId()); + List> futures = new ArrayList<>(approved.size()); + for (ToolCallEntry entry : approved) { + futures.add( + toolRegistry.submitTool(entry.toolName, entry.fnCall.function().arguments(), toolCtx)); + } + + for (int i = 0; i < approved.size(); i++) { + ToolCallEntry entry = approved.get(i); + String output = futures.get(i).join(); + String modified = dispatchAfterToolCall(ctx, entry.toolName, entry.toolArgs, output); + if (modified != null) { + output = modified; + } + memory.addToolResult(entry.fnCall.id(), output); + emit(ctx, new ToolResult(entry.fnCall.id(), entry.toolName, output, false), eventConsumer); + } + } + + /** Result of a streaming LLM call, assembled from chunks. */ + private static class StreamResult { + final String content; + final List toolCalls; + + StreamResult(String content, List toolCalls) { + this.content = content; + this.toolCalls = toolCalls; + } + + boolean isEmpty() { + return content.isEmpty() && (toolCalls == null || toolCalls.isEmpty()); + } + } + + /** Holds an approved tool call's parsed metadata for the 3-phase execution pipeline. */ + private static class ToolCallEntry { + final ChatCompletionMessageFunctionToolCall fnCall; + final String toolName; + final Map toolArgs; + + ToolCallEntry( + ChatCompletionMessageFunctionToolCall fnCall, + String toolName, + Map toolArgs) { + this.fnCall = fnCall; + this.toolName = toolName; + this.toolArgs = toolArgs; + } + } + + /** + * Stream LLM response, emitting ContentDelta for each text chunk. Assembles tool calls directly + * from streamed chunks — no non-streaming fallback. Exceptions propagate to the top-level handler + * in {@link #run}. + */ + private StreamResult streamLlmResponse( + AgentRunContext ctx, + List messages, + String effectiveModel, + Consumer eventConsumer) { + ChatCompletionCreateParams.Builder paramsBuilder = + ChatCompletionCreateParams.builder() + .model(effectiveModel) + .streamOptions(ChatCompletionStreamOptions.builder().includeUsage(true).build()); + for (ChatCompletionMessageParam msg : messages) { + paramsBuilder.addMessage(msg); + } + toolRegistry.addToolsTo(paramsBuilder); + + LOG.info("LLM request: model={}", effectiveModel); + StreamAccumulator acc = new StreamAccumulator(); + try (StreamResponse stream = + client.chat().completions().createStreaming(paramsBuilder.build())) { + stream.stream().forEach(chunk -> consumeChunk(ctx, chunk, acc, eventConsumer)); + } + return new StreamResult(acc.content.toString(), acc.buildToolCalls()); + } + + /** Fold one streaming chunk into {@code acc}, emitting per-token {@link ContentDelta}s. */ + private void consumeChunk( + AgentRunContext ctx, + ChatCompletionChunk chunk, + StreamAccumulator acc, + Consumer eventConsumer) { + if (!acc.serverModelLogged) { + LOG.info("LLM response: server-echoed model={}", chunk.model()); + acc.serverModelLogged = true; + } + chunk + .usage() + .ifPresent(u -> ctx.addTokenUsage(u.promptTokens(), u.completionTokens(), u.totalTokens())); + + for (ChatCompletionChunk.Choice c : chunk.choices()) { + c.delta() + .content() + .ifPresent( + text -> { + acc.content.append(text); + emit(ctx, new ContentDelta(text), eventConsumer); + }); + c.delta().toolCalls().ifPresent(acc::mergeToolCallDeltas); + } + } + + /** + * Mutable accumulator for a single streaming LLM turn. Tool call fields are keyed by the chunk's + * {@code index} because provider SDKs may deliver a single logical call across multiple chunks + * and only surface the {@code id}/{@code name} on the first one. + */ + private static final class StreamAccumulator { + final StringBuilder content = new StringBuilder(); + final Map toolCallIds = new HashMap<>(); + final Map toolCallNames = new HashMap<>(); + final Map toolCallArgs = new HashMap<>(); + boolean serverModelLogged = false; + + void mergeToolCallDeltas(List deltas) { + for (ChatCompletionChunk.Choice.Delta.ToolCall tc : deltas) { + int idx = (int) tc.index(); + tc.id().ifPresent(id -> toolCallIds.put(idx, id)); + tc.function() + .ifPresent( + fn -> { + fn.name().ifPresent(name -> toolCallNames.put(idx, name)); + fn.arguments() + .ifPresent( + args -> + toolCallArgs + .computeIfAbsent(idx, k -> new StringBuilder()) + .append(args)); + }); + } + } + + /** + * Materialize accumulated deltas into SDK tool-call objects. Returns {@code null} (not an empty + * list) if no tool calls were seen, matching the existing {@link StreamResult} contract. + */ + List buildToolCalls() { + if (toolCallIds.isEmpty()) return null; + List out = new ArrayList<>(toolCallIds.size()); + for (Map.Entry e : toolCallIds.entrySet()) { + int idx = e.getKey(); + String id = (e.getValue() == null || e.getValue().isEmpty()) ? synthId() : e.getValue(); + String args = toolCallArgs.containsKey(idx) ? toolCallArgs.get(idx).toString() : "{}"; + out.add( + ChatCompletionMessageToolCall.ofFunction( + ChatCompletionMessageFunctionToolCall.builder() + .id(id) + .function( + ChatCompletionMessageFunctionToolCall.Function.builder() + .name(toolCallNames.getOrDefault(idx, "")) + .arguments(args) + .build()) + .build())); + } + return out; + } + + /** + * Synthesize an id for tool calls whose id never arrived on the stream (some OpenAI-compatible + * providers omit it). The id has to be stable within a turn and unique across turns so the + * assistant/tool_result pairing downstream holds. + */ + private static String synthId() { + return "local_" + java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 24); + } + } + + private static Map parseToolArgs(String json) { + if (json == null || json.isEmpty()) { + return new HashMap<>(); + } + try { + return JSON.readValue(json, new TypeReference>() {}); + } catch (java.io.IOException e) { + throw new IllegalArgumentException("Malformed tool-call arguments JSON: " + json, e); + } + } + + // --- Middleware dispatch methods --- + // + // Middlewares are internal framework code. If one throws, the agent run fails via the + // top-level catch in run() — we do not wrap individual dispatch calls in try/catch. + + private void emitFinish(AgentRunContext ctx, Consumer eventConsumer) { + emit( + ctx, + new AgentFinish( + ctx.getIteration(), + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()), + eventConsumer); + } + + private void emit(AgentRunContext ctx, AgentEvent event, Consumer consumer) { + AgentEvent filtered = event; + for (AgentMiddleware mw : middlewares) { + filtered = mw.onEvent(ctx, filtered); + if (filtered == null) return; + } + consumer.accept(filtered); + } + + private void dispatchAgentStart(AgentRunContext ctx) { + for (AgentMiddleware mw : middlewares) { + mw.onAgentStart(ctx); + } + } + + private void dispatchAgentFinish(AgentRunContext ctx) { + // Runs even when the agent body threw, so swallow here to ensure every middleware's cleanup + // gets a chance to run; otherwise we'd leak session state in later middlewares. + for (int i = middlewares.size() - 1; i >= 0; i--) { + try { + middlewares.get(i).onAgentFinish(ctx); + } catch (Exception e) { + LOG.warn("Middleware onAgentFinish error", e); + } + } + } + + private AgentMiddleware.LlmCallAction dispatchBeforeLlmCall( + AgentRunContext ctx, List messages) { + for (AgentMiddleware mw : middlewares) { + AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, messages); + if (action != null) return action; + } + return null; + } + + private void dispatchAfterLlmCall( + AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + for (int i = middlewares.size() - 1; i >= 0; i--) { + middlewares.get(i).afterLlmCall(ctx, response); + } + } + + private AgentMiddleware.ToolCallDenial dispatchBeforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + for (AgentMiddleware mw : middlewares) { + AgentMiddleware.ToolCallDenial denial = + mw.beforeToolCall(ctx, toolCallId, toolName, toolArgs); + if (denial != null) return denial; + } + return null; + } + + private String dispatchAfterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + String modified = null; + for (int i = middlewares.size() - 1; i >= 0; i--) { + String mwResult = + middlewares + .get(i) + .afterToolCall(ctx, toolName, toolArgs, modified != null ? modified : result); + if (mwResult != null) { + modified = mwResult; + } + } + return modified; + } + + // --- Builder --- + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private OpenAIClient client; + private String modelName; + private ToolRegistry toolRegistry = new ToolRegistry(ToolRegistry.DEFAULT_TIMEOUT_SECONDS); + private final List middlewares = new ArrayList<>(); + private int maxIterations = 20; + private String systemPrompt; + + public Builder client(OpenAIClient client) { + this.client = client; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder toolRegistry(ToolRegistry toolRegistry) { + this.toolRegistry = toolRegistry; + return this; + } + + public Builder addMiddleware(AgentMiddleware middleware) { + this.middlewares.add(middleware); + return this; + } + + public Builder maxIterations(int maxIterations) { + if (maxIterations < 1) { + throw new IllegalArgumentException("maxIterations must be >= 1, got " + maxIterations); + } + this.maxIterations = maxIterations; + return this; + } + + public Builder systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + public ReactAgent build() { + if (client == null) throw new IllegalStateException("client is required"); + if (modelName == null) throw new IllegalStateException("modelName is required"); + if (toolRegistry == null) throw new IllegalStateException("toolRegistry is required"); + for (AgentMiddleware mw : middlewares) { + mw.onRegister(toolRegistry); + } + return new ReactAgent( + client, modelName, toolRegistry, middlewares, maxIterations, systemPrompt); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java new file mode 100644 index 00000000000..d4a8a97e97d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; +import java.util.stream.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Per-session temp-file store for large tool outputs that the input gate has offloaded from + * conversation history. + * + *

Layout: {@code //tool_.txt}, where {@code } is a + * per-engine randomly-named temp directory (see {@link #create()}). UTF-8 text. + * + *

Isolation: every {@code read}/{@code grep} call requires the current session id and + * validates that the caller-supplied path resolves under {@code /}, not merely + * under {@code }. A session cannot read another session's offloaded output even if it somehow + * obtained the absolute path. The per-engine random root additionally isolates different engine + * processes sharing the same host. + * + *

Failures (traversal, missing file, session-id mismatch, IO) are reported as error strings + * rather than thrown — the tool must keep the agent loop alive. + */ +public class ToolOutputStore implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(ToolOutputStore.class); + private static final String ROOT_PREFIX = "kyuubi-data-agent-"; + + private final Path root; + + /** + * Create a store backed by a fresh, per-engine random temp directory under {@code + * java.io.tmpdir}. + */ + public static ToolOutputStore create() { + try { + return new ToolOutputStore(Files.createTempDirectory(ROOT_PREFIX)); + } catch (IOException e) { + throw new IllegalStateException("Failed to create ToolOutputStore temp root", e); + } + } + + private ToolOutputStore(Path root) { + try { + Files.createDirectories(root); + this.root = root.toRealPath(); + } catch (IOException e) { + throw new IllegalStateException("Failed to initialize ToolOutputStore root: " + root, e); + } + } + + public Path getRoot() { + return root; + } + + /** Write {@code content} to {@code //tool_.txt}. */ + public Path write(String sessionId, String toolCallId, String content) throws IOException { + Path dir = root.resolve(safeSegment(sessionId)); + Files.createDirectories(dir); + Path file = dir.resolve("tool_" + safeSegment(toolCallId) + ".txt"); + Files.write(file, content.getBytes(StandardCharsets.UTF_8)); + return file; + } + + /** + * Read a line window. Returns a human-readable block including a 1-based {@code [lines X-Y of Z + * total]} header, or an error string on traversal / IO failure / cross-session access. + */ + public String read(String sessionId, String pathStr, long offset, int limit) { + Path file = validatePath(sessionId, pathStr); + if (file == null) { + return "Error: path is outside this session's tool-output directory or does not exist: " + + pathStr; + } + if (offset < 0) offset = 0; + if (limit <= 0) limit = 1; + + List taken = new ArrayList<>(limit); + long totalLines = 0; + try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) { + String line; + while ((line = br.readLine()) != null) { + if (totalLines >= offset && taken.size() < limit) { + taken.add(line); + } + totalLines++; + } + } catch (IOException e) { + return "Error reading " + pathStr + ": " + e.getMessage(); + } + + long fromLine = offset + 1; // 1-based + long toLineExclusive = Math.min(offset + limit, totalLines); + StringBuilder sb = new StringBuilder(); + sb.append("[lines ") + .append(fromLine) + .append("-") + .append(toLineExclusive) + .append(" of ") + .append(totalLines) + .append(" total]\n"); + for (String line : taken) { + sb.append(line).append('\n'); + } + return sb.toString(); + } + + /** + * Stream-grep the file. Returns at most {@code maxMatches} matches as {@code lineNo:content}, one + * per line; or an error string on traversal / regex / IO failure / cross-session access. + */ + public String grep(String sessionId, String pathStr, String patternStr, int maxMatches) { + Path file = validatePath(sessionId, pathStr); + if (file == null) { + return "Error: path is outside this session's tool-output directory or does not exist: " + + pathStr; + } + if (patternStr == null || patternStr.isEmpty()) { + return "Error: 'pattern' parameter is required."; + } + if (maxMatches <= 0) maxMatches = 50; + + Pattern pattern; + try { + pattern = Pattern.compile(patternStr); + } catch (PatternSyntaxException e) { + return "Error: invalid regex pattern: " + e.getMessage(); + } + + StringBuilder sb = new StringBuilder(); + int matches = 0; + long lineNo = 0; + try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) { + String line; + while ((line = br.readLine()) != null) { + lineNo++; + if (pattern.matcher(line).find()) { + sb.append(lineNo).append(':').append(line).append('\n'); + matches++; + if (matches >= maxMatches) break; + } + } + } catch (IOException e) { + return "Error reading " + pathStr + ": " + e.getMessage(); + } + if (matches == 0) { + return "[no matches for pattern: " + patternStr + "]"; + } + return "[" + matches + " match" + (matches == 1 ? "" : "es") + "]\n" + sb; + } + + /** Recursively delete the session's subtree. Safe to call on missing sessions. */ + public void cleanupSession(String sessionId) { + if (sessionId == null) return; + Path dir = root.resolve(safeSegment(sessionId)); + deleteTree(dir); + } + + /** Delete everything below (and including) the root. Idempotent; safe to call multiple times. */ + @Override + public void close() { + deleteTree(root); + } + + private static void deleteTree(Path dir) { + if (!Files.exists(dir)) return; + try (Stream stream = Files.walk(dir)) { + stream.sorted(Comparator.reverseOrder()).forEach(ToolOutputStore::deleteQuietly); + } catch (IOException e) { + LOG.warn("Failed to clean up dir {}", dir, e); + } + } + + private static void deleteQuietly(Path p) { + try { + Files.deleteIfExists(p); + } catch (IOException e) { + LOG.debug("Failed to delete {}", p, e); + } + } + + /** + * Resolve {@code pathStr} and return it only if (a) it exists as a regular file and (b) the real + * path is under {@code /}. Returns null on any violation — including a null or + * empty session id, since without one we cannot scope the check. + */ + private Path validatePath(String sessionId, String pathStr) { + if (pathStr == null || pathStr.isEmpty()) return null; + if (sessionId == null || sessionId.isEmpty()) return null; + Path sessionRoot = root.resolve(safeSegment(sessionId)); + try { + Path real = Paths.get(pathStr).toRealPath(); + if (!real.startsWith(sessionRoot)) return null; + if (!Files.isRegularFile(real)) return null; + return real; + } catch (IOException | SecurityException e) { + return null; + } + } + + /** Strip anything that could escape a single path segment. */ + private static String safeSegment(String raw) { + if (raw == null || raw.isEmpty()) return "_"; + StringBuilder sb = new StringBuilder(raw.length()); + for (int i = 0; i < raw.length(); i++) { + char c = raw.charAt(i); + if (Character.isLetterOrDigit(c) || c == '-' || c == '_' || c == '.') { + sb.append(c); + } else { + sb.append('_'); + } + } + return sb.toString(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java new file mode 100644 index 00000000000..26eb6024e98 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.event; + +/** + * Emitted by {@code CompactionMiddleware} after it has summarized a prefix of the conversation + * history and replaced it in memory. Purely observational — the LLM call that immediately follows + * uses the already-compacted history, so consumers just see this as a side-channel notice that + * compaction happened. The summary text itself is intentionally not included: it can be large and + * would bloat the event stream; operators who need it can read the middleware log. + */ +public final class Compaction extends AgentEvent { + private final int summarizedCount; + private final int keptCount; + private final long triggerTokens; + private final long observedTokens; + + public Compaction(int summarizedCount, int keptCount, long triggerTokens, long observedTokens) { + super(EventType.COMPACTION); + this.summarizedCount = summarizedCount; + this.keptCount = keptCount; + this.triggerTokens = triggerTokens; + this.observedTokens = observedTokens; + } + + public int summarizedCount() { + return summarizedCount; + } + + public int keptCount() { + return keptCount; + } + + public long triggerTokens() { + return triggerTokens; + } + + public long observedTokens() { + return observedTokens; + } + + @Override + public String toString() { + return "Compaction{summarized=" + + summarizedCount + + ", kept=" + + keptCount + + ", triggerTokens=" + + triggerTokens + + ", observedTokens=" + + observedTokens + + "}"; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java index 937422e2bf5..d58e5de2ee7 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java @@ -21,6 +21,8 @@ * Enumerates the types of events emitted by the ReAct agent loop. Each value maps to a * corresponding {@link AgentEvent} subclass and carries an SSE event name used for wire * serialization. + * + * @see org.apache.kyuubi.engine.dataagent.runtime.ReactAgent */ public enum EventType { @@ -51,6 +53,9 @@ public enum EventType { /** The agent requires user approval before executing a tool. */ APPROVAL_REQUEST("approval_request"), + /** The conversation history was compacted by the compaction middleware. */ + COMPACTION("compaction"), + /** The agent has finished its analysis. */ AGENT_FINISH("agent_finish"); diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java new file mode 100644 index 00000000000..c934bb0882a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; + +/** + * Middleware interface for the Data Agent ReAct loop. Middlewares are executed in onion-model + * order: before_* hooks run first-to-last, after_* hooks run last-to-first. + * + *

All hooks have default no-op implementations. Override only what you need. + */ +public interface AgentMiddleware { + + /** + * Called once when the middleware is wired into the agent. Register companion tools that are part + * of the middleware's contract, or capture a reference to the registry for later use. Dispatched + * by {@code ReactAgent.Builder.build} before the agent accepts any requests. + */ + default void onRegister(ToolRegistry registry) {} + + /** Called when the agent starts processing a user query. Runs first-to-last. */ + default void onAgentStart(AgentRunContext ctx) {} + + /** Called when the agent finishes. Runs last-to-first (cleanup order). */ + default void onAgentFinish(AgentRunContext ctx) {} + + /** + * Called before each LLM invocation. Return non-null to skip or modify the LLM call. Runs + * first-to-last. + * + * @return {@code null} to proceed normally, {@link LlmSkip} to abort, or {@link + * LlmModifyMessages} to replace the message list for this call. + */ + default LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + return null; + } + + /** Called after each LLM invocation. Runs last-to-first. */ + default void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {} + + /** Called before each tool execution. Return non-null to deny the call. Runs first-to-last. */ + default ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + return null; + } + + /** + * Called after each tool execution. Runs last-to-first. + * + *

Returns {@code String} (not {@code void}) so that middlewares can intercept and transform + * the tool result before it is fed back to the LLM — e.g. for data masking, output truncation, or + * injecting metadata. Return {@code null} to keep the original result unchanged; return a + * non-null value to replace it. + */ + default String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + return null; + } + + /** + * Called for every event before it is emitted. Return null to suppress the event. Runs + * first-to-last. + */ + default AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { + return event; + } + + /** + * Called when a session is closed. Clean up per-session state (scratch files, pending tasks, + * counters). Idempotent. Dispatched by {@code ReactAgent.closeSession}. + */ + default void onSessionClose(String sessionId) {} + + /** + * Called when the engine is stopping. Release global resources and unblock any threads still + * waiting on this middleware. Dispatched by {@code ReactAgent.stop}. + */ + default void onStop() {} + + /** + * Base type for {@code beforeLlmCall} return values. Subtypes: {@link LlmSkip} to abort the LLM + * call, {@link LlmModifyMessages} to replace the message list for this call. + */ + abstract class LlmCallAction { + private LlmCallAction() {} + } + + /** Returned from {@code beforeLlmCall} to skip the LLM call and abort the agent loop. */ + class LlmSkip extends LlmCallAction { + private final String reason; + + public LlmSkip(String reason) { + this.reason = reason; + } + + public String reason() { + return reason; + } + } + + /** + * Returned from {@code beforeLlmCall} to replace the message list for this LLM invocation. The + * agent loop continues normally with the modified messages. + */ + class LlmModifyMessages extends LlmCallAction { + private final List messages; + + public LlmModifyMessages(List messages) { + this.messages = messages; + } + + public List messages() { + return messages; + } + } + + /** Returned from {@code beforeToolCall} to deny a tool call. Non-null means denied. */ + class ToolCallDenial { + private final String reason; + + public ToolCallDenial(String reason) { + this.reason = reason; + } + + public String reason() { + return reason; + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java new file mode 100644 index 00000000000..92d25b47b9d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Middleware that enforces human-in-the-loop approval for tool calls based on the {@link + * ApprovalMode} and the tool's {@link ToolRiskLevel}. + * + *

When approval is required, an {@link ApprovalRequest} event is emitted to the client via + * {@link AgentRunContext#emit}, and the agent thread blocks until the client responds via {@link + * #resolve} or the timeout expires. + */ +public class ApprovalMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(ApprovalMiddleware.class); + + private static final long DEFAULT_TIMEOUT_SECONDS = 300; // 5 minutes + + private final long timeoutSeconds; + private final ConcurrentHashMap> pending = + new ConcurrentHashMap<>(); + private ToolRegistry toolRegistry; + + public ApprovalMiddleware() { + this(DEFAULT_TIMEOUT_SECONDS); + } + + public ApprovalMiddleware(long timeoutSeconds) { + this.timeoutSeconds = timeoutSeconds; + } + + @Override + public void onRegister(ToolRegistry registry) { + this.toolRegistry = registry; + } + + @Override + public ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + ToolRiskLevel riskLevel = toolRegistry.getRiskLevel(toolName); + + if (shouldAutoApprove(ctx.getApprovalMode(), riskLevel)) { + return null; + } + + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + pending.put(requestId, future); + + ctx.emit(new ApprovalRequest(requestId, toolCallId, toolName, toolArgs, riskLevel)); + LOG.info("Approval requested for tool '{}' (requestId={})", toolName, requestId); + + try { + boolean approved = future.get(timeoutSeconds, TimeUnit.SECONDS); + if (!approved) { + LOG.info("Tool '{}' denied by user (requestId={})", toolName, requestId); + return new ToolCallDenial("User denied execution of " + toolName); + } + LOG.info("Tool '{}' approved by user (requestId={})", toolName, requestId); + return null; + } catch (TimeoutException e) { + // Complete the future so that a late resolve() call is a harmless no-op + // instead of completing a dangling future. + future.completeExceptionally(e); + LOG.warn("Approval timed out for tool '{}' (requestId={})", toolName, requestId); + return new ToolCallDenial("Approval timed out after " + timeoutSeconds + "s for " + toolName); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return new ToolCallDenial("Approval interrupted for " + toolName); + } catch (Exception e) { + LOG.error("Unexpected error waiting for approval", e); + return new ToolCallDenial("Approval error: " + e.getMessage()); + } finally { + pending.remove(requestId); + } + } + + /** + * Resolve a pending approval request. Called by the external approval channel (e.g. a Kyuubi + * operation or REST endpoint). + * + * @param requestId the request ID from the {@link ApprovalRequest} event + * @param approved true to approve, false to deny + * @return true if the request was found and resolved, false if not found (already timed out or + * invalid ID) + */ + public boolean resolve(String requestId, boolean approved) { + CompletableFuture future = pending.get(requestId); + if (future != null) { + return future.complete(approved); + } + LOG.warn("No pending approval found for requestId={}", requestId); + return false; + } + + /** + * Cancel all pending approval requests to unblock any waiting agent threads. Invoked as part of + * engine shutdown via {@code ReactAgent.stop}. + */ + @Override + public void onStop() { + InterruptedException ex = new InterruptedException("Session closed"); + pending.forEachKey( + Long.MAX_VALUE, + key -> { + CompletableFuture future = pending.remove(key); + if (future != null) { + future.completeExceptionally(ex); + } + }); + } + + private static boolean shouldAutoApprove(ApprovalMode mode, ToolRiskLevel riskLevel) { + if (mode == ApprovalMode.AUTO_APPROVE) { + return true; + } + if (mode == ApprovalMode.NORMAL && riskLevel == ToolRiskLevel.SAFE) { + return true; + } + // STRICT: all tools require approval + return false; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java new file mode 100644 index 00000000000..acab7ae8b75 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.client.OpenAIClient; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.event.Compaction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Middleware that compacts conversation history when the prompt grows large. + * + *

Trigger formula: + * + *

+ *   predicted_this_turn_prompt_tokens
+ *       = last_llm_call_total_tokens            // prompt + completion of the previous call
+ *       + estimate_new_tail(messages)           // chars / 4, for content appended after the
+ *                                               // last assistant message (tool results, new user)
+ * 
+ * + * History (already sent to the LLM) must use real token counts — we read them straight off {@link + * ConversationMemory#getLastTotalTokens()}, which the provider updates after every call. We use + * total (prompt + completion) rather than just prompt because the previous call's completion + * — e.g. an assistant {@code tool_call} message — is already appended to history and will be + * tokenized into the next prompt; the tail estimator starts strictly after the last + * assistant message, so without completion we would miss it. Only content appended since that + * assistant message is estimated, which catches spikes before the next API report arrives. + * + *

Post-compaction persistence: the summary + kept tail replace {@link + * ConversationMemory}'s history via {@link ConversationMemory#replaceHistory}. All subsequent turns + * in this session read the compacted history — we do not re-summarize each turn. Because the next + * LLM call uses the compacted messages, its reported {@code prompt_tokens} will be small, naturally + * preventing retriggering. + * + *

Thread safety / shared instance: Instances of this middleware are shared across all + * sessions inside a provider (see {@code OpenAiProvider} javadoc). All per-session state + * (cumulative and last-call totals) lives on {@link ConversationMemory}, so this middleware itself + * is stateless across sessions and requires no per-session cleanup. + * + *

Tool-call pair invariant: The split never separates an assistant message bearing {@code + * tool_calls} from the {@code tool_result} messages that satisfy those calls — an orphan + * tool_result is rejected by the OpenAI API with HTTP 400. + * + *

Failure handling: summarizer failures propagate out of {@code beforeLlmCall} to the + * agent's top-level catch, which surfaces an {@code AgentError} event. We don't silently skip + * compaction — a broken summarizer is a real problem the operator needs to see. + * + *

Disabling: to effectively turn compaction off, construct with a very large {@code + * triggerPromptTokens} (e.g., {@link Long#MAX_VALUE}). + * + *

TODO: support a separate (cheaper) summarization model distinct from the main agent model. + */ +public class CompactionMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(CompactionMiddleware.class); + + /** Number of recent user turns (and their assistant/tool companions) to preserve verbatim. */ + private static final int KEEP_RECENT_TURNS = 4; + + private static final String COMPACTION_SYSTEM_PROMPT = + "SYSTEM OPERATION — this is an automated context compaction step, NOT a user message.\n" + + "\n" + + "You are summarizing a conversation between a user and a Data Agent (a ReAct agent" + + " that executes SQL and analytics tools against Kyuubi). The user has NOT asked you" + + " to summarize anything. Do not address the user. Do not ask questions. Produce only" + + " the summary in the schema below.\n" + + "\n" + + "Goal: produce a dense, structured summary the agent can resume from without losing" + + " critical context. Preserve concrete details verbatim — file paths, table names," + + " schema definitions, SQL snippets, column names, error messages.\n" + + "\n" + + "Output EXACTLY these 8 sections, in order, as markdown headers:\n" + + "\n" + + "1. ## User Intent\n" + + " The user's original request, restated in full. Preserve the literal phrasing of" + + " the ask. Include follow-up refinements.\n" + + "\n" + + "2. ## Key Concepts\n" + + " Domain terms, data sources, tables, schemas, SQL dialects, and business logic" + + " the agent has been reasoning about.\n" + + "\n" + + "3. ## Files and Code\n" + + " File paths, query text, DDL, or code artifacts referenced. Include verbatim SQL" + + " snippets that produced meaningful results.\n" + + "\n" + + "4. ## Errors and Recoveries\n" + + " Errors encountered (SQL syntax, permission, timeout, tool failures), what was" + + " tried, and what resolved them. Preserve error messages verbatim.\n" + + "\n" + + "5. ## Pending Work\n" + + " Tasks the agent identified but has not completed yet.\n" + + "\n" + + "6. ## Current State\n" + + " Where the agent is right now — what question is open, what data has been" + + " retrieved, what hypothesis is being tested.\n" + + "\n" + + "7. ## Next Step\n" + + " The immediate next action the agent should take when resuming.\n" + + "\n" + + "8. ## Tool Usage Summary\n" + + " Which tools were called, how many times, and notable results.\n" + + "\n" + + "CRITICAL:\n" + + "- DO NOT ask the user about this summary.\n" + + "- DO NOT mention that compaction occurred in any future assistant response.\n" + + "- DO NOT invent details not present in the conversation.\n" + + "- DO NOT output anything outside the 8 sections.\n"; + + private final OpenAIClient client; + private final String summarizerModel; + private final long triggerPromptTokens; + + public CompactionMiddleware( + OpenAIClient client, String summarizerModel, long triggerPromptTokens) { + this.client = client; + this.summarizerModel = summarizerModel; + this.triggerPromptTokens = triggerPromptTokens; + } + + @Override + public LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + ConversationMemory mem = ctx.getMemory(); + // 1) Real token count of the previous LLM call (prompt + completion, i.e. everything through + // the last assistant message, which is now part of history). 0 on the first call. + long lastTotal = mem.getLastTotalTokens(); + // 2) Estimated tokens appended to the tail after the last assistant (tool_results, new user). + long newTailEstimate = estimateTailAfterLastAssistant(messages); + + if (lastTotal + newTailEstimate < triggerPromptTokens) { + return null; + } + + List history = mem.getHistory(); + + // 3) Split history into old (to summarize) and kept (recent tail), never orphaning a + // tool_result. + Split split = computeSplit(history, KEEP_RECENT_TURNS); + if (split.old.isEmpty()) { + return null; + } + + String summary = summarize(mem.getSystemPrompt(), split.old); + + // 4) Build the compacted history and persist into ConversationMemory. + List compacted = new ArrayList<>(1 + split.kept.size()); + compacted.add(wrapSummaryAsUserMessage(summary)); + compacted.addAll(split.kept); + mem.replaceHistory(compacted); + + LOG.info( + "Compacted {} old msgs into 1 summary; kept {} tail msgs (lastTotal={}, newTail~={})", + split.old.size(), + split.kept.size(), + lastTotal, + newTailEstimate); + + ctx.emit( + new Compaction( + split.old.size(), split.kept.size(), triggerPromptTokens, lastTotal + newTailEstimate)); + + return new LlmModifyMessages(mem.buildLlmMessages()); + } + + /** Call the LLM to produce a summary of {@code oldMessages}. Failures propagate. */ + private String summarize(String agentSystemPrompt, List oldMessages) { + String systemPrompt = COMPACTION_SYSTEM_PROMPT; + if (agentSystemPrompt != null && !agentSystemPrompt.isEmpty()) { + systemPrompt = + systemPrompt + + "\n---\nFor context, the agent's own system prompt is:\n" + + agentSystemPrompt; + } + + String rendered = renderAsText(oldMessages); + + ChatCompletionCreateParams params = + ChatCompletionCreateParams.builder() + .model(summarizerModel) + .temperature(0.0) + .addMessage( + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content(systemPrompt).build())) + .addMessage( + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(rendered).build())) + .build(); + + ChatCompletion response = client.chat().completions().create(params); + return response.choices().get(0).message().content().get(); + } + + // ----- helpers ----- + + /** Sum of content characters in messages after the last assistant, using ~4 chars per token. */ + static long estimateTailAfterLastAssistant(List messages) { + int lastAssistantIdx = -1; + for (int i = messages.size() - 1; i >= 0; i--) { + if (messages.get(i).isAssistant()) { + lastAssistantIdx = i; + break; + } + } + long totalChars = 0; + for (int i = lastAssistantIdx + 1; i < messages.size(); i++) { + totalChars += contentCharCount(messages.get(i)); + } + return totalChars / 4; + } + + private static long contentCharCount(ChatCompletionMessageParam msg) { + if (msg.isUser()) { + return msg.asUser().content().text().map(String::length).orElse(0); + } + if (msg.isTool()) { + return msg.asTool().content().text().map(String::length).orElse(0); + } + if (msg.isAssistant()) { + return msg.asAssistant().content().flatMap(c -> c.text()).map(String::length).orElse(0); + } + if (msg.isSystem()) { + return msg.asSystem().content().text().map(String::length).orElse(0); + } + return 0; + } + + /** + * Render a list of messages as plain text for the summarizer's user turn. Tool calls and tool + * results are rendered as tagged text so the summarizer LLM doesn't try to continue them as live + * agent state. + */ + static String renderAsText(List messages) { + StringBuilder sb = new StringBuilder(); + for (ChatCompletionMessageParam msg : messages) { + if (sb.length() > 0) sb.append("\n\n"); + if (msg.isUser()) { + sb.append("USER: ").append(extractUserContent(msg)); + } else if (msg.isAssistant()) { + sb.append("ASSISTANT: ").append(extractAssistantContent(msg)); + msg.asAssistant() + .toolCalls() + .ifPresent( + calls -> { + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction()) { + sb.append("\n[tool_call: ") + .append(tc.asFunction().function().name()) + .append("(") + .append(tc.asFunction().function().arguments()) + .append(") id=") + .append(tc.asFunction().id()) + .append("]"); + } + } + }); + } else if (msg.isTool()) { + sb.append("[tool_result id=") + .append(msg.asTool().toolCallId()) + .append("]: ") + .append(extractToolContent(msg)); + } else if (msg.isSystem()) { + // system prompt should not appear in oldMessages, but render defensively + sb.append("SYSTEM: ").append(extractSystemContent(msg)); + } + } + return sb.toString(); + } + + private static String extractUserContent(ChatCompletionMessageParam msg) { + return msg.asUser().content().text().orElse("[non-text content]"); + } + + private static String extractAssistantContent(ChatCompletionMessageParam msg) { + return msg.asAssistant().content().map(c -> c.text().orElse("[non-text content]")).orElse(""); + } + + private static String extractToolContent(ChatCompletionMessageParam msg) { + return msg.asTool().content().text().orElse("[non-text content]"); + } + + private static String extractSystemContent(ChatCompletionMessageParam msg) { + return msg.asSystem().content().text().orElse("[non-text content]"); + } + + /** Result of splitting the history into a summarizable prefix and a kept tail. */ + static final class Split { + + final List old; + final List kept; + + Split(List old, List kept) { + this.old = old; + this.kept = kept; + } + } + + /** + * Split the history at a boundary that preserves the last {@code keepRecentTurns} user messages, + * with adjustments so that no assistant-tool_use is separated from its tool_results. + */ + static Split computeSplit(List history, int keepRecentTurns) { + if (history.size() <= 2) { + return new Split(new ArrayList<>(), new ArrayList<>(history)); + } + + // Walk from the tail, count user boundaries. If the history does not contain enough user + // messages to satisfy keepRecentTurns, keep everything (splitIdx = 0); the empty-old check + // in beforeLlmCall will then skip this turn gracefully. + int userBoundariesFound = 0; + int splitIdx = 0; + for (int i = history.size() - 1; i >= 0; i--) { + if (history.get(i).isUser()) { + userBoundariesFound++; + if (userBoundariesFound == keepRecentTurns) { + splitIdx = i; + break; + } + } + } + + // Protect tool-call / tool-result pairing: never split between an assistant that issued + // tool_calls and the tool_results that satisfy them. + while (splitIdx > 0) { + ChatCompletionMessageParam prev = history.get(splitIdx - 1); + if (prev.isTool()) { + splitIdx--; + continue; + } + if (prev.isAssistant()) { + boolean hasToolCalls = prev.asAssistant().toolCalls().map(List::size).orElse(0) > 0; + if (hasToolCalls) { + splitIdx--; + continue; + } + } + break; + } + + // Also guard against the edge case: if kept contains a tool_result whose tool_call id is + // defined only in old, pull that assistant (and its siblings) into kept too. + Set keptCallIds = collectToolCallIds(history.subList(splitIdx, history.size())); + if (!keptCallIds.isEmpty()) { + while (splitIdx > 0) { + ChatCompletionMessageParam prev = history.get(splitIdx - 1); + if (!prev.isAssistant()) break; + List calls = prev.asAssistant().toolCalls().orElse(null); + if (calls == null || calls.isEmpty()) break; + boolean satisfiesKept = false; + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction() && keptCallIds.contains(tc.asFunction().id())) { + satisfiesKept = true; + break; + } + } + if (!satisfiesKept) break; + splitIdx--; + } + } + + List oldPart = new ArrayList<>(history.subList(0, splitIdx)); + List keptPart = + new ArrayList<>(history.subList(splitIdx, history.size())); + return new Split(oldPart, keptPart); + } + + private static Set collectToolCallIds(List slice) { + Set ids = new HashSet<>(); + for (ChatCompletionMessageParam m : slice) { + if (m.isTool()) { + ids.add(m.asTool().toolCallId()); + } + } + return ids; + } + + private static ChatCompletionMessageParam wrapSummaryAsUserMessage(String summary) { + String body = "\n" + summary + "\n"; + return ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(body).build()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java new file mode 100644 index 00000000000..e0a5c2364eb --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +/** + * Logging middleware that prints agent lifecycle events for debugging and observability. + * + *

Picks up {@code operationId} and {@code sessionId} from SLF4J MDC (set by ExecuteStatement) to + * tag every log line with the Kyuubi operation/session context. + * + *

Log structure: + * + *

+ *   [op:abcd1234] START user_input="..."
+ *   [op:abcd1234] Step 1
+ *   [op:abcd1234] LLM call: step=1, messages=3
+ *   [op:abcd1234] LLM response: step=1, content="...(truncated)", tool_calls=1
+ *   [op:abcd1234] Tool call: sql_query {sql=SELECT ...}
+ *   [op:abcd1234] Tool result: sql_query -> "| col1 | col2 |...(truncated)"
+ *   [op:abcd1234] FINISH steps=2, tokens=1234
+ * 
+ */ +public class LoggingMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger("DataAgent"); + + private static final int MAX_PREVIEW_LENGTH = 500; + + private static String prefix() { + String sessionId = MDC.get("sessionId"); + String opId = MDC.get("operationId"); + StringBuilder sb = new StringBuilder(); + if (sessionId != null) { + sb.append("[s:").append(shortId(sessionId)).append("]"); + } + if (opId != null) { + sb.append("[op:").append(shortId(opId)).append("]"); + } + if (sb.length() > 0) { + sb.append(" "); + } + return sb.toString(); + } + + /** + * Take the first segment of a UUID (before the first dash). e.g. "327d8c5b-91ef-..." → "327d8c5b" + */ + private static String shortId(String id) { + int dash = id.indexOf('-'); + return dash > 0 ? id.substring(0, dash) : id; + } + + @Override + public void onAgentStart(AgentRunContext ctx) { + LOG.debug("{}START user_input=\"{}\"", prefix(), truncate(ctx.getMemory().getLastUserInput())); + } + + @Override + public void onAgentFinish(AgentRunContext ctx) { + LOG.info( + "{}FINISH steps={}, prompt_tokens={}, completion_tokens={}, total_tokens={}", + prefix(), + ctx.getIteration(), + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()); + } + + @Override + public LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + LOG.info("{}LLM call: step={}, messages={}", prefix(), ctx.getIteration(), messages.size()); + return null; + } + + @Override + public void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + String content = response.content().map(Object::toString).orElse(""); + int toolCallCount = response.toolCalls().map(List::size).orElse(0); + LOG.info( + "{}LLM response: step={}, content=\"{}\", tool_calls={}, " + + "usage(cumulative): prompt={}, completion={}, total={}", + prefix(), + ctx.getIteration(), + truncate(content), + toolCallCount, + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()); + } + + @Override + public ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + LOG.info("{}Tool call: id={}, name={}", prefix(), toolCallId, toolName); + LOG.debug("{}Tool args: {}", prefix(), toolArgs); + return null; + } + + @Override + public String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + LOG.info("{}Tool result: {} -> \"{}\"", prefix(), toolName, truncate(result)); + return null; + } + + @Override + public AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { + switch (event.eventType()) { + case STEP_START: + LOG.info("{}Step {}", prefix(), ((StepStart) event).stepNumber()); + break; + case ERROR: + LOG.error("{}ERROR: {}", prefix(), ((AgentError) event).message()); + break; + case TOOL_RESULT: + ToolResult tr = (ToolResult) event; + if (tr.isError()) { + LOG.warn("{}Tool error: {} -> \"{}\"", prefix(), tr.toolName(), truncate(tr.output())); + } + break; + default: + break; + } + return event; + } + + private static String truncate(String s) { + if (s == null) return ""; + return s.length() <= MAX_PREVIEW_LENGTH ? s : s.substring(0, MAX_PREVIEW_LENGTH) + "..."; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java new file mode 100644 index 00000000000..87aad9f3255 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool; +import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Input-gate middleware that offloads oversized tool outputs to per-session temp files and replaces + * the in-memory tool result with a small head + tail preview plus a retrieval hint. The cheapest + * and highest-ROI defense against context rot. + * + *

Trigger: {@code result.lines > MAX_LINES} OR {@code result.bytes > MAX_BYTES}, first to + * trip wins. Thresholds are hardcoded — the ReAct loop can compensate for suboptimal defaults by + * calling the retrieval tools more aggressively. + * + *

Exempt tools: {@link ReadToolOutputTool} and {@link GrepToolOutputTool} never go + * through the gate — the agent would otherwise recursively re-offload its own retrieval output. + * + *

Session lifecycle: a monotonic counter per session is used to name temp files. {@link + * #onSessionClose(String)} wipes the counter and the per-session temp dir; call it from the + * provider's {@code close(sessionId)} hook (not from {@link #onAgentFinish}, which fires on every + * turn and would invalidate paths the LLM still needs to reference). + */ +public class ToolResultOffloadMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(ToolResultOffloadMiddleware.class); + + static final int MAX_LINES = 500; + static final int MAX_BYTES = 50 * 1024; + static final int PREVIEW_HEAD_LINES = 20; + static final int PREVIEW_TAIL_LINES = 20; + + private static final Set EXEMPT_TOOLS = + new HashSet<>(Arrays.asList(ReadToolOutputTool.NAME, GrepToolOutputTool.NAME)); + + private final ToolOutputStore store = ToolOutputStore.create(); + private final ConcurrentHashMap counters = new ConcurrentHashMap<>(); + + /** + * Register the companion retrieval tools so the LLM can reach back into offloaded files. Paired + * with the preview hint emitted from {@link #afterToolCall}; skipping this registration would + * leave the LLM dangling on file paths it can never read. + */ + @Override + public void onRegister(ToolRegistry registry) { + registry.register(new ReadToolOutputTool(store)); + registry.register(new GrepToolOutputTool(store)); + } + + @Override + public String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + if (result.isEmpty()) return null; + if (EXEMPT_TOOLS.contains(toolName)) return null; + + int bytes = result.getBytes(StandardCharsets.UTF_8).length; + int lines = countLines(result); + if (lines <= MAX_LINES && bytes <= MAX_BYTES) { + return null; + } + + // AgentRunContext.sessionId is null in unit-test constructions that don't exercise offload. + // In production the provider always threads it through, so treat null as "skip offload". + String sessionId = ctx.getSessionId(); + if (sessionId == null) return null; + + long n = counters.computeIfAbsent(sessionId, k -> new AtomicLong()).incrementAndGet(); + String toolCallId = toolName + "_" + n; + + Path file; + try { + file = store.write(sessionId, toolCallId, result); + } catch (IOException e) { + LOG.warn( + "Tool output offload failed for tool={} session={}; passing through full output", + toolName, + sessionId, + e); + return null; + } + + LOG.info( + "Offloaded tool={} session={} ({} lines / {} bytes) -> {}", + toolName, + sessionId, + lines, + bytes, + file.getFileName()); + return buildPreview(result, lines, bytes, file); + } + + /** Clean up counter and temp dir for a closed session. Idempotent. */ + @Override + public void onSessionClose(String sessionId) { + if (sessionId == null) return; + counters.remove(sessionId); + store.cleanupSession(sessionId); + } + + /** Engine-wide shutdown: drop the temp root. */ + @Override + public void onStop() { + store.close(); + } + + static int countLines(String s) { + if (s.isEmpty()) return 0; + int count = 1; + for (int i = 0; i < s.length(); i++) { + if (s.charAt(i) == '\n') count++; + } + // Trailing newline means the last "line" is empty, but still counted. Good enough for + // gating decisions — we're not trying to match `wc -l` exactly. + return count; + } + + static String buildPreview(String full, int lines, int bytes, Path file) { + String[] split = full.split("\n", -1); + int headEnd = Math.min(PREVIEW_HEAD_LINES, split.length); + int tailStart = Math.max(headEnd, split.length - PREVIEW_TAIL_LINES); + + StringBuilder sb = new StringBuilder(); + sb.append("[Tool output truncated: ") + .append(lines) + .append(" lines, ") + .append(humanBytes(bytes)) + .append("]\n") + .append("Saved to: ") + .append(file.toString()) + .append("\n\n--- First ") + .append(headEnd) + .append(" lines ---\n"); + for (int i = 0; i < headEnd; i++) { + sb.append(split[i]).append('\n'); + } + if (tailStart < split.length) { + int tailCount = split.length - tailStart; + sb.append("--- Last ").append(tailCount).append(" lines ---\n"); + for (int i = tailStart; i < split.length; i++) { + sb.append(split[i]).append('\n'); + } + } + sb.append("\nUse ") + .append(ReadToolOutputTool.NAME) + .append("(path, offset, limit) to read windows, or ") + .append(GrepToolOutputTool.NAME) + .append("(path, pattern, max_matches) to search."); + return sb.toString(); + } + + private static String humanBytes(long bytes) { + if (bytes < 1024) return bytes + " B"; + if (bytes < 1024 * 1024) return String.format("%.1f KB", bytes / 1024.0); + return String.format("%.1f MB", bytes / (1024.0 * 1024.0)); + } + + /** Visible for testing. */ + int trackedSessions() { + return counters.size(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java index 297a7c8d74a..b19f6787c87 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java @@ -50,7 +50,10 @@ default ToolRiskLevel riskLevel() { * Execute the tool with the given deserialized arguments. * * @param args the deserialized arguments from the LLM's tool call + * @param ctx per-invocation context (session id, etc.); never null — use {@link + * ToolContext#EMPTY} for calls without a session. Tools that are session-agnostic may ignore + * it. * @return the result string to feed back to the LLM */ - String execute(T args); + String execute(T args, ToolContext ctx); } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java new file mode 100644 index 00000000000..2625f3bab59 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +/** + * Per-invocation context handed to {@link AgentTool#execute(Object, ToolContext)}. Today it carries + * just the session id so session-scoped tools (e.g. the offloaded tool-output retrievers) can + * restrict their filesystem view; extend here when a tool needs user/approval/etc. + */ +public final class ToolContext { + + /** Sentinel for call sites that have no session to attribute — tests, direct CLI use. */ + public static final ToolContext EMPTY = new ToolContext(null); + + private final String sessionId; + + public ToolContext(String sessionId) { + this.sessionId = sessionId; + } + + /** Upstream session id, or {@code null} when invoked outside a session. */ + public String sessionId() { + return sessionId; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java index a403c66b58d..3a11bab567c 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java @@ -26,15 +26,16 @@ import com.openai.models.chat.completions.ChatCompletionTool; import java.util.LinkedHashMap; import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,15 +69,20 @@ public class ToolRegistry implements AutoCloseable { */ private static final int MAX_POOL_SIZE = 8; + /** Default wall-clock cap for a single tool call, used when no explicit value is configured. */ + public static final long DEFAULT_TIMEOUT_SECONDS = 300; + private final Map> tools = new LinkedHashMap<>(); private volatile Map cachedSpecs; private final long toolCallTimeoutSeconds; private final ExecutorService executor; + private final ScheduledExecutorService timeoutScheduler; /** - * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} call, sourced from - * {@code kyuubi.engine.data.agent.tool.call.timeout}. When the timeout fires, the thread is - * interrupted and a descriptive error is returned to the LLM. + * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} / {@link + * #submitTool} call, sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. When + * the timeout fires, the worker thread is interrupted and a descriptive error is returned to + * the LLM. */ public ToolRegistry(long toolCallTimeoutSeconds) { this.toolCallTimeoutSeconds = toolCallTimeoutSeconds; @@ -93,12 +99,20 @@ public ToolRegistry(long toolCallTimeoutSeconds) { t.setDaemon(true); return t; }); + this.timeoutScheduler = + Executors.newSingleThreadScheduledExecutor( + r -> { + Thread t = new Thread(r, "tool-call-timeout"); + t.setDaemon(true); + return t; + }); } - /** Shut down the worker pool. Idempotent. */ + /** Shut down the worker pool and the timeout scheduler. Idempotent. */ @Override public void close() { executor.shutdownNow(); + timeoutScheduler.shutdownNow(); } /** Register a tool. Keyed by {@link AgentTool#name()}. */ @@ -138,61 +152,100 @@ private synchronized Map ensureSpecs() { } /** - * Execute a tool call: deserialize the JSON args, then delegate to the tool, with a wall-clock - * timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. If the tool does not - * finish within the timeout, the worker thread is interrupted and a descriptive error is returned - * to the LLM so it can react (e.g. simplify the query, retry with LIMIT). + * Synchronous entry point. Blocks until the tool finishes, times out, or the registry rejects the + * submission. Errors are surfaced as strings (never as exceptions) so the LLM can observe and + * react to them. * * @param toolName the function name from the LLM response * @param argsJson the raw JSON arguments string * @return the result string, or an error message */ - @SuppressWarnings("unchecked") public String executeTool(String toolName, String argsJson) { - AgentTool tool; + return submitTool(toolName, argsJson, ToolContext.EMPTY).join(); + } + + /** Synchronous entry point with an explicit {@link ToolContext}. */ + public String executeTool(String toolName, String argsJson, ToolContext ctx) { + return submitTool(toolName, argsJson, ctx).join(); + } + + public CompletableFuture submitTool(String toolName, String argsJson) { + return submitTool(toolName, argsJson, ToolContext.EMPTY); + } + + /** + * Asynchronous entry point. Deserialize args, run the tool on the worker pool, and apply a + * wall-clock timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. The + * returned future is guaranteed to complete normally — timeouts, pool saturation, unknown tool, + * and execution failures are all translated into error strings. Callers can therefore use {@code + * .join()} / {@code .get()} without handling {@link java.util.concurrent.TimeoutException} or + * {@link java.util.concurrent.ExecutionException}. + * + * @param toolName the function name from the LLM response + * @param argsJson the raw JSON arguments string + */ + @SuppressWarnings("unchecked") + public CompletableFuture submitTool(String toolName, String argsJson, ToolContext ctx) { + AgentTool tool; synchronized (this) { - tool = tools.get(toolName); + tool = (AgentTool) tools.get(toolName); } if (tool == null) { - return "Error: unknown tool '" + toolName + "'"; + return CompletableFuture.completedFuture("Error: unknown tool '" + toolName + "'"); } - return executeWithTimeout((AgentTool) tool, argsJson); - } + ToolContext toolCtx = ctx != null ? ctx : ToolContext.EMPTY; - private String executeWithTimeout(AgentTool tool, String argsJson) { - Callable task = - () -> { - T args = JSON.readValue(argsJson, tool.argsType()); - return tool.execute(args); - }; - Future future; + CompletableFuture result = new CompletableFuture<>(); + Future submitted; try { - future = executor.submit(task); + submitted = + executor.submit( + () -> { + try { + Object args = JSON.readValue(argsJson, tool.argsType()); + String out = tool.execute(args, toolCtx); + // When the timeout handler interrupts us, the tool may still unwind cleanly and + // produce a stale return value — don't race the scheduler's timeout message with + // it. Let the timeout path be the single authority for the final result. + if (!Thread.currentThread().isInterrupted()) { + result.complete(out); + } + } catch (Exception e) { + result.complete("Error executing " + toolName + ": " + e.getMessage()); + } + }); } catch (RejectedExecutionException e) { - LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", tool.name(), MAX_POOL_SIZE); - return "Error: tool call '" - + tool.name() - + "' rejected — server is handling too many concurrent tool calls. " - + "Retry in a moment."; - } - try { - return future.get(toolCallTimeoutSeconds, TimeUnit.SECONDS); - } catch (TimeoutException e) { - future.cancel(true); - LOG.warn("Tool call '{}' timed out after {} seconds", tool.name(), toolCallTimeoutSeconds); - return "Error: tool call '" - + tool.name() - + "' timed out after " - + toolCallTimeoutSeconds - + " seconds. " - + "Try simplifying the query or adding filters to reduce execution time."; - } catch (ExecutionException e) { - Throwable cause = e.getCause() != null ? e.getCause() : e; - return "Error executing " + tool.name() + ": " + cause.getMessage(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return "Error: tool call '" + tool.name() + "' was interrupted."; + LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", toolName, MAX_POOL_SIZE); + return CompletableFuture.completedFuture( + "Error: tool call '" + + toolName + + "' rejected — server is handling too many concurrent tool calls. " + + "Retry in a moment."); } + + ScheduledFuture timer = + timeoutScheduler.schedule( + () -> { + if (!result.isDone()) { + // cancel(true) interrupts the worker thread directly — the inner task's + // catch-all will see the interrupt and call result.complete(...), but the + // timeout message below wins because complete() is idempotent on first-winner. + submitted.cancel(true); + LOG.warn( + "Tool call '{}' timed out after {} seconds", toolName, toolCallTimeoutSeconds); + result.complete( + "Error: tool call '" + + toolName + + "' timed out after " + + toolCallTimeoutSeconds + + " seconds. " + + "Try simplifying the query or adding filters to reduce execution time."); + } + }, + toolCallTimeoutSeconds, + TimeUnit.SECONDS); + result.whenComplete((r, e) -> timer.cancel(false)); + return result; } private static ChatCompletionTool buildChatCompletionTool(AgentTool tool) { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java new file mode 100644 index 00000000000..b05a199090a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** Args for {@link GrepToolOutputTool}. */ +public class GrepToolOutputArgs { + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Absolute path to the offloaded tool-output file, as reported by the truncation notice.") + public String path; + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Java regex pattern to search for. Matches are returned as ':'.") + public String pattern; + + @JsonProperty("max_matches") + @JsonPropertyDescription("Maximum number of matching lines to return. Defaults to 50.") + public Integer maxMatches; +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java new file mode 100644 index 00000000000..f94fb21b315 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; + +/** + * Regex-search a previously offloaded tool-output file. Companion to {@code + * ToolResultOffloadMiddleware}. + */ +public class GrepToolOutputTool implements AgentTool { + + public static final String NAME = "grep_tool_output"; + private static final int DEFAULT_MAX_MATCHES = 50; + + private final ToolOutputStore store; + + public GrepToolOutputTool(ToolOutputStore store) { + this.store = store; + } + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "Regex-search a previously offloaded tool-output file " + + "(the path is supplied in the truncation notice of a prior tool result). " + + "Cheaper than read_tool_output when you know what you're looking for. " + + "Returns matching lines as ':'."; + } + + @Override + public Class argsType() { + return GrepToolOutputArgs.class; + } + + @Override + public String execute(GrepToolOutputArgs args, ToolContext ctx) { + if (args == null || args.path == null || args.path.isEmpty()) { + return "Error: 'path' parameter is required."; + } + if (ctx == null || ctx.sessionId() == null) { + return "Error: grep_tool_output requires a session context."; + } + int max = args.maxMatches != null ? args.maxMatches : DEFAULT_MAX_MATCHES; + return store.grep(ctx.sessionId(), args.path, args.pattern, max); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java new file mode 100644 index 00000000000..458fbfa4f66 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** Args for {@link ReadToolOutputTool}. */ +public class ReadToolOutputArgs { + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Absolute path to the offloaded tool-output file, as reported by the truncation notice.") + public String path; + + @JsonPropertyDescription("0-based line offset into the file. Defaults to 0.") + public Integer offset; + + @JsonPropertyDescription( + "Number of lines to return starting at 'offset'. Defaults to 200; capped at 500.") + public Integer limit; +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java new file mode 100644 index 00000000000..93e837998f3 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; + +/** + * Read a line window from a previously offloaded tool-output file. Companion to {@code + * ToolResultOffloadMiddleware}. + */ +public class ReadToolOutputTool implements AgentTool { + + public static final String NAME = "read_tool_output"; + private static final int DEFAULT_LIMIT = 200; + private static final int MAX_LIMIT = 500; + + private final ToolOutputStore store; + + public ReadToolOutputTool(ToolOutputStore store) { + this.store = store; + } + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "Read a line window from a previously offloaded tool-output file " + + "(the path is supplied in the truncation notice of a prior tool result). " + + "Returns '[lines X-Y of Z total]' header followed by the requested window. " + + "Use when a prior tool's output was truncated and you need to inspect more of it."; + } + + @Override + public Class argsType() { + return ReadToolOutputArgs.class; + } + + @Override + public String execute(ReadToolOutputArgs args, ToolContext ctx) { + if (args == null || args.path == null || args.path.isEmpty()) { + return "Error: 'path' parameter is required."; + } + if (ctx == null || ctx.sessionId() == null) { + return "Error: read_tool_output requires a session context."; + } + int offset = args.offset != null ? args.offset : 0; + int limit = args.limit != null ? args.limit : DEFAULT_LIMIT; + if (limit > MAX_LIMIT) limit = MAX_LIMIT; + return store.read(ctx.sessionId(), args.path, offset, limit); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java index 06b12f2be72..88838ca477a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java @@ -19,6 +19,7 @@ import javax.sql.DataSource; import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; /** @@ -69,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args) { + public String execute(SqlQueryArgs args, ToolContext ctx) { return SqlExecutor.execute(dataSource, args.sql, queryTimeoutSeconds); } } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java index 0c57cc049ed..9136b0aa903 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java @@ -19,6 +19,7 @@ import javax.sql.DataSource; import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; /** @@ -69,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args) { + public String execute(SqlQueryArgs args, ToolContext ctx) { String sql = args.sql; if (sql == null || sql.trim().isEmpty()) { return "Error: 'sql' parameter is required."; diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java index d2cac9ae518..52b83182add 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java @@ -53,6 +53,7 @@ final class SqlReadOnlyChecker { *
  • {@code USE} — switch session catalog/database; not data-mutating *
  • {@code LIST} — Spark {@code LIST FILE} / {@code LIST JAR} inspection *
  • {@code HELP} — some engines expose interactive help + *
  • {@code PRAGMA} — SQLite schema/metadata inspection * */ private static final Set READ_ONLY_KEYWORDS = @@ -70,7 +71,8 @@ final class SqlReadOnlyChecker { "EXPLAIN", "USE", "LIST", - "HELP"))); + "HELP", + "PRAGMA"))); private SqlReadOnlyChecker() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java new file mode 100644 index 00000000000..f4366094670 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.util; + +import org.apache.kyuubi.config.ConfigEntry; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.config.OptionalConfigEntry; + +/** Small helpers for reading typed values out of {@link KyuubiConf}. */ +public final class ConfUtils { + + private ConfUtils() {} + + /** Return the string value, or throw if the entry is not set. */ + public static String requireString(KyuubiConf conf, OptionalConfigEntry key) { + scala.Option opt = conf.get(key); + if (opt.isEmpty()) { + throw new IllegalArgumentException(key.key() + " is required"); + } + return opt.get(); + } + + /** Return the string value, or {@code null} if the entry is not set. */ + public static String optionalString(KyuubiConf conf, OptionalConfigEntry key) { + scala.Option opt = conf.get(key); + return opt.isDefined() ? opt.get() : null; + } + + /** Return the value for a raw key, or {@code null} if not set. */ + public static String optionalString(KyuubiConf conf, String key) { + scala.Option opt = conf.getOption(key); + return opt.isDefined() ? opt.get() : null; + } + + public static int intConf(KyuubiConf conf, ConfigEntry key) { + return ((Number) conf.get(key)).intValue(); + } + + public static long longConf(KyuubiConf conf, ConfigEntry key) { + return ((Number) conf.get(key)).longValue(); + } + + /** Read a millisecond-valued entry and return it as whole seconds. */ + public static long millisAsSeconds(KyuubiConf conf, ConfigEntry key) { + return longConf(conf, key) / 1000L; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala index 04d0defa2dc..3d902677a0a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala +++ b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala @@ -25,7 +25,7 @@ import org.slf4j.MDC import org.apache.kyuubi.{KyuubiSQLException, Logging} import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.engine.dataagent.provider.{DataAgentProvider, ProviderRunRequest} -import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult} +import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, Compaction, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult} import org.apache.kyuubi.operation.OperationState import org.apache.kyuubi.operation.log.OperationLog import org.apache.kyuubi.session.Session @@ -117,7 +117,7 @@ class ExecuteStatement( n.put("type", sseType) n.put("id", toolCall.toolCallId()) n.put("name", toolCall.toolName()) - n.set("args", JSON.valueToTree(toolCall.toolArgs())) + n.set[ObjectNode]("args", JSON.valueToTree(toolCall.toolArgs())) })) case EventType.TOOL_RESULT => val toolResult = event.asInstanceOf[ToolResult] @@ -148,6 +148,15 @@ class ExecuteStatement( n.set("args", JSON.valueToTree(req.toolArgs())) n.put("riskLevel", req.riskLevel().name()) })) + case EventType.COMPACTION => + val c = event.asInstanceOf[Compaction] + incrementalIter.append(Array(toJson { n => + n.put("type", sseType) + n.put("summarized", c.summarizedCount()) + n.put("kept", c.keptCount()) + n.put("triggerTokens", c.triggerTokens()) + n.put("observedTokens", c.observedTokens()) + })) case EventType.AGENT_FINISH => val finish = event.asInstanceOf[AgentFinish] incrementalIter.append(Array(toJson { n => diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java index e728ff871e7..c43942a8f35 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.*; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect; import org.junit.Test; public class JdbcDialectTest { diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java index 4e713b9cca3..cc45ebdc7e7 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java @@ -20,8 +20,9 @@ import static org.junit.Assert.*; import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; -import org.apache.kyuubi.engine.dataagent.datasource.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; import org.apache.kyuubi.engine.dataagent.tool.sql.SqlQueryArgs; import org.junit.BeforeClass; @@ -66,7 +67,7 @@ public void testBacktickQuotingWithReservedWord() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT " + quotedCol + " FROM " + quotedTable + " WHERE id = 1"; - String result = selectTool.execute(args); + String result = selectTool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("value1")); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java new file mode 100644 index 00000000000..2756b3e2087 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.provider.mock; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import javax.sql.DataSource; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory; +import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider; +import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.apache.kyuubi.engine.dataagent.util.ConfUtils; + +/** + * A mock LLM provider for testing the full tool-call pipeline without a real LLM. Simulates the + * ReAct loop: extracts SQL from the user question, executes it via SqlQueryTool, and returns the + * result as a formatted answer. + * + *

    Recognizes two patterns: + * + *

      + *
    • Questions containing SQL keywords (SELECT, SHOW, DESCRIBE) — extracts and executes the SQL + *
    • All other questions — returns a canned response without tool calls + *
    + */ +public class MockLlmProvider implements DataAgentProvider { + + private static final Pattern SQL_PATTERN = + Pattern.compile( + "(SELECT\\b.+|SHOW\\b.+|DESCRIBE\\b.+)", Pattern.CASE_INSENSITIVE | Pattern.DOTALL); + + /** + * Simple natural-language-to-SQL mappings so tests can use human-readable questions instead of + * raw SQL. Checked before the regex pattern — if a question matches a key (case-insensitive + * prefix), the mapped SQL is executed. + */ + private static final Map NL_TO_SQL = new java.util.LinkedHashMap<>(); + + static { + NL_TO_SQL.put( + "list all employee names and departments", + "SELECT name, department FROM employees ORDER BY id"); + NL_TO_SQL.put( + "how many employees in each department", + "SELECT department, COUNT(*) as cnt FROM employees GROUP BY department"); + NL_TO_SQL.put("count the total number of employees", "SELECT COUNT(*) FROM employees"); + } + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ToolRegistry toolRegistry; + private final DataSource dataSource; + + public MockLlmProvider(KyuubiConf conf) { + String jdbcUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL()); + this.dataSource = DataSourceFactory.create(jdbcUrl); + this.toolRegistry = new ToolRegistry(30); + this.toolRegistry.register(new RunSelectQueryTool(dataSource, 0)); + } + + @Override + public void open(String sessionId, String user) { + sessions.put(sessionId, new Object()); + } + + @Override + public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) { + String question = request.getQuestion(); + onEvent.accept(new AgentStart()); + + // Trigger an error for testing the error path in ExecuteStatement + if (question.trim().equalsIgnoreCase("__error__")) { + throw new RuntimeException("MockLlmProvider simulated failure"); + } + + // First check natural-language mappings, then fall back to SQL pattern extraction + String sql = resolveToSql(question); + if (sql != null) { + runWithToolCall(sql, onEvent); + } else { + runWithoutToolCall(question, onEvent); + } + } + + private void runWithToolCall(String sql, Consumer onEvent) { + // Step 1: LLM "decides" to call sql_query tool + onEvent.accept(new StepStart(1)); + String toolCallId = "mock_call_" + System.nanoTime(); + Map toolArgs = new HashMap<>(); + toolArgs.put("sql", sql); + onEvent.accept(new ToolCall(toolCallId, "run_select_query", toolArgs)); + + // Execute the tool + String toolOutput = + toolRegistry.executeTool("run_select_query", "{\"sql\":\"" + escapeJson(sql) + "\"}"); + onEvent.accept(new ToolResult(toolCallId, "run_select_query", toolOutput, false)); + onEvent.accept(new StepEnd(1)); + + // Step 2: LLM "summarizes" the result + onEvent.accept(new StepStart(2)); + String answer = "Based on the query result:\n\n" + toolOutput; + for (String token : answer.split("(?<=\\n)")) { + onEvent.accept(new ContentDelta(token)); + } + onEvent.accept(new ContentComplete(answer)); + onEvent.accept(new StepEnd(2)); + onEvent.accept(new AgentFinish(2, 100, 50, 150)); + } + + private void runWithoutToolCall(String question, Consumer onEvent) { + onEvent.accept(new StepStart(1)); + String answer = "[MockLLM] No SQL detected in: " + question; + onEvent.accept(new ContentDelta(answer)); + onEvent.accept(new ContentComplete(answer)); + onEvent.accept(new StepEnd(1)); + onEvent.accept(new AgentFinish(1, 50, 20, 70)); + } + + @Override + public void close(String sessionId) { + sessions.remove(sessionId); + } + + @Override + public void stop() { + if (dataSource instanceof com.zaxxer.hikari.HikariDataSource) { + ((com.zaxxer.hikari.HikariDataSource) dataSource).close(); + } + } + + /** + * Resolve a user question to SQL. Checks NL_TO_SQL mappings first, then falls back to regex + * extraction of raw SQL from the input. + */ + private static String resolveToSql(String question) { + String lower = question.toLowerCase().trim(); + for (Map.Entry entry : NL_TO_SQL.entrySet()) { + if (lower.startsWith(entry.getKey())) { + return entry.getValue(); + } + } + Matcher matcher = SQL_PATTERN.matcher(question); + if (matcher.find()) { + return matcher.group(1).trim(); + } + return null; + } + + private static String escapeJson(String s) { + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java new file mode 100644 index 00000000000..d8eb96cfbc1 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; + +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.Collections; +import org.junit.Test; + +public class ConversationMemoryTest { + + @Test + public void testReplaceHistoryClearsLastTotalTokensButKeepsCumulative() { + ConversationMemory memory = new ConversationMemory(); + memory.addCumulativeTokens(100, 50, 150); + memory.addCumulativeTokens(200, 80, 280); + assertEquals(280L, memory.getLastTotalTokens()); + assertEquals(300L, memory.getCumulativePromptTokens()); + assertEquals(430L, memory.getCumulativeTotalTokens()); + + ChatCompletionMessageParam summary = + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content("summary").build()); + memory.replaceHistory(Collections.singletonList(summary)); + + assertEquals("lastTotalTokens reset after compaction", 0L, memory.getLastTotalTokens()); + assertEquals("cumulative totals preserved", 300L, memory.getCumulativePromptTokens()); + assertEquals("cumulative totals preserved", 430L, memory.getCumulativeTotalTokens()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java new file mode 100644 index 00000000000..75553f46998 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java @@ -0,0 +1,568 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import java.io.File; +import java.sql.Connection; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.runtime.event.*; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.sqlite.SQLiteDataSource; + +/** + * Live integration test with a real LLM and real SQLite database. Exercises the full ReAct loop: + * LLM reasoning -> tool calls -> result verification. Requires DATA_AGENT_LLM_API_KEY and + * DATA_AGENT_LLM_API_URL environment variables. Works with any OpenAI-compatible LLM service. + */ +public class ReactAgentLiveTest { + + private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); + private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); + private static final String MODEL_NAME = + System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", "gpt-4o"); + + private static final String SYSTEM_PROMPT = + SystemPromptBuilder.create().datasource("sqlite").build(); + + private final List tempFiles = new ArrayList<>(); + private OpenAIClient client; + + @Before + public void setUp() { + assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); + } + + @After + public void tearDown() { + tempFiles.forEach(File::delete); + } + + @Test + public void testPlainTextStreamingWithoutTools() { + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(new ToolRegistry(30)) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(3) + .systemPrompt("You are a helpful assistant. Answer concisely in 1-2 sentences.") + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + agent.run(new AgentInvocation("What is Apache Kyuubi?"), memory, events::add); + + List deltas = + events.stream() + .filter(e -> e instanceof ContentDelta) + .map(e -> ((ContentDelta) e).text()) + .collect(Collectors.toList()); + + assertTrue("Expected multiple ContentDelta events", deltas.size() > 1); + assertFalse("Streamed text should not be empty", String.join("", deltas).isEmpty()); + assertTrue(events.stream().anyMatch(e -> e instanceof StepStart)); + assertTrue(events.stream().anyMatch(e -> e instanceof ContentComplete)); + assertTrue(events.get(events.size() - 1) instanceof AgentFinish); + assertEquals(2, memory.getHistory().size()); // user + assistant + } + + @Test + public void testFullReActLoopWithSchemaInspectAndSqlQuery() { + SQLiteDataSource ds = createSalesDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + agent.run( + new AgentInvocation( + "What is the total revenue by product category? Which category has the highest revenue?"), + memory, + events::add); + + // Verify tool calls happened + List toolCalls = + events.stream() + .filter(e -> e instanceof ToolCall) + .map(e -> (ToolCall) e) + .collect(Collectors.toList()); + List toolResults = + events.stream() + .filter(e -> e instanceof ToolResult) + .map(e -> (ToolResult) e) + .collect(Collectors.toList()); + + assertFalse("Agent should have called at least one tool", toolCalls.isEmpty()); + assertFalse("Agent should have received tool results", toolResults.isEmpty()); + assertTrue("Tool calls should not error", toolResults.stream().noneMatch(ToolResult::isError)); + + // Verify SQL query was called + assertTrue( + "Agent should execute at least one SQL query", + toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName()))); + + // Verify final answer mentions "Electronics" (highest revenue) + List completions = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .collect(Collectors.toList()); + String lastAnswer = completions.get(completions.size() - 1); + assertTrue( + "Final answer should mention Electronics, got: " + lastAnswer, + lastAnswer.toLowerCase().contains("electronics")); + + // Verify agent finished successfully + AgentEvent last = events.get(events.size() - 1); + assertTrue(last instanceof AgentFinish); + assertTrue("Should take multiple steps", ((AgentFinish) last).totalSteps() > 1); + } + + @Test + public void testMultiTurnConversationWithToolUse() { + SQLiteDataSource ds = createSalesDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + // Shared memory across turns + ConversationMemory memory = new ConversationMemory(); + + // Turn 1 + List events1 = new CopyOnWriteArrayList<>(); + agent.run(new AgentInvocation("How many orders are there in total?"), memory, events1::add); + + assertTrue( + "Turn 1 should query the database", + events1.stream() + .anyMatch( + e -> + e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName()))); + assertTrue(events1.get(events1.size() - 1) instanceof AgentFinish); + + // Turn 2: follow-up relying on conversation context + List events2 = new CopyOnWriteArrayList<>(); + agent.run( + new AgentInvocation("Now show me only orders above 500 dollars."), memory, events2::add); + + assertTrue( + "Turn 2 should also query the database", + events2.stream() + .anyMatch( + e -> + e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName()))); + assertTrue(events2.get(events2.size() - 1) instanceof AgentFinish); + + // Verify memory accumulated across both turns + assertTrue( + "Memory should contain messages from both turns, got " + memory.getHistory().size(), + memory.getHistory().size() > 4); + } + + @Test + public void testToolOutputOffloadThenGrep() throws Exception { + // Large result forces ToolResultOffloadMiddleware to truncate the tool output and + // emit a preview hint telling the LLM to use grep_tool_output / read_tool_output. + // A correct answer proves the LLM read the hint and drove the retrieval itself. + SQLiteDataSource ds = createNeedleInHaystackDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(new ToolResultOffloadMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + // Explicit workflow: full-table scan first, then retrieval tool. Without this hint a + // capable LLM just issues SELECT note FROM events WHERE tag='NEEDLE' and skips the + // offload path entirely -- smart behavior, but defeats the purpose of this test. + agent.run( + new AgentInvocation( + "Step 1: issue exactly this query: SELECT id, tag, note FROM events (no" + + " WHERE clause, return every row). Step 2: the result will be" + + " truncated; call the grep_tool_output tool with pattern 'NEEDLE' on" + + " the saved output file to find the matching row. Step 3: respond with" + + " ONLY the note text from that row, nothing else. Do NOT add a WHERE" + + " clause. Do NOT issue any other SQL.") + // Offload middleware requires a non-null session id -- without it the offload + // path skips entirely (see ToolResultOffloadMiddleware.afterToolCall). + .sessionId("offload-live-" + java.util.UUID.randomUUID()), + memory, + events::add); + + List toolCalls = + events.stream() + .filter(e -> e instanceof ToolCall) + .map(e -> (ToolCall) e) + .collect(Collectors.toList()); + List toolResults = + events.stream() + .filter(e -> e instanceof ToolResult) + .map(e -> (ToolResult) e) + .collect(Collectors.toList()); + + // Dump the tool trace -- diagnoses whether the LLM followed the workflow, added a + // WHERE clause despite instructions, or deviated in some other way. + for (ToolCall tc : toolCalls) { + System.out.println("[ToolCall] " + tc.toolName() + " args=" + tc.toolArgs()); + } + for (ToolResult tr : toolResults) { + String out = tr.output(); + String preview = + out.length() > 300 + ? out.substring(0, 300) + "...(+" + (out.length() - 300) + " chars)" + : out; + System.out.println("[ToolResult] " + tr.toolName() + " -> " + preview); + } + + assertTrue( + "Agent should have run a select query first", + toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName()))); + + // The SELECT returned 800 rows, which must trip the offload threshold. + assertTrue( + "Expected at least one offload preview marker in tool results", + toolResults.stream().anyMatch(tr -> tr.output().contains("Tool output truncated"))); + + assertTrue( + "Agent should have used grep_tool_output or read_tool_output after seeing" + + " the offload preview; actual tool calls: " + + toolCalls.stream().map(ToolCall::toolName).collect(Collectors.toList()), + toolCalls.stream() + .anyMatch( + tc -> + "grep_tool_output".equals(tc.toolName()) + || "read_tool_output".equals(tc.toolName()))); + + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse(""); + assertTrue( + "Final answer should contain the needle note 'the-answer-is-42', got: " + finalAnswer, + finalAnswer.contains("the-answer-is-42")); + } + + @Test + public void testApprovalApproveFlow() throws Exception { + // Real LLM picks run_mutation_query (DESTRUCTIVE) -> ApprovalMiddleware pauses -> + // background thread resolves(approved=true) -> mutation actually runs in SQLite. + SQLiteDataSource ds = createCountersDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + registry.register(new RunMutationQueryTool(ds, 0)); + + ApprovalMiddleware approval = new ApprovalMiddleware(30); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(approval) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + // Auto-approve any approval request that shows up. + ExecutorService approver = Executors.newSingleThreadExecutor(); + Consumer listener = + event -> { + events.add(event); + if (event instanceof ApprovalRequest) { + String rid = ((ApprovalRequest) event).requestId(); + approver.submit(() -> approval.resolve(rid, true)); + } + }; + + try { + agent.run( + new AgentInvocation( + "Increment the 'hits' counter in the counters table by 1, then tell me its" + + " new value. Respond with ONLY the new value, no explanation."), + memory, + listener); + } finally { + approver.shutdown(); + approver.awaitTermination(5, TimeUnit.SECONDS); + } + + List approvals = + events.stream() + .filter(e -> e instanceof ApprovalRequest) + .map(e -> (ApprovalRequest) e) + .collect(Collectors.toList()); + assertFalse("Expected at least one ApprovalRequest", approvals.isEmpty()); + assertTrue( + "ApprovalRequest should target run_mutation_query", + approvals.stream().anyMatch(a -> "run_mutation_query".equals(a.toolName()))); + + // Mutation must have actually executed — check the DB directly. + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery("SELECT value FROM counters WHERE name='hits'")) { + assertTrue(rs.next()); + assertEquals("Counter should be 1 after approved mutation", 1, rs.getInt(1)); + } + + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse(""); + assertTrue("Final answer should mention 1, got: " + finalAnswer, finalAnswer.contains("1")); + } + + @Test + public void testApprovalDenyFlow() throws Exception { + // Same setup as the approve test, but the approval listener denies. The mutation + // must NOT run, and the LLM must surface the denial to the user naturally. + SQLiteDataSource ds = createCountersDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + registry.register(new RunMutationQueryTool(ds, 0)); + + ApprovalMiddleware approval = new ApprovalMiddleware(30); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(approval) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + ExecutorService approver = Executors.newSingleThreadExecutor(); + Consumer listener = + event -> { + events.add(event); + if (event instanceof ApprovalRequest) { + String rid = ((ApprovalRequest) event).requestId(); + approver.submit(() -> approval.resolve(rid, false)); + } + }; + + try { + agent.run( + new AgentInvocation( + "Delete all rows from the counters table. If you cannot, explain why."), + memory, + listener); + } finally { + approver.shutdown(); + approver.awaitTermination(5, TimeUnit.SECONDS); + } + + assertTrue( + "Expected at least one ApprovalRequest", + events.stream().anyMatch(e -> e instanceof ApprovalRequest)); + + // DB must be untouched. + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM counters")) { + assertTrue(rs.next()); + assertTrue("Counters rows must survive denied mutation", rs.getInt(1) > 0); + } + + // LLM should tell the user the operation didn't go through. Loose lexical check to + // absorb model wording drift. + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse("") + .toLowerCase(); + assertTrue( + "Final answer should indicate the deletion was refused/denied/not executed, got: " + + finalAnswer, + finalAnswer.contains("den") + || finalAnswer.contains("reject") + || finalAnswer.contains("not ") + || finalAnswer.contains("refus") + || finalAnswer.contains("unable") + || finalAnswer.contains("could not")); + } + + // --- Helpers --- + + private SQLiteDataSource createSalesDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE products (" + + "id INTEGER PRIMARY KEY, name TEXT NOT NULL, " + + "category TEXT NOT NULL, price REAL NOT NULL)"); + stmt.execute( + "INSERT INTO products VALUES " + + "(1, 'Laptop', 'Electronics', 999.99), " + + "(2, 'Headphones', 'Electronics', 199.99), " + + "(3, 'T-Shirt', 'Clothing', 29.99), " + + "(4, 'Jeans', 'Clothing', 59.99), " + + "(5, 'Novel', 'Books', 14.99), " + + "(6, 'Textbook', 'Books', 89.99)"); + stmt.execute( + "CREATE TABLE orders (" + + "id INTEGER PRIMARY KEY, product_id INTEGER NOT NULL, " + + "customer_name TEXT NOT NULL, quantity INTEGER NOT NULL, " + + "order_date TEXT NOT NULL, " + + "FOREIGN KEY (product_id) REFERENCES products(id))"); + stmt.execute( + "INSERT INTO orders VALUES " + + "(1, 1, 'Alice', 1, '2024-01-15'), " + + "(2, 2, 'Bob', 2, '2024-01-20'), " + + "(3, 3, 'Charlie', 3, '2024-02-01'), " + + "(4, 4, 'Alice', 1, '2024-02-10'), " + + "(5, 5, 'Bob', 5, '2024-02-15'), " + + "(6, 1, 'Diana', 1, '2024-03-01'), " + + "(7, 6, 'Charlie', 2, '2024-03-05'), " + + "(8, 2, 'Diana', 1, '2024-03-10')"); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createNeedleInHaystackDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE events (" + + "id INTEGER PRIMARY KEY, tag TEXT NOT NULL, note TEXT NOT NULL)"); + conn.setAutoCommit(false); + try (java.sql.PreparedStatement ps = + conn.prepareStatement("INSERT INTO events VALUES (?, ?, ?)")) { + // 800 filler rows, guaranteed to blow past ToolResultOffloadMiddleware's 500-line + // threshold when the LLM issues a SELECT *. Exactly one row carries the NEEDLE tag. + int needleId = 573; + for (int i = 1; i <= 800; i++) { + ps.setInt(1, i); + if (i == needleId) { + ps.setString(2, "NEEDLE"); + ps.setString(3, "the-answer-is-42"); + } else { + ps.setString(2, "FILLER"); + ps.setString(3, "filler-note-" + i); + } + ps.addBatch(); + } + ps.executeBatch(); + } + conn.commit(); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createCountersDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE counters (name TEXT PRIMARY KEY, value INTEGER NOT NULL)"); + stmt.execute("INSERT INTO counters VALUES ('hits', 0)"); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createDataSource() { + try { + File tmpFile = File.createTempFile("kyuubi-agent-live-", ".db"); + tmpFile.deleteOnExit(); + tempFiles.add(tmpFile); + SQLiteDataSource ds = new SQLiteDataSource(); + ds.setUrl("jdbc:sqlite:" + tmpFile.getAbsolutePath()); + return ds; + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java new file mode 100644 index 00000000000..e3eda5ff5e4 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.Before; +import org.junit.Test; + +public class ToolOutputStoreTest { + + private ToolOutputStore store; + + @Before + public void setUp() throws IOException { + store = ToolOutputStore.create(); + } + + @Test + public void writeAndReadWindow() throws IOException { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= 100; i++) sb.append("row").append(i).append('\n'); + Path p = store.write("sess1", "call1", sb.toString()); + assertTrue(Files.exists(p)); + + String out = store.read("sess1", p.toString(), 10, 5); + assertTrue(out, out.contains("lines 11-15 of")); + assertTrue(out, out.contains("row11")); + assertTrue(out, out.contains("row15")); + assertFalse(out, out.contains("row16")); + assertFalse(out, out.contains("row10")); + } + + @Test + public void grepReturnsMatchingLinesWithLineNumbers() throws IOException { + String content = "apple\nbanana\ncherry\napple pie\ndate\n"; + Path p = store.write("sess1", "call1", content); + + String out = store.grep("sess1", p.toString(), "apple", 10); + assertTrue(out, out.contains("1:apple")); + assertTrue(out, out.contains("4:apple pie")); + assertFalse(out, out.contains("banana")); + } + + @Test + public void grepRespectsMaxMatches() throws IOException { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 20; i++) sb.append("hit\n"); + Path p = store.write("sess1", "call1", sb.toString()); + + String out = store.grep("sess1", p.toString(), "hit", 3); + assertTrue(out, out.contains("[3 matches]")); + assertTrue(out, out.contains("1:hit")); + assertTrue(out, out.contains("3:hit")); + assertFalse("should stop after 3 matches", out.contains("4:hit")); + } + + @Test + public void grepInvalidRegexReturnsError() throws IOException { + Path p = store.write("sess1", "call1", "x\n"); + String out = store.grep("sess1", p.toString(), "[", 10); + assertTrue(out, out.startsWith("Error:")); + } + + @Test + public void readRejectsCrossSessionPath() throws IOException { + Path victim = store.write("victim", "secret_call", "top secret\n"); + assertTrue(Files.exists(victim)); + + String out = store.read("attacker", victim.toString(), 0, 10); + assertTrue(out, out.startsWith("Error:")); + assertFalse(out, out.contains("top secret")); + } + + @Test + public void grepRejectsCrossSessionPath() throws IOException { + Path victim = store.write("victim", "secret_call", "api_key=xyz\n"); + String out = store.grep("attacker", victim.toString(), "api_key", 10); + assertTrue(out, out.startsWith("Error:")); + assertFalse(out, out.contains("xyz")); + } + + @Test + public void cleanupSessionRemovesSubtree() throws IOException { + Path p1 = store.write("sessA", "call1", "a\n"); + Path p2 = store.write("sessA", "call2", "b\n"); + Path p3 = store.write("sessB", "call1", "c\n"); + assertTrue(Files.exists(p1)); + assertTrue(Files.exists(p2)); + assertTrue(Files.exists(p3)); + + store.cleanupSession("sessA"); + + assertFalse(Files.exists(p1)); + assertFalse(Files.exists(p2)); + assertTrue("other sessions untouched", Files.exists(p3)); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java index 50d22da416d..b6a5f093b61 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java @@ -113,6 +113,7 @@ public void testEventTypeSseNames() { assertEquals("step_end", EventType.STEP_END.sseEventName()); assertEquals("error", EventType.ERROR.sseEventName()); assertEquals("approval_request", EventType.APPROVAL_REQUEST.sseEventName()); + assertEquals("compaction", EventType.COMPACTION.sseEventName()); assertEquals("agent_finish", EventType.AGENT_FINISH.sseEventName()); } @@ -123,6 +124,6 @@ public void testAllEventTypesHaveUniqueSseNames() { for (EventType type : values) { assertTrue("Duplicate SSE name: " + type.sseEventName(), names.add(type.sseEventName())); } - assertEquals(10, values.length); + assertEquals(11, values.length); } } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java new file mode 100644 index 00000000000..a84bbc25948 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest; +import org.apache.kyuubi.engine.dataagent.runtime.event.EventType; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; +import org.junit.Before; +import org.junit.Test; + +public class ApprovalMiddlewareTest { + + private ToolRegistry registry; + private List emittedEvents; + + @Before + public void setUp() { + registry = new ToolRegistry(30); + registry.register(safeTool("safe_tool")); + registry.register(destructiveTool("dangerous_tool")); + emittedEvents = Collections.synchronizedList(new ArrayList<>()); + } + + // --- Auto-approve mode: all tools pass --- + + @Test + public void testAutoApproveModeSkipsAllApproval() { + ApprovalMiddleware mw = newApprovalMiddleware(); + AgentRunContext ctx = makeContext(ApprovalMode.AUTO_APPROVE); + + assertNull(mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + assertNull(mw.beforeToolCall(ctx, "tc2", "safe_tool", Collections.emptyMap())); + assertTrue("No approval events should be emitted", emittedEvents.isEmpty()); + } + + // --- Normal mode: safe auto-approved, destructive needs approval --- + + @Test + public void testNormalModeAutoApprovesSafeTool() { + ApprovalMiddleware mw = newApprovalMiddleware(); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + assertNull(mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + assertTrue(emittedEvents.isEmpty()); + } + + @Test + public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + // Capture the emitted event to get the requestId + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit( + () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + + // Wait for the approval request event + assertTrue("Approval event should be emitted", eventEmitted.await(2, TimeUnit.SECONDS)); + assertEquals(1, emittedEvents.size()); + assertEquals(EventType.APPROVAL_REQUEST, emittedEvents.get(0).eventType()); + + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + assertEquals("dangerous_tool", req.toolName()); + assertEquals(ToolRiskLevel.DESTRUCTIVE, req.riskLevel()); + + // Approve + assertTrue(mw.resolve(req.requestId(), true)); + assertNull("Approved tool should return null (no denial)", future.get(2, TimeUnit.SECONDS)); + } finally { + exec.shutdownNow(); + } + } + + @Test + public void testDeniedToolReturnsToolCallDenial() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit( + () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + + assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + + // Deny + assertTrue(mw.resolve(req.requestId(), false)); + AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); + assertNotNull(denial); + assertTrue(denial.reason().contains("denied")); + } finally { + exec.shutdownNow(); + } + } + + // --- Strict mode: all tools need approval --- + + @Test + public void testStrictModeRequiresApprovalForSafeTool() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + + assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + assertEquals("safe_tool", req.toolName()); + + assertTrue(mw.resolve(req.requestId(), true)); + assertNull(future.get(2, TimeUnit.SECONDS)); + } finally { + exec.shutdownNow(); + } + } + + // --- Timeout --- + + @Test + public void testApprovalTimeoutReturnsDenial() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(1); // 1 second timeout + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + ctx.setEventEmitter(emittedEvents::add); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + Future future = + exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + + // Don't resolve — let it time out + AgentMiddleware.ToolCallDenial denial = future.get(5, TimeUnit.SECONDS); + assertNotNull("Timeout should produce a denial", denial); + assertTrue(denial.reason().contains("timed out")); + } finally { + exec.shutdownNow(); + } + } + + // --- Cancel all --- + + @Test + public void testOnStopUnblocksPendingRequests() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(30); + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + ctx.setEventEmitter(emittedEvents::add); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch started = new CountDownLatch(1); + Future future = + exec.submit( + () -> { + started.countDown(); + return mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap()); + }); + + assertTrue(started.await(2, TimeUnit.SECONDS)); + Thread.sleep(100); // let the thread enter the blocking wait + + mw.onStop(); + + AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); + assertNotNull("onStop should unblock with a denial", denial); + } finally { + exec.shutdownNow(); + } + } + + // --- Helpers --- + + private ApprovalMiddleware newApprovalMiddleware() { + ApprovalMiddleware mw = new ApprovalMiddleware(); + mw.onRegister(registry); + return mw; + } + + private ApprovalMiddleware newApprovalMiddleware(long timeoutSeconds) { + ApprovalMiddleware mw = new ApprovalMiddleware(timeoutSeconds); + mw.onRegister(registry); + return mw; + } + + private AgentRunContext makeContext(ApprovalMode mode) { + AgentRunContext ctx = new AgentRunContext(new ConversationMemory(), mode); + ctx.setEventEmitter(emittedEvents::add); + return ctx; + } + + private static AgentTool safeTool(String name) { + return new DummyTool(name, ToolRiskLevel.SAFE); + } + + private static AgentTool destructiveTool(String name) { + return new DummyTool(name, ToolRiskLevel.DESTRUCTIVE); + } + + public static class DummyArgs { + public String value; + } + + private static class DummyTool implements AgentTool { + private final String name; + private final ToolRiskLevel riskLevel; + + DummyTool(String name, ToolRiskLevel riskLevel) { + this.name = name; + this.riskLevel = riskLevel; + } + + @Override + public String name() { + return name; + } + + @Override + public String description() { + return "dummy tool"; + } + + @Override + public ToolRiskLevel riskLevel() { + return riskLevel; + } + + @Override + public Class argsType() { + return DummyArgs.class; + } + + @Override + public String execute(DummyArgs args, ToolContext ctx) { + return "ok"; + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java new file mode 100644 index 00000000000..1c4eeb7c5ad --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.junit.Before; +import org.junit.Test; + +/** + * Live integration test for {@link CompactionMiddleware}: exercises the full compaction path + * against a real OpenAI-compatible LLM. Requires {@code DATA_AGENT_LLM_API_KEY} and {@code + * DATA_AGENT_LLM_API_URL} environment variables; skipped otherwise. + */ +public class CompactionMiddlewareLiveTest { + + private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); + private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); + private static final String MODEL_NAME = System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", ""); + + private OpenAIClient client; + + @Before + public void setUp() { + assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); + } + + @Test + public void compactsHistoryWhenThresholdCrossed() { + // Seed a realistic ReAct-style history so the summarizer has something non-trivial to + // summarize. ~20 alternating user/assistant turns. + ConversationMemory memory = new ConversationMemory(); + memory.setSystemPrompt( + "You are a data agent. You previously helped the user investigate the orders table."); + for (int i = 0; i < 10; i++) { + memory.addUserMessage( + "Follow-up question " + i + ": what about the column customer_id in orders?"); + memory.addAssistantMessage( + ChatCompletionAssistantMessageParam.builder() + .content( + "Assistant turn " + + i + + ": the orders table has a customer_id BIGINT column referencing customers.id.") + .build()); + } + int originalSize = memory.size(); + + AgentRunContext ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE); + // Simulate the previous LLM call having reported a large prompt_tokens so the next + // beforeLlmCall trips the threshold. + ctx.addTokenUsage(60_000, 0, 60_000); + + CompactionMiddleware mw = new CompactionMiddleware(client, MODEL_NAME, /* trigger */ 50_000L); + + AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, memory.buildLlmMessages()); + + assertNotNull("expected compaction to fire", action); + assertTrue(action instanceof AgentMiddleware.LlmModifyMessages); + + // History got rewritten: [summary user msg] + kept tail. + List hist = memory.getHistory(); + assertTrue(hist.size() < originalSize); + assertTrue(hist.get(0).isUser()); + String first = hist.get(0).asUser().content().text().orElse(""); + assertTrue( + "summary message should be wrapped in ", + first.contains("")); + + // The LLM was told to emit 8 markdown sections; sanity-check a couple show up. + assertTrue( + "summary should contain '## User Intent' section", + first.contains("## User Intent") || first.contains("User Intent")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java new file mode 100644 index 00000000000..d4c0cc581db --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests that exercise only the deterministic, LLM-free parts of {@link CompactionMiddleware}: + * + *
      + *
    • Static helpers {@code computeSplit} and {@code estimateTailAfterLastAssistant}. + *
    • Pre-summarizer gating paths in {@code beforeLlmCall} (below threshold, empty old after + * split) that never reach the LLM call. + *
    • Constructor validation. + *
    + * + * The end-to-end behaviour (actual summarization output, compacted-history structure, no re-trigger + * next turn) is covered by {@code CompactionMiddlewareLiveTest}, which is gated on real LLM + * credentials. + */ +public class CompactionMiddlewareTest { + + /** + * A minimally-configured real {@link OpenAIClient}. Never invoked in these tests because every + * code path exercised here returns before reaching the summarizer. Uses dummy credentials and a + * bogus base URL so an accidental invocation would fail loudly. + */ + private static final OpenAIClient DUMMY_CLIENT = + OpenAIOkHttpClient.builder() + .apiKey("dummy-for-unit-tests") + .baseUrl("http://127.0.0.1:0") + .build(); + + private ConversationMemory memory; + private AgentRunContext ctx; + + @Before + public void setUp() { + memory = new ConversationMemory(); + memory.setSystemPrompt("SYS"); + ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE); + } + + // ----- computeSplit ----- + + @Test + public void computeSplit_shortHistoryReturnsAllInKept() { + List history = + Arrays.asList(userMsg("u0"), asstMsg(asstPlain("a0"))); + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertEquals(0, s.old.size()); + assertEquals(2, s.kept.size()); + } + + @Test + public void computeSplit_simpleAlternationSplitsCorrectly() { + // 10 msgs: u0,a0,u1,a1,u2,a2,u3,a3,u4,a4. keep=2 → boundary at u3 (2 users from tail). + List history = alternatingHistory(10); + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 2); + // 4 users from tail: u4,u3 → splitIdx = index of u3 = 6 → old=[0..5], kept=[6..9] + assertEquals(6, s.old.size()); + assertEquals(4, s.kept.size()); + assertTrue(s.kept.get(0).isUser()); + } + + @Test + public void computeSplit_neverOrphansToolResult() { + // Layout: u0 a0 u1 a1 u2 a2 u3 a3(tc1) tool(tc1) u4 a4 u5 a5 u6 a6 u7 a7 + // keep=4 → naive splitIdx lands between tool(tc1) and u4; pair-protection must shift it back + // to before a3(tc1), so a3(tc1) + tool(tc1) end up in kept. + List history = new ArrayList<>(); + history.add(userMsg("u0")); + history.add(asstMsg(asstPlain("a0"))); + history.add(userMsg("u1")); + history.add(asstMsg(asstPlain("a1"))); + history.add(userMsg("u2")); + history.add(asstMsg(asstPlain("a2"))); + history.add(userMsg("u3")); + history.add(asstMsg(asstWithToolCall("a3", "tc1", "sql_query", "{}"))); + history.add(toolMsg("tc1", "r1")); + history.add(userMsg("u4")); + history.add(asstMsg(asstPlain("a4"))); + history.add(userMsg("u5")); + history.add(asstMsg(asstPlain("a5"))); + history.add(userMsg("u6")); + history.add(asstMsg(asstPlain("a6"))); + history.add(userMsg("u7")); + history.add(asstMsg(asstPlain("a7"))); + + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertNoOrphanToolResult(s.kept); + // Verify the tc1 pair really did end up in kept, not old. + assertTrue("a3(tc1) must be in kept", containsToolCallId(s.kept, "tc1")); + assertTrue("tool_result(tc1) must be in kept", containsToolCallIdAsResult(s.kept, "tc1")); + } + + @Test + public void computeSplit_keepCountExceedsAvailableUsers() { + // Only 2 user msgs but we ask to keep 4 — boundary walks to the top, old=[], kept=all. + List history = alternatingHistory(4); // u0,a0,u1,a1 + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertEquals(0, s.old.size()); + assertEquals(4, s.kept.size()); + } + + // ----- estimateTailAfterLastAssistant ----- + + @Test + public void estimateTail_afterLastAssistant() { + // u(200 chars) a(50) u(100) → last assistant at index 1, tail is the final user = 100 chars + // → 100/4 = 25 tokens + List msgs = + Arrays.asList( + userMsg(repeat('x', 200)), + asstMsg(asstPlain(repeat('y', 50))), + userMsg(repeat('z', 100))); + assertEquals(25L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs)); + } + + @Test + public void estimateTail_noAssistantMeansEverythingIsTail() { + List msgs = + Arrays.asList(userMsg(repeat('x', 400)), userMsg(repeat('y', 400))); + assertEquals(200L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs)); + } + + @Test + public void estimateTail_emptyReturnsZero() { + assertEquals(0L, CompactionMiddleware.estimateTailAfterLastAssistant(Collections.emptyList())); + } + + // ----- beforeLlmCall pre-summarizer gating ----- + + @Test + public void belowThresholdReturnsNull() { + seedSimpleHistory(6); + ctx.addTokenUsage(1000, 0, 1000); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + // Nothing was mutated. + assertEquals(6, memory.size()); + } + + @Test + public void aboveThresholdButHistoryTooShortReturnsNull() { + // Threshold crossed (60k cumulative) but history has only 2 user turns → computeSplit + // can't satisfy KEEP_RECENT_TURNS=4 and keeps everything, leaving split.old empty; so + // beforeLlmCall bails out before ever invoking the summarizer. + memory.addUserMessage("u0"); + memory.addAssistantMessage(asstPlain("a0")); + memory.addUserMessage("u1"); + ctx.addTokenUsage(60_000, 0, 60_000); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertEquals(3, memory.size()); + } + + @Test + public void triggerUsesLastCallTotalNotCumulative() { + // Two consecutive calls with total_tokens below threshold. The middleware must key on the + // *last* call's total (prompt + completion), not the session cumulative — otherwise a session + // that has accumulated large cumulative cost but then compacted would misfire. Using total + // (not just prompt) also covers the last assistant message — e.g. a tool_call's completion + // tokens — which is part of the next prompt but sits beyond the tail estimator's window. + seedSimpleHistory(6); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + ctx.addTokenUsage(4_000, 1_000, 5_000); + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + + ctx.addTokenUsage(8_000, 2_000, 10_000); + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + + assertEquals(10_000L, memory.getLastTotalTokens()); + assertEquals(15_000L, memory.getCumulativeTotalTokens()); + } + + // ----- helpers ----- + + private void seedSimpleHistory(int n) { + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + memory.addUserMessage("u" + i); + } else { + memory.addAssistantMessage(asstPlain("a" + i)); + } + } + } + + private static List alternatingHistory(int n) { + List out = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + out.add(userMsg("u" + i)); + } else { + out.add(asstMsg(asstPlain("a" + i))); + } + } + return out; + } + + private static ChatCompletionMessageParam userMsg(String text) { + return ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(text).build()); + } + + private static ChatCompletionMessageParam asstMsg(ChatCompletionAssistantMessageParam p) { + return ChatCompletionMessageParam.ofAssistant(p); + } + + private static ChatCompletionMessageParam toolMsg(String toolCallId, String content) { + return ChatCompletionMessageParam.ofTool( + ChatCompletionToolMessageParam.builder().toolCallId(toolCallId).content(content).build()); + } + + private static ChatCompletionAssistantMessageParam asstPlain(String text) { + return ChatCompletionAssistantMessageParam.builder().content(text).build(); + } + + private static ChatCompletionAssistantMessageParam asstWithToolCall( + String text, String toolCallId, String toolName, String args) { + List calls = new ArrayList<>(); + calls.add( + ChatCompletionMessageToolCall.ofFunction( + ChatCompletionMessageFunctionToolCall.builder() + .id(toolCallId) + .function( + ChatCompletionMessageFunctionToolCall.Function.builder() + .name(toolName) + .arguments(args) + .build()) + .build())); + return ChatCompletionAssistantMessageParam.builder().content(text).toolCalls(calls).build(); + } + + private static String repeat(char c, int n) { + char[] arr = new char[n]; + Arrays.fill(arr, c); + return new String(arr); + } + + private static boolean containsToolCallId(List msgs, String id) { + for (ChatCompletionMessageParam m : msgs) { + if (m.isAssistant()) { + List calls = m.asAssistant().toolCalls().orElse(null); + if (calls == null) continue; + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction() && id.equals(tc.asFunction().id())) return true; + } + } + } + return false; + } + + private static boolean containsToolCallIdAsResult( + List msgs, String id) { + for (ChatCompletionMessageParam m : msgs) { + if (m.isTool() && id.equals(m.asTool().toolCallId())) return true; + } + return false; + } + + private static void assertNoOrphanToolResult(List msgs) { + Set issued = new HashSet<>(); + for (ChatCompletionMessageParam m : msgs) { + if (m.isAssistant()) { + m.asAssistant() + .toolCalls() + .ifPresent( + calls -> { + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction()) issued.add(tc.asFunction().id()); + } + }); + } + } + for (ChatCompletionMessageParam m : msgs) { + if (m.isTool()) { + assertTrue( + "tool_result id=" + m.asTool().toolCallId() + " has no matching tool_call", + issued.contains(m.asTool().toolCallId())); + } + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java new file mode 100644 index 00000000000..cb107b775f5 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import java.util.Collections; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool; +import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class ToolResultOffloadMiddlewareTest { + + private ToolResultOffloadMiddleware mw; + private AgentRunContext ctxWithSession; + private AgentRunContext ctxNoSession; + + @Before + public void setUp() { + mw = new ToolResultOffloadMiddleware(); + ctxWithSession = + new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, "sess-1"); + ctxNoSession = new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, null); + } + + @After + public void tearDown() { + mw.onStop(); + } + + @Test + public void underThresholdPassesThrough() { + String small = "row1\nrow2\nrow3\n"; + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), small); + assertNull(out); + } + + @Test + public void overLineThresholdTriggersOffload() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + + assertNotNull(out); + assertTrue(out, out.contains("Tool output truncated")); + assertTrue(out, out.contains("Saved to:")); + assertTrue(out, out.contains(ReadToolOutputTool.NAME)); + assertTrue(out, out.contains(GrepToolOutputTool.NAME)); + assertTrue(out, out.contains("row0")); + assertTrue(out, out.contains("row599")); + } + + @Test + public void overByteThresholdTriggersOffload() { + // 60 lines of ~1 KB each = ~60 KB — over the byte threshold but well under the line threshold. + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 60; i++) { + for (int j = 0; j < 1024; j++) sb.append('a'); + sb.append('\n'); + } + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + + assertNotNull("byte threshold should trigger", out); + assertTrue(out, out.contains("Tool output truncated")); + } + + @Test + public void retrievalToolsAreExemptFromGate() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 2000; i++) sb.append("row").append(i).append('\n'); + assertNull( + mw.afterToolCall( + ctxWithSession, ReadToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); + assertNull( + mw.afterToolCall( + ctxWithSession, GrepToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); + } + + @Test + public void missingSessionIdPassesThrough() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 1000; i++) sb.append("row").append(i).append('\n'); + String out = + mw.afterToolCall(ctxNoSession, "run_select_query", Collections.emptyMap(), sb.toString()); + assertNull("without sessionId, cannot offload safely — pass through", out); + } + + @Test + public void onSessionCloseClearsCounterAndFiles() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + assertEquals(1, mw.trackedSessions()); + + mw.onSessionClose("sess-1"); + assertEquals(0, mw.trackedSessions()); + } + + @Test + public void multipleOffloadsReuseSameSessionDir() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + String out1 = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + String out2 = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + // Both previews reference the same session dir, different file names. + assertNotEquals(extractPath(out1), extractPath(out2)); + assertTrue(extractPath(out1).contains("sess-1")); + assertTrue(extractPath(out2).contains("sess-1")); + } + + private static String extractPath(String preview) { + int i = preview.indexOf("Saved to:"); + int eol = preview.indexOf('\n', i); + return preview.substring(i + "Saved to:".length(), eol).trim(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java index c3ceb1a4dc0..7f790e1238e 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java @@ -66,7 +66,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "result_" + idx; } }); @@ -118,7 +118,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "existing_result"; } }); @@ -159,7 +159,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "dynamic"; } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java index 8e0a5cd01b0..777017c4438 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java @@ -109,7 +109,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { try { Thread.sleep(60_000); } catch (InterruptedException e) { @@ -147,7 +147,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { throw new RuntimeException("intentional failure"); } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java index c46fc2501bc..5384bd937d9 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java @@ -24,6 +24,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; import org.junit.After; import org.junit.Before; @@ -60,7 +61,7 @@ public void testRiskLevelDestructive() { public void testInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO t VALUES (9999, 'hello')"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.contains("1 row(s) affected")); } @@ -68,21 +69,21 @@ public void testInsert() { public void testUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE t SET v = 'updated' WHERE id = 1"; - assertTrue(tool.execute(args).contains("1 row(s) affected")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); } @Test public void testDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM t WHERE id = 1"; - assertTrue(tool.execute(args).contains("1 row(s) affected")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); } @Test public void testCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE new_t (id INTEGER PRIMARY KEY, v TEXT)"; - assertTrue(tool.execute(args).contains("executed successfully")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("executed successfully")); } @Test @@ -90,7 +91,7 @@ public void testAlsoAcceptsSelect() { // Mutation tool does not enforce read-only check; SELECT works fine here. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT v FROM t WHERE id = 1"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); } @@ -98,18 +99,18 @@ public void testAlsoAcceptsSelect() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs).startsWith("Error:")); + assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO nonexistent_table VALUES (1)"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } // --- Helpers --- diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java index 3c6579cf9a7..d1015ede63d 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java @@ -24,6 +24,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -56,7 +57,7 @@ public void tearDown() { public void testRejectsInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO large_table VALUES (9999, 'x')"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("read-only")); assertTrue(result.contains("run_mutation_query")); @@ -66,28 +67,28 @@ public void testRejectsInsert() { public void testRejectsUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE large_table SET value = 'x' WHERE id = 1"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM large_table WHERE id = 1"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE x (id INT)"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testAllowsSelect() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 100"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[100 row(s) returned]")); } @@ -96,7 +97,7 @@ public void testAllowsSelect() { public void testAllowsCte() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "WITH cte AS (SELECT id, value FROM large_table LIMIT 5) SELECT * FROM cte"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("row(s)")); } @@ -107,7 +108,7 @@ public void testAllowsCte() { public void testRespectsLimitInSql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 5"; - assertTrue(tool.execute(args).contains("[5 row(s) returned]")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[5 row(s) returned]")); } @Test @@ -116,7 +117,7 @@ public void testNoClientSideCapWhenLimitOmitted() { // Cap discipline is delegated to the LLM via the system prompt. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table"; - assertTrue(tool.execute(args).contains("[1500 row(s) returned]")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[1500 row(s) returned]")); } // --- Zero-row result --- @@ -125,7 +126,7 @@ public void testNoClientSideCapWhenLimitOmitted() { public void testZeroRowsResult() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table WHERE id < 0"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[0 row(s) returned]")); } @@ -136,15 +137,15 @@ public void testZeroRowsResult() { public void testSelectWithLeadingBlockComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "/* get count */ SELECT COUNT(*) FROM large_table"; - assertFalse(tool.execute(args).startsWith("Error:")); + assertFalse(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsMutationHiddenBehindComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "-- looks innocent\nDROP TABLE large_table"; - assertTrue(tool.execute(args).startsWith("Error:")); - assertTrue(tool.execute(args).contains("read-only")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("read-only")); } // --- Edge cases --- @@ -153,25 +154,25 @@ public void testRejectsMutationHiddenBehindComment() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs).startsWith("Error:")); + assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsWhitespaceOnlySql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = " \t\n "; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM nonexistent_table"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } // --- Output formatting --- @@ -188,7 +189,7 @@ public void testNullValuesRenderedAsNULL() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id, name FROM nullable_test ORDER BY ROWID"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.contains("NULL")); assertTrue(result.contains("Alice")); } @@ -204,7 +205,7 @@ public void testPipeCharacterEscapedInOutput() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT val FROM pipe_test"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue("Pipe should be escaped for markdown table", result.contains("a\\|b\\|c")); } @@ -214,7 +215,7 @@ public void testPipeCharacterEscapedInOutput() { public void testExtractRootCauseFromNestedExceptions() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM this_table_does_not_exist_at_all"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("this_table_does_not_exist_at_all")); } @@ -223,7 +224,7 @@ public void testExtractRootCauseFromNestedExceptions() { public void testErrorMessageIsConcise() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELEC INVALID SYNTAX HERE !!!"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); long newlines = result.chars().filter(c -> c == '\n').count(); assertTrue("Error should be concise (<=2 newlines), got " + newlines, newlines <= 2); @@ -236,7 +237,7 @@ public void testCustomQueryTimeout() { RunSelectQueryTool customTool = new RunSelectQueryTool(ds, 5); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT COUNT(*) FROM large_table"; - assertFalse(customTool.execute(args).startsWith("Error:")); + assertFalse(customTool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test @@ -314,7 +315,7 @@ public boolean isWrapperFor(Class iface) { RunSelectQueryTool timeoutTool = new RunSelectQueryTool(slowDs, 1); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM large_table"; - String result = timeoutTool.execute(args); + String result = timeoutTool.execute(args, ToolContext.EMPTY); assertTrue("Expected error on timeout", result.startsWith("Error:")); assertTrue("Expected timeout message", result.contains("timed out")); } diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala new file mode 100644 index 00000000000..5fd95bb0fdc --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.operation + +import java.sql.DriverManager + +import scala.collection.JavaConverters._ + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} + +import org.apache.kyuubi.config.KyuubiConf._ +import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine +import org.apache.kyuubi.operation.HiveJDBCTestHelper + +/** + * End-to-end test that forces CompactionMiddleware to fire inside the engine and verifies + * two things simultaneously: + * 1. The `compaction` SSE event reaches the JDBC client (wiring works). + * 2. The agent still answers correctly *after* compaction, proving the summary preserved + * the facts the follow-up question depends on. + * + * The trigger threshold is set extremely low (500 tokens) so that the schema dump plus the + * first turn's completion already blow past it, forcing compaction before turn 2 or 3. + * + * Requires DATA_AGENT_LLM_API_KEY and DATA_AGENT_LLM_API_URL. + */ +class DataAgentCompactionE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { + + private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") + private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") + private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") + private val dbPath = { + val tmp = System.getProperty("java.io.tmpdir") + val uid = java.util.UUID.randomUUID() + s"$tmp/dataagent_compaction_e2e_$uid.db" + } + + override def withKyuubiConf: Map[String, String] = Map( + ENGINE_DATA_AGENT_PROVIDER.key -> "OPENAI_COMPATIBLE", + ENGINE_DATA_AGENT_LLM_API_KEY.key -> apiKey, + ENGINE_DATA_AGENT_LLM_API_URL.key -> apiUrl, + ENGINE_DATA_AGENT_LLM_MODEL.key -> modelName, + ENGINE_DATA_AGENT_MAX_ITERATIONS.key -> "10", + ENGINE_DATA_AGENT_APPROVAL_MODE.key -> "AUTO_APPROVE", + // Force compaction to fire aggressively -- any realistic prompt + one LLM round-trip + // will exceed 500 tokens, so compaction must trigger by turn 2 or 3. + ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS.key -> "500", + ENGINE_DATA_AGENT_JDBC_URL.key -> s"jdbc:sqlite:$dbPath") + + override protected def jdbcUrl: String = jdbcConnectionUrl + + private val enabled: Boolean = apiKey.nonEmpty && apiUrl.nonEmpty + + override def beforeAll(): Unit = { + if (enabled) { + setupTestDatabase() + super.beforeAll() + } + } + + override def afterAll(): Unit = { + if (enabled) { + super.afterAll() + new java.io.File(dbPath).delete() + } + } + + private def setupTestDatabase(): Unit = { + new java.io.File(dbPath).delete() + val conn = DriverManager.getConnection(s"jdbc:sqlite:$dbPath") + try { + val stmt = conn.createStatement() + stmt.execute( + """ + |CREATE TABLE employees ( + | id INTEGER PRIMARY KEY, + | name TEXT NOT NULL, + | department TEXT NOT NULL, + | salary REAL NOT NULL + |)""".stripMargin) + // 6 employees across 3 departments; Frank is unambiguously the top earner. + stmt.execute("INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 25000)") + stmt.execute("INSERT INTO employees VALUES (2, 'Bob', 'Engineering', 30000)") + stmt.execute("INSERT INTO employees VALUES (3, 'Charlie', 'Sales', 20000)") + stmt.execute("INSERT INTO employees VALUES (4, 'Diana', 'Sales', 22000)") + stmt.execute("INSERT INTO employees VALUES (5, 'Eve', 'Marketing', 18000)") + stmt.execute("INSERT INTO employees VALUES (6, 'Frank', 'Engineering', 35000)") + } finally { + conn.close() + } + } + + private val mapper = new ObjectMapper() + + private def drainReply(rs: java.sql.ResultSet): String = { + val sb = new StringBuilder + while (rs.next()) { + sb.append(rs.getString("reply")) + } + sb.toString() + } + + private def parseEvents(stream: String): Seq[JsonNode] = { + val parser = mapper.getFactory.createParser(stream) + try mapper.readValues(parser, classOf[JsonNode]).asScala.toList + finally parser.close() + } + + private def extractAnswer(events: Seq[JsonNode]): String = { + val sb = new StringBuilder + events.foreach { node => + if ("content_delta" == node.path("type").asText()) { + sb.append(node.path("text").asText("")) + } + } + sb.toString() + } + + private val strictFormatHint = + "Respond with ONLY the answer, no explanation, no markdown, no punctuation." + + test("E2E: compaction fires mid-conversation and preserves facts across turns") { + assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping") + + // CompactionMiddleware.KEEP_RECENT_TURNS is hardcoded to 4. computeSplit needs a + // non-empty 'old' slice, so at least 5 distinct user turns must accumulate before + // compaction can fire -- turns 1..(N-4) become the old slice, 4 most recent are + // kept verbatim. Turn 5 is the observable trigger point. + // + // Turn 5 is phrased to force a fresh SQL query rather than relying on recall, + // because summary quality varies across LLMs (some drop the top-earner fact). + // Correctness of the final answer then validates that the post-compaction history + // still gives the agent enough context to pick the right tool and query -- which is + // the compaction contract we actually care about: mechanism fires, agent recovers. + withJdbcStatement() { stmt => + Seq( + "List every department that appears in the employees table.", + "How many employees work in Engineering?", + "What salaries do Sales employees earn?", + "Who works in Marketing?").zipWithIndex.foreach { case (q, i) => + val events = parseEvents(drainReply(stmt.executeQuery(q))) + info(s"Turn ${i + 1} answer: ${extractAnswer(events)}") + } + + // Turn 5 -- explicitly instruct the agent to re-query so the answer does not + // depend on summary fidelity, only on the agent still functioning after + // compaction rewrote history. + val events5 = parseEvents(drainReply(stmt.executeQuery( + "Run a SELECT against the employees table to find the single employee with" + + " the highest salary. Report ONLY that employee's name." + + s" $strictFormatHint"))) + val answer5 = extractAnswer(events5) + info(s"Turn 5 answer: $answer5") + + val compactionEvents = + events5.filter(_.path("type").asText() == "compaction") + assert( + compactionEvents.nonEmpty, + "Expected at least one compaction event in turn 5 (trigger=500 tokens, 5 turns)") + + // Sanity-check event shape -- field names must match ExecuteStatement's SSE encoder. + val c = compactionEvents.head + assert( + c.has("summarized") && c.get("summarized").asInt() > 0, + s"compaction event should carry a positive summarized count: $c") + assert(c.has("kept") && c.get("kept").asInt() >= 0, s"compaction event missing kept: $c") + assert( + c.has("triggerTokens") && c.get("triggerTokens").asLong() == 500L, + s"compaction event should echo configured trigger: $c") + + // Turn 5 was told to SELECT fresh; Frank (35000) is unambiguously the top earner. + // If we don't get "Frank", either the agent failed to re-query after compaction + // (real bug in post-compaction history) or the tool layer is broken. + assert( + answer5.contains("Frank"), + s"Turn 5 should identify Frank as the top earner after re-querying; the agent" + + s" must remain functional post-compaction. Got: $answer5") + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala index 977c655dabd..cfb83ecc9b2 100644 --- a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala @@ -19,6 +19,10 @@ package org.apache.kyuubi.engine.dataagent.operation import java.sql.DriverManager +import scala.collection.JavaConverters._ + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} + import org.apache.kyuubi.config.KyuubiConf._ import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine import org.apache.kyuubi.operation.HiveJDBCTestHelper @@ -34,7 +38,7 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") - private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "gpt-4o") + private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") private val dbPath = s"${System.getProperty("java.io.tmpdir")}/dataagent_e2e_test_${java.util.UUID.randomUUID()}.db" @@ -107,33 +111,71 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { new java.io.File(dbPath).delete() } + private val mapper = new ObjectMapper() + + private def drainReply(rs: java.sql.ResultSet): String = { + val sb = new StringBuilder + while (rs.next()) { + sb.append(rs.getString("reply")) + } + val stream = sb.toString() + info(s"Agent event stream: $stream") + stream + } + + /** + * The JDBC `reply` column is a concatenated stream of SSE events + * (`agent_start`, `tool_call`, `tool_result`, `content_delta`, ...). Only + * `content_delta.text` is actual model output - this pulls those out and + * joins them to recover the final natural-language answer. + */ + private def extractAnswer(eventStream: String): String = { + val parser = mapper.getFactory.createParser(eventStream) + val sb = new StringBuilder + try { + mapper.readValues(parser, classOf[JsonNode]).asScala.foreach { node => + if ("content_delta" == node.path("type").asText()) { + sb.append(node.path("text").asText("")) + } + } + } finally { + parser.close() + } + sb.toString() + } + + private val strictFormatHint = + "Respond with ONLY the answer, no explanation, no markdown, no punctuation." + test("E2E: agent answers data question through full Kyuubi pipeline") { assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") - // scalastyle:off println withJdbcStatement() { stmt => - // Ask a question that requires schema exploration + SQL execution - val result = stmt.executeQuery( - "Which department has the highest average salary?") - - val sb = new StringBuilder - while (result.next()) { - val chunk = result.getString("reply") - sb.append(chunk) - print(chunk) // real-time output for debugging - } - println() - - val reply = sb.toString() - - // The agent should have: - // 1. Explored the schema (mentioned table names or columns) - // 2. Executed SQL (the reply should contain actual data) - // 3. Answered with "Engineering" (avg salary 30000) - assert(reply.nonEmpty, "Agent should return a non-empty response") - assert( - reply.toLowerCase.contains("engineering") || reply.contains("30000"), - s"Expected the answer to mention 'Engineering' or '30000', got: ${reply.take(500)}") + val stream = drainReply( + stmt.executeQuery( + s"Which department has the highest average salary? $strictFormatHint")) + assert(extractAnswer(stream) == "Engineering") + } + } + + test("E2E: agent resolves follow-up question using prior conversation context") { + assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") + // Two executeQuery calls on the same Statement share the JDBC session, which means + // the provider reuses the same ConversationMemory across turns. Turn 2 uses the + // demonstrative "that department" - it can only be answered correctly if Turn 1's + // answer (Engineering) is carried over in the agent's conversation history. + withJdbcStatement() { stmt => + val stream1 = drainReply( + stmt.executeQuery( + s"Which department has the highest average salary? $strictFormatHint")) + assert(extractAnswer(stream1) == "Engineering") + + // Engineering has 3 employees (Alice, Bob, Frank). If memory is not shared + // the agent cannot resolve "that department" and cannot produce the exact + // integer 3 - nothing in Turn 2's prompt points to Engineering. + val stream2 = drainReply( + stmt.executeQuery( + s"How many employees work in that department? $strictFormatHint")) + assert(extractAnswer(stream2) == "3") } - // scalastyle:on println } } diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index 215a76e26d4..31d308e9c67 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -3848,6 +3848,21 @@ object KyuubiConf { .checkValue(_ > 0, "must be positive number") .createWithDefault(100) + val ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS: ConfigEntry[Long] = + buildConf("kyuubi.engine.data.agent.compaction.trigger.tokens") + .doc("The prompt-token threshold above which the Data Agent's compaction middleware " + + "summarizes older conversation history into a compact message. The check is made each " + + "turn as " + + "real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; " + + "when this predicted prompt size reaches the configured value, older messages are " + + "replaced by a single summary message while the most recent exchanges are kept verbatim. " + + "Set to a very large value (e.g., 9223372036854775807) to effectively " + + "disable compaction.") + .version("1.12.0") + .longConf + .checkValue(_ > 0, "must be positive number") + .createWithDefault(128000L) + val ENGINE_DATA_AGENT_QUERY_TIMEOUT: ConfigEntry[Long] = buildConf("kyuubi.engine.data.agent.query.timeout") .doc("The JDBC query execution timeout for the Data Agent SQL tools. " + From 52235617db8f52a211d205581377eb4e0aada950 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 23 Apr 2026 10:59:03 +0800 Subject: [PATCH 02/10] [KYUUBI #7379][2b/4] Move mysql-connector-j to test scope MySQL Connector/J is GPL-licensed and cannot be bundled in an Apache binary release. Users who need the MySQL/StarRocks datasource at runtime should provide the driver jar themselves on the engine classpath. Addresses review feedback on #7417. --- externals/kyuubi-data-agent-engine/pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml index 74da5005784..43a5008dda9 100644 --- a/externals/kyuubi-data-agent-engine/pom.xml +++ b/externals/kyuubi-data-agent-engine/pom.xml @@ -76,10 +76,11 @@ ${sqlite.version} - + com.mysql mysql-connector-j + test From 39fb9c52f4d4ef55b83dd911100c929de83f241f Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 30 Apr 2026 23:49:45 +0800 Subject: [PATCH 03/10] [KYUUBI #7379][2b/4][FOLLOWUP] Tighten data-agent dependencies: drop SQLite/PostgreSQL bundle, pin kotlin/okhttp/okio Two pom-level cleanups requested in #7417 review: 1. Drop sqlite-jdbc and postgresql JDBC drivers from the binary bundle. sqlite-jdbc moves to test scope (still needed for unit tests); postgresql is no longer declared. Users targeting those databases provide the driver jar on the engine classpath the same way they do for any other JDBC source. Trims ~14 MB from the bundled tgz. 2. Pin the kotlin runtime and okhttp/okio versions transitively introduced by openai-java in the data-agent module's pom, so any drift across openai-java upgrades becomes a deliberate change rather than a silent transitive shift. Versions pinned at the values openai-java currently resolves to (kotlin-stdlib* 1.8.0, kotlin-reflect 2.0.21, okhttp 4.12.0, okio 3.6.0); the dependency tree is identical to before. Addresses review feedback on #7417. --- externals/kyuubi-data-agent-engine/pom.xml | 61 +++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml index 43a5008dda9..be29d409208 100644 --- a/externals/kyuubi-data-agent-engine/pom.xml +++ b/externals/kyuubi-data-agent-engine/pom.xml @@ -30,6 +30,15 @@ Kyuubi Project Engine Data Agent https://kyuubi.apache.org/ + + + 1.8.0 + 2.0.21 + 4.12.0 + 3.6.0 + + @@ -57,6 +66,53 @@ ${openai.sdk.version} + + + org.jetbrains.kotlin + kotlin-stdlib + ${kotlin.stdlib.version} + + + org.jetbrains.kotlin + kotlin-stdlib-common + ${kotlin.stdlib.version} + + + org.jetbrains.kotlin + kotlin-stdlib-jdk7 + ${kotlin.stdlib.version} + + + org.jetbrains.kotlin + kotlin-stdlib-jdk8 + ${kotlin.stdlib.version} + + + org.jetbrains.kotlin + kotlin-reflect + ${kotlin.reflect.version} + + + com.squareup.okhttp3 + okhttp + ${okhttp.version} + + + com.squareup.okhttp3 + logging-interceptor + ${okhttp.version} + + + com.squareup.okio + okio + ${okio.version} + + + com.squareup.okio + okio-jvm + ${okio.version} + + com.github.victools @@ -69,14 +125,15 @@ ${victools.jsonschema.version} - + org.xerial sqlite-jdbc ${sqlite.version} + test - com.mysql mysql-connector-j From 87c1f44f8acd362044bddfb26e1d627b8e51dc62 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 30 Apr 2026 23:50:40 +0800 Subject: [PATCH 04/10] [KYUUBI #7379][2b/4][FOLLOWUP] Adopt Trino-style config keys, rename OpenAiProvider to ChatCompletionProvider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous config namespace and provider class name were ambiguous — both implied an LLM/vendor identity (OpenAI) when in practice they denote the OpenAI-compatible chat-completion protocol that virtually every modern LLM endpoint speaks. Reviewers asked for vendor-neutral naming aligned with Trino's ai.* function configuration style. Config keys: kyuubi.engine.data.agent.llm.api.key -> openai.api.key kyuubi.engine.data.agent.llm.api.url -> openai.endpoint kyuubi.engine.data.agent.llm.model -> model Provider class: org.apache.kyuubi.engine.dataagent.provider.openai.OpenAiProvider -> org.apache.kyuubi.engine.dataagent.provider.chatcompletion .ChatCompletionProvider Env vars consumed by tests/E2E suites and the regenerated settings.md follow the same renames. Addresses review feedback on #7417. --- docs/configuration/settings.md | 6 ++--- .../ChatCompletionProvider.java} | 14 ++++++------ .../middleware/CompactionMiddleware.java | 2 +- .../operation/ExecuteStatement.scala | 2 +- .../dataagent/runtime/ReactAgentLiveTest.java | 16 ++++++++------ .../CompactionMiddlewareLiveTest.java | 16 ++++++++------ .../DataAgentCompactionE2ESuite.scala | 16 +++++++------- .../operation/DataAgentE2ESuite.scala | 18 +++++++-------- .../org/apache/kyuubi/config/KyuubiConf.scala | 22 ++++++++++--------- .../dataagent/DataAgentProcessBuilder.scala | 2 +- .../DataAgentProcessBuilderSuite.scala | 2 +- 11 files changed, 61 insertions(+), 55 deletions(-) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/{openai/OpenAiProvider.java => chatcompletion/ChatCompletionProvider.java} (95%) diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md index a2910abfca6..043084182b8 100644 --- a/docs/configuration/settings.md +++ b/docs/configuration/settings.md @@ -148,11 +148,11 @@ You can configure the Kyuubi properties in `$KYUUBI_HOME/conf/kyuubi-defaults.co | kyuubi.engine.data.agent.extra.classpath | <undefined> | The extra classpath for the Data Agent engine, for configuring the location of the LLM SDK and etc. | string | 1.12.0 | | kyuubi.engine.data.agent.java.options | <undefined> | The extra Java options for the Data Agent engine | string | 1.12.0 | | kyuubi.engine.data.agent.jdbc.url | <undefined> | The JDBC URL for the Data Agent engine to connect to the target database. If not set, the Data Agent will connect back to Kyuubi server via ZooKeeper service discovery. | string | 1.12.0 | -| kyuubi.engine.data.agent.llm.api.key | <undefined> | The API key to access the LLM service for the Data Agent engine. | string | 1.12.0 | -| kyuubi.engine.data.agent.llm.api.url | <undefined> | The API base URL for the LLM service used by the Data Agent engine. | string | 1.12.0 | -| kyuubi.engine.data.agent.llm.model | <undefined> | The model ID used by the Data Agent engine LLM provider. | string | 1.12.0 | | kyuubi.engine.data.agent.max.iterations | 100 | The maximum number of ReAct loop iterations for the Data Agent engine. | int | 1.12.0 | | kyuubi.engine.data.agent.memory | 1g | The heap memory for the Data Agent engine | string | 1.12.0 | +| kyuubi.engine.data.agent.model | <undefined> | The model ID used by the Data Agent engine. | string | 1.12.0 | +| kyuubi.engine.data.agent.openai.api.key | <undefined> | The API key for the OpenAI-compatible chat-completion endpoint used by the Data Agent engine. | string | 1.12.0 | +| kyuubi.engine.data.agent.openai.endpoint | <undefined> | The base URL of the OpenAI-compatible chat-completion endpoint used by the Data Agent engine. | string | 1.12.0 | | kyuubi.engine.data.agent.provider | ECHO | The provider for the Data Agent engine. Candidates:
    • ECHO: simply echoes the input, for testing purpose.
    • OPENAI_COMPATIBLE: OpenAI-compatible LLM provider.
    | string | 1.12.0 | | kyuubi.engine.data.agent.query.timeout | PT3M | The JDBC query execution timeout for the Data Agent SQL tools. Passed to Statement.setQueryTimeout so the server (Spark/Trino/...) can cooperatively cancel long-running queries and release cluster resources. Should be set lower than kyuubi.engine.data.agent.tool.call.timeout so server-side cancellation has time to react before the outer wall-clock cap fires. | duration | 1.12.0 | | kyuubi.engine.data.agent.tool.call.timeout | PT5M | The maximum wall-clock execution time for any tool call in the Data Agent engine. Acts as the outer safety net enforced by the agent runtime via Future.cancel(), applied uniformly to every tool. For SQL tools the inner JDBC-level timeout is controlled separately by kyuubi.engine.data.agent.query.timeout, which should be set lower so server-side cancellation has time to react before this hard cap fires. | duration | 1.12.0 | diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/chatcompletion/ChatCompletionProvider.java similarity index 95% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/chatcompletion/ChatCompletionProvider.java index bcd647b9326..b9bb30edb7e 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/chatcompletion/ChatCompletionProvider.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.provider.openai; +package org.apache.kyuubi.engine.dataagent.provider.chatcompletion; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; @@ -55,9 +55,9 @@ * engine is bound to one user + one datasource, so all sessions within the engine naturally share * the same data connection. */ -public class OpenAiProvider implements DataAgentProvider { +public class ChatCompletionProvider implements DataAgentProvider { - private static final Logger LOG = LoggerFactory.getLogger(OpenAiProvider.class); + private static final Logger LOG = LoggerFactory.getLogger(ChatCompletionProvider.class); private final ReactAgent agent; private final ToolRegistry toolRegistry; @@ -65,10 +65,10 @@ public class OpenAiProvider implements DataAgentProvider { private final OpenAIClient client; private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - public OpenAiProvider(KyuubiConf conf) { - String apiKey = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_KEY()); - String baseUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_URL()); - String modelName = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_MODEL()); + public ChatCompletionProvider(KyuubiConf conf) { + String apiKey = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_OPENAI_API_KEY()); + String baseUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_OPENAI_ENDPOINT()); + String modelName = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_MODEL()); int maxIterations = ConfUtils.intConf(conf, KyuubiConf.ENGINE_DATA_AGENT_MAX_ITERATIONS()); long compactionTriggerTokens = diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java index acab7ae8b75..b89dbe015cc 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java @@ -61,7 +61,7 @@ * preventing retriggering. * *

    Thread safety / shared instance: Instances of this middleware are shared across all - * sessions inside a provider (see {@code OpenAiProvider} javadoc). All per-session state + * sessions inside a provider (see {@code ChatCompletionProvider} javadoc). All per-session state * (cumulative and last-call totals) lives on {@link ConversationMemory}, so this middleware itself * is stateless across sessions and requires no per-session cleanup. * diff --git a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala index 3d902677a0a..cb901462d0f 100644 --- a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala +++ b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala @@ -88,7 +88,7 @@ class ExecuteStatement( val request = new ProviderRunRequest(statement) // Merge session-level conf with per-statement confOverlay (overlay takes precedence) val mergedConf = session.conf ++ confOverlay - mergedConf.get(KyuubiConf.ENGINE_DATA_AGENT_LLM_MODEL.key).foreach(request.modelName) + mergedConf.get(KyuubiConf.ENGINE_DATA_AGENT_MODEL.key).foreach(request.modelName) val approvalMode = mergedConf.getOrElse( KyuubiConf.ENGINE_DATA_AGENT_APPROVAL_MODE.key, session.sessionManager.getConf.get(KyuubiConf.ENGINE_DATA_AGENT_APPROVAL_MODE)) diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java index 75553f46998..e21410a99bf 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java @@ -48,15 +48,17 @@ /** * Live integration test with a real LLM and real SQLite database. Exercises the full ReAct loop: - * LLM reasoning -> tool calls -> result verification. Requires DATA_AGENT_LLM_API_KEY and - * DATA_AGENT_LLM_API_URL environment variables. Works with any OpenAI-compatible LLM service. + * LLM reasoning -> tool calls -> result verification. Requires DATA_AGENT_OPENAI_API_KEY and + * DATA_AGENT_OPENAI_ENDPOINT environment variables. Works with any OpenAI-compatible LLM service. */ public class ReactAgentLiveTest { - private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); - private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); + private static final String API_KEY = + System.getenv().getOrDefault("DATA_AGENT_OPENAI_API_KEY", ""); + private static final String BASE_URL = + System.getenv().getOrDefault("DATA_AGENT_OPENAI_ENDPOINT", ""); private static final String MODEL_NAME = - System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", "gpt-4o"); + System.getenv().getOrDefault("DATA_AGENT_MODEL", "gpt-4o"); private static final String SYSTEM_PROMPT = SystemPromptBuilder.create().datasource("sqlite").build(); @@ -66,8 +68,8 @@ public class ReactAgentLiveTest { @Before public void setUp() { - assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); - assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + assumeTrue("DATA_AGENT_OPENAI_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_OPENAI_ENDPOINT not set, skipping live tests", !BASE_URL.isEmpty()); client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java index 1c4eeb7c5ad..49187623bd1 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java @@ -33,21 +33,23 @@ /** * Live integration test for {@link CompactionMiddleware}: exercises the full compaction path - * against a real OpenAI-compatible LLM. Requires {@code DATA_AGENT_LLM_API_KEY} and {@code - * DATA_AGENT_LLM_API_URL} environment variables; skipped otherwise. + * against a real OpenAI-compatible LLM. Requires {@code DATA_AGENT_OPENAI_API_KEY} and {@code + * DATA_AGENT_OPENAI_ENDPOINT} environment variables; skipped otherwise. */ public class CompactionMiddlewareLiveTest { - private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); - private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); - private static final String MODEL_NAME = System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", ""); + private static final String API_KEY = + System.getenv().getOrDefault("DATA_AGENT_OPENAI_API_KEY", ""); + private static final String BASE_URL = + System.getenv().getOrDefault("DATA_AGENT_OPENAI_ENDPOINT", ""); + private static final String MODEL_NAME = System.getenv().getOrDefault("DATA_AGENT_MODEL", ""); private OpenAIClient client; @Before public void setUp() { - assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); - assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + assumeTrue("DATA_AGENT_OPENAI_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_OPENAI_ENDPOINT not set, skipping live tests", !BASE_URL.isEmpty()); client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); } diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala index 5fd95bb0fdc..9bef716c0ae 100644 --- a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala @@ -37,13 +37,13 @@ import org.apache.kyuubi.operation.HiveJDBCTestHelper * The trigger threshold is set extremely low (500 tokens) so that the schema dump plus the * first turn's completion already blow past it, forcing compaction before turn 2 or 3. * - * Requires DATA_AGENT_LLM_API_KEY and DATA_AGENT_LLM_API_URL. + * Requires DATA_AGENT_OPENAI_API_KEY and DATA_AGENT_OPENAI_ENDPOINT. */ class DataAgentCompactionE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { - private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") - private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") - private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") + private val apiKey = sys.env.getOrElse("DATA_AGENT_OPENAI_API_KEY", "") + private val apiUrl = sys.env.getOrElse("DATA_AGENT_OPENAI_ENDPOINT", "") + private val modelName = sys.env.getOrElse("DATA_AGENT_MODEL", "") private val dbPath = { val tmp = System.getProperty("java.io.tmpdir") val uid = java.util.UUID.randomUUID() @@ -52,9 +52,9 @@ class DataAgentCompactionE2ESuite extends HiveJDBCTestHelper with WithDataAgentE override def withKyuubiConf: Map[String, String] = Map( ENGINE_DATA_AGENT_PROVIDER.key -> "OPENAI_COMPATIBLE", - ENGINE_DATA_AGENT_LLM_API_KEY.key -> apiKey, - ENGINE_DATA_AGENT_LLM_API_URL.key -> apiUrl, - ENGINE_DATA_AGENT_LLM_MODEL.key -> modelName, + ENGINE_DATA_AGENT_OPENAI_API_KEY.key -> apiKey, + ENGINE_DATA_AGENT_OPENAI_ENDPOINT.key -> apiUrl, + ENGINE_DATA_AGENT_MODEL.key -> modelName, ENGINE_DATA_AGENT_MAX_ITERATIONS.key -> "10", ENGINE_DATA_AGENT_APPROVAL_MODE.key -> "AUTO_APPROVE", // Force compaction to fire aggressively -- any realistic prompt + one LLM round-trip @@ -135,7 +135,7 @@ class DataAgentCompactionE2ESuite extends HiveJDBCTestHelper with WithDataAgentE "Respond with ONLY the answer, no explanation, no markdown, no punctuation." test("E2E: compaction fires mid-conversation and preserves facts across turns") { - assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping") + assume(enabled, "DATA_AGENT_OPENAI_API_KEY/API_URL not set, skipping") // CompactionMiddleware.KEEP_RECENT_TURNS is hardcoded to 4. computeSplit needs a // non-empty 'old' slice, so at least 5 distinct user turns must accumulate before diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala index cfb83ecc9b2..28963147059 100644 --- a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala @@ -32,21 +32,21 @@ import org.apache.kyuubi.operation.HiveJDBCTestHelper * Full pipeline: JDBC Client -> Kyuubi Thrift -> DataAgentEngine * -> LLM -> Tools -> SQLite -> Results * - * Requires DATA_AGENT_LLM_API_KEY and DATA_AGENT_LLM_API_URL environment variables. + * Requires DATA_AGENT_OPENAI_API_KEY and DATA_AGENT_OPENAI_ENDPOINT environment variables. */ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { - private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") - private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") - private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") + private val apiKey = sys.env.getOrElse("DATA_AGENT_OPENAI_API_KEY", "") + private val apiUrl = sys.env.getOrElse("DATA_AGENT_OPENAI_ENDPOINT", "") + private val modelName = sys.env.getOrElse("DATA_AGENT_MODEL", "") private val dbPath = s"${System.getProperty("java.io.tmpdir")}/dataagent_e2e_test_${java.util.UUID.randomUUID()}.db" override def withKyuubiConf: Map[String, String] = Map( ENGINE_DATA_AGENT_PROVIDER.key -> "OPENAI_COMPATIBLE", - ENGINE_DATA_AGENT_LLM_API_KEY.key -> apiKey, - ENGINE_DATA_AGENT_LLM_API_URL.key -> apiUrl, - ENGINE_DATA_AGENT_LLM_MODEL.key -> modelName, + ENGINE_DATA_AGENT_OPENAI_API_KEY.key -> apiKey, + ENGINE_DATA_AGENT_OPENAI_ENDPOINT.key -> apiUrl, + ENGINE_DATA_AGENT_MODEL.key -> modelName, ENGINE_DATA_AGENT_MAX_ITERATIONS.key -> "10", ENGINE_DATA_AGENT_APPROVAL_MODE.key -> "AUTO_APPROVE", ENGINE_DATA_AGENT_JDBC_URL.key -> s"jdbc:sqlite:$dbPath") @@ -148,7 +148,7 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { "Respond with ONLY the answer, no explanation, no markdown, no punctuation." test("E2E: agent answers data question through full Kyuubi pipeline") { - assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") + assume(enabled, "DATA_AGENT_OPENAI_API_KEY/API_URL not set, skipping E2E tests") withJdbcStatement() { stmt => val stream = drainReply( stmt.executeQuery( @@ -158,7 +158,7 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { } test("E2E: agent resolves follow-up question using prior conversation context") { - assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") + assume(enabled, "DATA_AGENT_OPENAI_API_KEY/API_URL not set, skipping E2E tests") // Two executeQuery calls on the same Statement share the JDBC session, which means // the provider reuses the same ConversationMemory across turns. Turn 2 uses the // demonstrative "that department" - it can only be answered correctly if Turn 1's diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index 31d308e9c67..2a31cae2138 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -3814,28 +3814,30 @@ object KyuubiConf { case "ECHO" | "echo" => "org.apache.kyuubi.engine.dataagent.provider.echo.EchoProvider" case "OPENAI_COMPATIBLE" | "openai_compatible" | "openai-compatible" => - "org.apache.kyuubi.engine.dataagent.provider.openai.OpenAiProvider" + "org.apache.kyuubi.engine.dataagent.provider.chatcompletion.ChatCompletionProvider" case other => other } .createWithDefault("ECHO") - val ENGINE_DATA_AGENT_LLM_API_KEY: OptionalConfigEntry[String] = - buildConf("kyuubi.engine.data.agent.llm.api.key") - .doc("The API key to access the LLM service for the Data Agent engine.") + val ENGINE_DATA_AGENT_MODEL: OptionalConfigEntry[String] = + buildConf("kyuubi.engine.data.agent.model") + .doc("The model ID used by the Data Agent engine.") .version("1.12.0") .stringConf .createOptional - val ENGINE_DATA_AGENT_LLM_MODEL: OptionalConfigEntry[String] = - buildConf("kyuubi.engine.data.agent.llm.model") - .doc("The model ID used by the Data Agent engine LLM provider.") + val ENGINE_DATA_AGENT_OPENAI_API_KEY: OptionalConfigEntry[String] = + buildConf("kyuubi.engine.data.agent.openai.api.key") + .doc("The API key for the OpenAI-compatible chat-completion endpoint used by " + + "the Data Agent engine.") .version("1.12.0") .stringConf .createOptional - val ENGINE_DATA_AGENT_LLM_API_URL: OptionalConfigEntry[String] = - buildConf("kyuubi.engine.data.agent.llm.api.url") - .doc("The API base URL for the LLM service used by the Data Agent engine.") + val ENGINE_DATA_AGENT_OPENAI_ENDPOINT: OptionalConfigEntry[String] = + buildConf("kyuubi.engine.data.agent.openai.endpoint") + .doc("The base URL of the OpenAI-compatible chat-completion endpoint used by " + + "the Data Agent engine.") .version("1.12.0") .stringConf .createOptional diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilder.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilder.scala index c68ec18d8b0..c53b749a8e5 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilder.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilder.scala @@ -102,7 +102,7 @@ class DataAgentProcessBuilder( } else { redactConfValues( Utils.redactCommandLineArgs(conf, commands), - Set(ENGINE_DATA_AGENT_LLM_API_KEY.key)).map { + Set(ENGINE_DATA_AGENT_OPENAI_API_KEY.key)).map { case arg if arg.startsWith("-") || arg == mainClass => s"\\\n\t$arg" case arg => arg }.mkString(" ") diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilderSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilderSuite.scala index 5c21c31d36c..32995b7eb32 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilderSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/dataagent/DataAgentProcessBuilderSuite.scala @@ -34,7 +34,7 @@ class DataAgentProcessBuilderSuite extends KyuubiFunSuite { test("API key is redacted in toString") { val conf = new KyuubiConf(false) - conf.set(ENGINE_DATA_AGENT_LLM_API_KEY.key, "sk-secret-key-12345") + conf.set(ENGINE_DATA_AGENT_OPENAI_API_KEY.key, "sk-secret-key-12345") val builder = new DataAgentProcessBuilder("testUser", doAsEnabled = false, conf) val output = builder.toString assert(!output.contains("sk-secret-key-12345"), "API key should not appear in toString output") From d1777fceab56b3d9623bcc5544651ce80b85f2d0 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 30 Apr 2026 23:53:25 +0800 Subject: [PATCH 05/10] [KYUUBI #7379][2b/4][FOLLOWUP] Capitalize SQLite and MySQL in dialect class names Reviewer asked for proper acronym casing in class names. Rename: SqliteDialect -> SQLiteDialect MysqlDialect -> MySQLDialect and update test method names that embed the same tokens (testSqlite*, testMysql*, testDatasourceSqlite, testDatasourceMysql). Addresses review feedback on #7417. --- .../engine/dataagent/datasource/JdbcDialect.java | 8 ++++---- .../{MysqlDialect.java => MySQLDialect.java} | 6 +++--- .../{SqliteDialect.java => SQLiteDialect.java} | 6 +++--- .../dataagent/datasource/JdbcDialectTest.java | 14 +++++++------- .../kyuubi/engine/dataagent/mysql/DialectTest.java | 6 +++--- .../dataagent/prompt/SystemPromptBuilderTest.java | 4 ++-- 6 files changed, 22 insertions(+), 22 deletions(-) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/{MysqlDialect.java => MySQLDialect.java} (89%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/{SqliteDialect.java => SQLiteDialect.java} (89%) diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java index c771ad222aa..1b149e81d73 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java @@ -18,9 +18,9 @@ package org.apache.kyuubi.engine.dataagent.datasource; import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect; -import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MySQLDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.SQLiteDialect; import org.apache.kyuubi.engine.dataagent.datasource.dialect.SparkDialect; -import org.apache.kyuubi.engine.dataagent.datasource.dialect.SqliteDialect; import org.apache.kyuubi.engine.dataagent.datasource.dialect.TrinoDialect; /** @@ -89,9 +89,9 @@ static JdbcDialect fromUrl(String jdbcUrl) { case "trino": return TrinoDialect.INSTANCE; case "mysql": - return MysqlDialect.INSTANCE; + return MySQLDialect.INSTANCE; case "sqlite": - return SqliteDialect.INSTANCE; + return SQLiteDialect.INSTANCE; default: return new GenericDialect(name); } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MySQLDialect.java similarity index 89% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MySQLDialect.java index 350789a6a87..e4dfb27ca2e 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MySQLDialect.java @@ -20,11 +20,11 @@ import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** MySQL dialect. Uses backtick quoting for identifiers. */ -public final class MysqlDialect implements JdbcDialect { +public final class MySQLDialect implements JdbcDialect { - public static final MysqlDialect INSTANCE = new MysqlDialect(); + public static final MySQLDialect INSTANCE = new MySQLDialect(); - private MysqlDialect() {} + private MySQLDialect() {} @Override public String datasourceName() { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SQLiteDialect.java similarity index 89% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SQLiteDialect.java index eb98ca8edfa..aa1d7db8d23 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SQLiteDialect.java @@ -20,11 +20,11 @@ import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** SQLite dialect. Uses double-quote quoting for identifiers. */ -public final class SqliteDialect implements JdbcDialect { +public final class SQLiteDialect implements JdbcDialect { - public static final SqliteDialect INSTANCE = new SqliteDialect(); + public static final SQLiteDialect INSTANCE = new SQLiteDialect(); - private SqliteDialect() {} + private SQLiteDialect() {} @Override public String datasourceName() { diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java index c43942a8f35..2b791233983 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java @@ -71,26 +71,26 @@ public void testTrinoQuoteIdentifier() { } @Test - public void testSqlite() { + public void testSQLite() { JdbcDialect d = JdbcDialect.fromUrl("jdbc:sqlite:/tmp/test.db"); assertNotNull(d); assertEquals("sqlite", d.datasourceName()); } @Test - public void testSqliteCaseInsensitive() { + public void testSQLiteCaseInsensitive() { assertNotNull(JdbcDialect.fromUrl("JDBC:SQLITE:test.db")); } @Test - public void testSqliteQuoteIdentifier() { + public void testSQLiteQuoteIdentifier() { JdbcDialect sqlite = JdbcDialect.fromUrl("jdbc:sqlite:test.db"); assertEquals("\"my_table\"", sqlite.quoteIdentifier("my_table")); assertEquals("\" \"\"inject\"\" \"", sqlite.quoteIdentifier(" \"inject\" ")); } @Test - public void testMysql() { + public void testMySQL() { JdbcDialect d = JdbcDialect.fromUrl("jdbc:mysql://localhost:3306"); assertNotNull(d); assertEquals("mysql", d.datasourceName()); @@ -120,13 +120,13 @@ public void testGenericDialectQuoteIdentifierUnsupported() { // --- qualify tests --- @Test - public void testMysqlQualifySchemaAndTable() { + public void testMySQLQualifySchemaAndTable() { JdbcDialect d = JdbcDialect.fromUrl("jdbc:mysql://localhost:3306"); assertEquals("`mydb`.`users`", d.qualify(TableRef.of("mydb", "users"))); } @Test - public void testMysqlQualifyTableOnly() { + public void testMySQLQualifyTableOnly() { JdbcDialect d = JdbcDialect.fromUrl("jdbc:mysql://localhost:3306"); assertEquals("`users`", d.qualify(TableRef.of("users"))); } @@ -152,7 +152,7 @@ public void testSparkQualifyFull() { } @Test - public void testSqliteQualifyTableOnly() { + public void testSQLiteQualifyTableOnly() { JdbcDialect d = JdbcDialect.fromUrl("jdbc:sqlite:test.db"); assertEquals("\"t\"", d.qualify(TableRef.of("t"))); } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java index cc45ebdc7e7..db86b92cc80 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java @@ -20,7 +20,7 @@ import static org.junit.Assert.*; import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; -import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MySQLDialect; import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; @@ -28,7 +28,7 @@ import org.junit.BeforeClass; import org.junit.Test; -/** Integration tests for {@link MysqlDialect} end-to-end with a real MySQL instance. */ +/** Integration tests for {@link MySQLDialect} end-to-end with a real MySQL instance. */ public class DialectTest extends WithMySQLContainer { private static RunSelectQueryTool selectTool; @@ -42,7 +42,7 @@ public static void setUp() { public void testDialectFromUrl() { JdbcDialect dialect = JdbcDialect.fromUrl(mysql.getJdbcUrl()); assertNotNull(dialect); - assertTrue(dialect instanceof MysqlDialect); + assertTrue(dialect instanceof MySQLDialect); assertEquals("mysql", dialect.datasourceName()); } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilderTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilderTest.java index f3cfa03dc19..f8c1ccc9541 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilderTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/prompt/SystemPromptBuilderTest.java @@ -47,7 +47,7 @@ public void testToolDescriptionsSubstituted() { } @Test - public void testDatasourceSqlite() { + public void testDatasourceSQLite() { String prompt = SystemPromptBuilder.create().datasource("sqlite").build(); assertTrue(prompt.contains("SQLite SQL compatibility")); assertTrue(prompt.contains("JULIANDAY")); @@ -108,7 +108,7 @@ public void testDatasourceClickhouseUsesGenericDialectSection() { } @Test - public void testDatasourceMysql() { + public void testDatasourceMySQL() { String prompt = SystemPromptBuilder.create().datasource("mysql").build(); assertTrue(prompt.contains("MySQL")); } From 108e9cd24288fd439660df4dd751399a3b65ac90 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 30 Apr 2026 23:54:04 +0800 Subject: [PATCH 06/10] [KYUUBI #7379][2b/4][FOLLOWUP] Replace null-as-noop with explicit sentinel actions in AgentMiddleware Reviewer pushed back on null propagation in AgentMiddleware return types. Apply the same sealed-style pattern uniformly across the three hooks that historically used null to mean "do nothing": beforeLlmCall -> LlmCallAction { LlmNoopAction | LlmSkip | LlmModifyMessages } beforeToolCall -> ToolCallAction { ToolCallApproval | ToolCallDenial } afterToolCall -> ToolResultAction { ToolResultUnchanged | ToolResultReplace } Each base type is non-instantiable, the no-op subtype is a singleton (*.INSTANCE), and the active subtype carries its payload. Defaults and all built-in middleware (Logging, Approval, Compaction, ToolResultOffload) return the appropriate sentinel; the ReactAgent dispatchers switch from null checks to instanceof checks. Tests assert on the singleton or on instanceof + cast to read the payload. No behavior change; the goal is just to remove null from the contract. Addresses review feedback on #7417. --- .../engine/dataagent/runtime/ReactAgent.java | 39 ++++---- .../runtime/middleware/AgentMiddleware.java | 90 +++++++++++++++---- .../middleware/ApprovalMiddleware.java | 6 +- .../middleware/CompactionMiddleware.java | 4 +- .../runtime/middleware/LoggingMiddleware.java | 10 +-- .../ToolResultOffloadMiddleware.java | 14 +-- .../middleware/ApprovalMiddlewareTest.java | 47 ++++++---- .../middleware/CompactionMiddlewareTest.java | 12 ++- .../ToolResultOffloadMiddlewareTest.java | 45 ++++++---- 9 files changed, 174 insertions(+), 93 deletions(-) diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java index 520cd963ef8..1f9b090b642 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java @@ -268,10 +268,10 @@ private void executeToolCalls( continue; } - AgentMiddleware.ToolCallDenial denial = + AgentMiddleware.ToolCallAction action = dispatchBeforeToolCall(ctx, fnCall.id(), toolName, toolArgs); - if (denial != null) { - String denied = "Tool call denied: " + denial.reason(); + if (action instanceof AgentMiddleware.ToolCallDenial) { + String denied = "Tool call denied: " + ((AgentMiddleware.ToolCallDenial) action).reason(); memory.addToolResult(fnCall.id(), denied); emit(ctx, new ToolResult(fnCall.id(), toolName, denied, true), eventConsumer); continue; @@ -290,11 +290,8 @@ private void executeToolCalls( for (int i = 0; i < approved.size(); i++) { ToolCallEntry entry = approved.get(i); - String output = futures.get(i).join(); - String modified = dispatchAfterToolCall(ctx, entry.toolName, entry.toolArgs, output); - if (modified != null) { - output = modified; - } + String raw = futures.get(i).join(); + String output = dispatchAfterToolCall(ctx, entry.toolName, entry.toolArgs, raw); memory.addToolResult(entry.fnCall.id(), output); emit(ctx, new ToolResult(entry.fnCall.id(), entry.toolName, output, false), eventConsumer); } @@ -508,9 +505,9 @@ private AgentMiddleware.LlmCallAction dispatchBeforeLlmCall( AgentRunContext ctx, List messages) { for (AgentMiddleware mw : middlewares) { AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, messages); - if (action != null) return action; + if (!(action instanceof AgentMiddleware.LlmNoopAction)) return action; } - return null; + return AgentMiddleware.LlmNoopAction.INSTANCE; } private void dispatchAfterLlmCall( @@ -520,29 +517,27 @@ private void dispatchAfterLlmCall( } } - private AgentMiddleware.ToolCallDenial dispatchBeforeToolCall( + private AgentMiddleware.ToolCallAction dispatchBeforeToolCall( AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { for (AgentMiddleware mw : middlewares) { - AgentMiddleware.ToolCallDenial denial = + AgentMiddleware.ToolCallAction action = mw.beforeToolCall(ctx, toolCallId, toolName, toolArgs); - if (denial != null) return denial; + if (action instanceof AgentMiddleware.ToolCallDenial) return action; } - return null; + return AgentMiddleware.ToolCallApproval.INSTANCE; } private String dispatchAfterToolCall( AgentRunContext ctx, String toolName, Map toolArgs, String result) { - String modified = null; + String current = result; for (int i = middlewares.size() - 1; i >= 0; i--) { - String mwResult = - middlewares - .get(i) - .afterToolCall(ctx, toolName, toolArgs, modified != null ? modified : result); - if (mwResult != null) { - modified = mwResult; + AgentMiddleware.ToolResultAction action = + middlewares.get(i).afterToolCall(ctx, toolName, toolArgs, current); + if (action instanceof AgentMiddleware.ToolResultReplace) { + current = ((AgentMiddleware.ToolResultReplace) action).replacement(); } } - return modified; + return current; } // --- Builder --- diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java index c934bb0882a..4d5319536a3 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java @@ -47,37 +47,42 @@ default void onAgentStart(AgentRunContext ctx) {} default void onAgentFinish(AgentRunContext ctx) {} /** - * Called before each LLM invocation. Return non-null to skip or modify the LLM call. Runs - * first-to-last. + * Called before each LLM invocation. Runs first-to-last. * - * @return {@code null} to proceed normally, {@link LlmSkip} to abort, or {@link + * @return {@link LlmNoopAction#INSTANCE} to proceed normally, {@link LlmSkip} to abort, or {@link * LlmModifyMessages} to replace the message list for this call. */ default LlmCallAction beforeLlmCall( AgentRunContext ctx, List messages) { - return null; + return LlmNoopAction.INSTANCE; } /** Called after each LLM invocation. Runs last-to-first. */ default void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {} - /** Called before each tool execution. Return non-null to deny the call. Runs first-to-last. */ - default ToolCallDenial beforeToolCall( + /** + * Called before each tool execution. Runs first-to-last. + * + * @return {@link ToolCallApproval#INSTANCE} to allow the call, {@link ToolCallDenial} to block + * it. + */ + default ToolCallAction beforeToolCall( AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { - return null; + return ToolCallApproval.INSTANCE; } /** * Called after each tool execution. Runs last-to-first. * - *

    Returns {@code String} (not {@code void}) so that middlewares can intercept and transform - * the tool result before it is fed back to the LLM — e.g. for data masking, output truncation, or - * injecting metadata. Return {@code null} to keep the original result unchanged; return a - * non-null value to replace it. + *

    Middlewares can intercept and transform the tool result before it is fed back to the LLM — + * e.g. for data masking, output truncation, or injecting metadata. + * + * @return {@link ToolResultUnchanged#INSTANCE} to keep the original result, {@link + * ToolResultReplace} to substitute it. */ - default String afterToolCall( + default ToolResultAction afterToolCall( AgentRunContext ctx, String toolName, Map toolArgs, String result) { - return null; + return ToolResultUnchanged.INSTANCE; } /** @@ -101,13 +106,21 @@ default void onSessionClose(String sessionId) {} default void onStop() {} /** - * Base type for {@code beforeLlmCall} return values. Subtypes: {@link LlmSkip} to abort the LLM - * call, {@link LlmModifyMessages} to replace the message list for this call. + * Base type for {@code beforeLlmCall} return values. Subtypes: {@link LlmNoopAction} to proceed + * normally, {@link LlmSkip} to abort the LLM call, {@link LlmModifyMessages} to replace the + * message list for this call. */ abstract class LlmCallAction { private LlmCallAction() {} } + /** Returned from {@code beforeLlmCall} to proceed without changing the LLM call. */ + final class LlmNoopAction extends LlmCallAction { + public static final LlmNoopAction INSTANCE = new LlmNoopAction(); + + private LlmNoopAction() {} + } + /** Returned from {@code beforeLlmCall} to skip the LLM call and abort the agent loop. */ class LlmSkip extends LlmCallAction { private final String reason; @@ -137,8 +150,23 @@ public List messages() { } } - /** Returned from {@code beforeToolCall} to deny a tool call. Non-null means denied. */ - class ToolCallDenial { + /** + * Base type for {@code beforeToolCall} return values. Subtypes: {@link ToolCallApproval} to allow + * the call, {@link ToolCallDenial} to block it. + */ + abstract class ToolCallAction { + private ToolCallAction() {} + } + + /** Returned from {@code beforeToolCall} to allow the tool call to proceed. */ + final class ToolCallApproval extends ToolCallAction { + public static final ToolCallApproval INSTANCE = new ToolCallApproval(); + + private ToolCallApproval() {} + } + + /** Returned from {@code beforeToolCall} to block a tool call. */ + final class ToolCallDenial extends ToolCallAction { private final String reason; public ToolCallDenial(String reason) { @@ -149,4 +177,32 @@ public String reason() { return reason; } } + + /** + * Base type for {@code afterToolCall} return values. Subtypes: {@link ToolResultUnchanged} to + * pass the result through, {@link ToolResultReplace} to substitute it. + */ + abstract class ToolResultAction { + private ToolResultAction() {} + } + + /** Returned from {@code afterToolCall} to keep the tool result as is. */ + final class ToolResultUnchanged extends ToolResultAction { + public static final ToolResultUnchanged INSTANCE = new ToolResultUnchanged(); + + private ToolResultUnchanged() {} + } + + /** Returned from {@code afterToolCall} to replace the tool result with a new string. */ + final class ToolResultReplace extends ToolResultAction { + private final String replacement; + + public ToolResultReplace(String replacement) { + this.replacement = replacement; + } + + public String replacement() { + return replacement; + } + } } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java index 92d25b47b9d..f727dba3376 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java @@ -64,12 +64,12 @@ public void onRegister(ToolRegistry registry) { } @Override - public ToolCallDenial beforeToolCall( + public ToolCallAction beforeToolCall( AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { ToolRiskLevel riskLevel = toolRegistry.getRiskLevel(toolName); if (shouldAutoApprove(ctx.getApprovalMode(), riskLevel)) { - return null; + return ToolCallApproval.INSTANCE; } String requestId = UUID.randomUUID().toString(); @@ -86,7 +86,7 @@ public ToolCallDenial beforeToolCall( return new ToolCallDenial("User denied execution of " + toolName); } LOG.info("Tool '{}' approved by user (requestId={})", toolName, requestId); - return null; + return ToolCallApproval.INSTANCE; } catch (TimeoutException e) { // Complete the future so that a late resolve() call is a harmless no-op // instead of completing a dangling future. diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java index b89dbe015cc..df0425f5a0a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java @@ -156,7 +156,7 @@ public LlmCallAction beforeLlmCall( long newTailEstimate = estimateTailAfterLastAssistant(messages); if (lastTotal + newTailEstimate < triggerPromptTokens) { - return null; + return LlmNoopAction.INSTANCE; } List history = mem.getHistory(); @@ -165,7 +165,7 @@ public LlmCallAction beforeLlmCall( // tool_result. Split split = computeSplit(history, KEEP_RECENT_TURNS); if (split.old.isEmpty()) { - return null; + return LlmNoopAction.INSTANCE; } String summary = summarize(mem.getSystemPrompt(), split.old); diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java index e0a5c2364eb..ec96b2e71ea 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java @@ -98,7 +98,7 @@ public void onAgentFinish(AgentRunContext ctx) { public LlmCallAction beforeLlmCall( AgentRunContext ctx, List messages) { LOG.info("{}LLM call: step={}, messages={}", prefix(), ctx.getIteration(), messages.size()); - return null; + return LlmNoopAction.INSTANCE; } @Override @@ -118,18 +118,18 @@ public void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessagePara } @Override - public ToolCallDenial beforeToolCall( + public ToolCallAction beforeToolCall( AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { LOG.info("{}Tool call: id={}, name={}", prefix(), toolCallId, toolName); LOG.debug("{}Tool args: {}", prefix(), toolArgs); - return null; + return ToolCallApproval.INSTANCE; } @Override - public String afterToolCall( + public ToolResultAction afterToolCall( AgentRunContext ctx, String toolName, Map toolArgs, String result) { LOG.info("{}Tool result: {} -> \"{}\"", prefix(), toolName, truncate(result)); - return null; + return ToolResultUnchanged.INSTANCE; } @Override diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java index 87aad9f3255..14b2d4ba7a2 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java @@ -78,21 +78,21 @@ public void onRegister(ToolRegistry registry) { } @Override - public String afterToolCall( + public ToolResultAction afterToolCall( AgentRunContext ctx, String toolName, Map toolArgs, String result) { - if (result.isEmpty()) return null; - if (EXEMPT_TOOLS.contains(toolName)) return null; + if (result.isEmpty()) return ToolResultUnchanged.INSTANCE; + if (EXEMPT_TOOLS.contains(toolName)) return ToolResultUnchanged.INSTANCE; int bytes = result.getBytes(StandardCharsets.UTF_8).length; int lines = countLines(result); if (lines <= MAX_LINES && bytes <= MAX_BYTES) { - return null; + return ToolResultUnchanged.INSTANCE; } // AgentRunContext.sessionId is null in unit-test constructions that don't exercise offload. // In production the provider always threads it through, so treat null as "skip offload". String sessionId = ctx.getSessionId(); - if (sessionId == null) return null; + if (sessionId == null) return ToolResultUnchanged.INSTANCE; long n = counters.computeIfAbsent(sessionId, k -> new AtomicLong()).incrementAndGet(); String toolCallId = toolName + "_" + n; @@ -106,7 +106,7 @@ public String afterToolCall( toolName, sessionId, e); - return null; + return ToolResultUnchanged.INSTANCE; } LOG.info( @@ -116,7 +116,7 @@ public String afterToolCall( lines, bytes, file.getFileName()); - return buildPreview(result, lines, bytes, file); + return new ToolResultReplace(buildPreview(result, lines, bytes, file)); } /** Clean up counter and temp dir for a closed session. Idempotent. */ diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java index a84bbc25948..ecd30a747d5 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java @@ -60,8 +60,12 @@ public void testAutoApproveModeSkipsAllApproval() { ApprovalMiddleware mw = newApprovalMiddleware(); AgentRunContext ctx = makeContext(ApprovalMode.AUTO_APPROVE); - assertNull(mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); - assertNull(mw.beforeToolCall(ctx, "tc2", "safe_tool", Collections.emptyMap())); + assertSame( + AgentMiddleware.ToolCallApproval.INSTANCE, + mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + assertSame( + AgentMiddleware.ToolCallApproval.INSTANCE, + mw.beforeToolCall(ctx, "tc2", "safe_tool", Collections.emptyMap())); assertTrue("No approval events should be emitted", emittedEvents.isEmpty()); } @@ -72,7 +76,9 @@ public void testNormalModeAutoApprovesSafeTool() { ApprovalMiddleware mw = newApprovalMiddleware(); AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); - assertNull(mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + assertSame( + AgentMiddleware.ToolCallApproval.INSTANCE, + mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); assertTrue(emittedEvents.isEmpty()); } @@ -91,7 +97,7 @@ public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception eventEmitted.countDown(); }); - Future future = + Future future = exec.submit( () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); @@ -106,7 +112,10 @@ public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception // Approve assertTrue(mw.resolve(req.requestId(), true)); - assertNull("Approved tool should return null (no denial)", future.get(2, TimeUnit.SECONDS)); + assertSame( + "Approved tool should return null (no denial)", + AgentMiddleware.ToolCallApproval.INSTANCE, + future.get(2, TimeUnit.SECONDS)); } finally { exec.shutdownNow(); } @@ -126,7 +135,7 @@ public void testDeniedToolReturnsToolCallDenial() throws Exception { eventEmitted.countDown(); }); - Future future = + Future future = exec.submit( () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); @@ -135,9 +144,9 @@ public void testDeniedToolReturnsToolCallDenial() throws Exception { // Deny assertTrue(mw.resolve(req.requestId(), false)); - AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); - assertNotNull(denial); - assertTrue(denial.reason().contains("denied")); + AgentMiddleware.ToolCallAction action = future.get(2, TimeUnit.SECONDS); + assertTrue(action instanceof AgentMiddleware.ToolCallDenial); + assertTrue(((AgentMiddleware.ToolCallDenial) action).reason().contains("denied")); } finally { exec.shutdownNow(); } @@ -159,7 +168,7 @@ public void testStrictModeRequiresApprovalForSafeTool() throws Exception { eventEmitted.countDown(); }); - Future future = + Future future = exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); @@ -167,7 +176,7 @@ public void testStrictModeRequiresApprovalForSafeTool() throws Exception { assertEquals("safe_tool", req.toolName()); assertTrue(mw.resolve(req.requestId(), true)); - assertNull(future.get(2, TimeUnit.SECONDS)); + assertSame(AgentMiddleware.ToolCallApproval.INSTANCE, future.get(2, TimeUnit.SECONDS)); } finally { exec.shutdownNow(); } @@ -183,13 +192,14 @@ public void testApprovalTimeoutReturnsDenial() throws Exception { ExecutorService exec = Executors.newSingleThreadExecutor(); try { - Future future = + Future future = exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); // Don't resolve — let it time out - AgentMiddleware.ToolCallDenial denial = future.get(5, TimeUnit.SECONDS); - assertNotNull("Timeout should produce a denial", denial); - assertTrue(denial.reason().contains("timed out")); + AgentMiddleware.ToolCallAction action = future.get(5, TimeUnit.SECONDS); + assertTrue( + "Timeout should produce a denial", action instanceof AgentMiddleware.ToolCallDenial); + assertTrue(((AgentMiddleware.ToolCallDenial) action).reason().contains("timed out")); } finally { exec.shutdownNow(); } @@ -206,7 +216,7 @@ public void testOnStopUnblocksPendingRequests() throws Exception { ExecutorService exec = Executors.newSingleThreadExecutor(); try { CountDownLatch started = new CountDownLatch(1); - Future future = + Future future = exec.submit( () -> { started.countDown(); @@ -218,8 +228,9 @@ public void testOnStopUnblocksPendingRequests() throws Exception { mw.onStop(); - AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); - assertNotNull("onStop should unblock with a denial", denial); + AgentMiddleware.ToolCallAction action = future.get(2, TimeUnit.SECONDS); + assertTrue( + "onStop should unblock with a denial", action instanceof AgentMiddleware.ToolCallDenial); } finally { exec.shutdownNow(); } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java index d4c0cc581db..2e8940422cc 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java @@ -172,7 +172,8 @@ public void belowThresholdReturnsNull() { ctx.addTokenUsage(1000, 0, 1000); CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); - assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertSame( + AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); // Nothing was mutated. assertEquals(6, memory.size()); } @@ -188,7 +189,8 @@ public void aboveThresholdButHistoryTooShortReturnsNull() { ctx.addTokenUsage(60_000, 0, 60_000); CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); - assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertSame( + AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); assertEquals(3, memory.size()); } @@ -203,10 +205,12 @@ public void triggerUsesLastCallTotalNotCumulative() { CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); ctx.addTokenUsage(4_000, 1_000, 5_000); - assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertSame( + AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); ctx.addTokenUsage(8_000, 2_000, 10_000); - assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertSame( + AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); assertEquals(10_000L, memory.getLastTotalTokens()); assertEquals(15_000L, memory.getCumulativeTotalTokens()); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java index cb107b775f5..86f67fcc642 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java @@ -51,9 +51,9 @@ public void tearDown() { @Test public void underThresholdPassesThrough() { String small = "row1\nrow2\nrow3\n"; - String out = - mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), small); - assertNull(out); + assertSame( + AgentMiddleware.ToolResultUnchanged.INSTANCE, + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), small)); } @Test @@ -61,9 +61,10 @@ public void overLineThresholdTriggersOffload() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); String out = - mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + replacement( + mw.afterToolCall( + ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); - assertNotNull(out); assertTrue(out, out.contains("Tool output truncated")); assertTrue(out, out.contains("Saved to:")); assertTrue(out, out.contains(ReadToolOutputTool.NAME)); @@ -81,9 +82,9 @@ public void overByteThresholdTriggersOffload() { sb.append('\n'); } String out = - mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); - - assertNotNull("byte threshold should trigger", out); + replacement( + mw.afterToolCall( + ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); assertTrue(out, out.contains("Tool output truncated")); } @@ -91,10 +92,12 @@ public void overByteThresholdTriggersOffload() { public void retrievalToolsAreExemptFromGate() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < 2000; i++) sb.append("row").append(i).append('\n'); - assertNull( + assertSame( + AgentMiddleware.ToolResultUnchanged.INSTANCE, mw.afterToolCall( ctxWithSession, ReadToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); - assertNull( + assertSame( + AgentMiddleware.ToolResultUnchanged.INSTANCE, mw.afterToolCall( ctxWithSession, GrepToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); } @@ -103,9 +106,10 @@ public void retrievalToolsAreExemptFromGate() { public void missingSessionIdPassesThrough() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < 1000; i++) sb.append("row").append(i).append('\n'); - String out = - mw.afterToolCall(ctxNoSession, "run_select_query", Collections.emptyMap(), sb.toString()); - assertNull("without sessionId, cannot offload safely — pass through", out); + assertSame( + "without sessionId, cannot offload safely — pass through", + AgentMiddleware.ToolResultUnchanged.INSTANCE, + mw.afterToolCall(ctxNoSession, "run_select_query", Collections.emptyMap(), sb.toString())); } @Test @@ -124,15 +128,26 @@ public void multipleOffloadsReuseSameSessionDir() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); String out1 = - mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + replacement( + mw.afterToolCall( + ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); String out2 = - mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + replacement( + mw.afterToolCall( + ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); // Both previews reference the same session dir, different file names. assertNotEquals(extractPath(out1), extractPath(out2)); assertTrue(extractPath(out1).contains("sess-1")); assertTrue(extractPath(out2).contains("sess-1")); } + private static String replacement(AgentMiddleware.ToolResultAction action) { + assertTrue( + "expected ToolResultReplace but got " + action, + action instanceof AgentMiddleware.ToolResultReplace); + return ((AgentMiddleware.ToolResultReplace) action).replacement(); + } + private static String extractPath(String preview) { int i = preview.indexOf("Saved to:"); int eol = preview.indexOf('\n', i); From b9b4208130f156a5c3458ce66001d70984f5bbe2 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 30 Apr 2026 23:54:26 +0800 Subject: [PATCH 07/10] [KYUUBI #7379][2b/4][FOLLOWUP] Move ToolContext to first parameter in AgentTool.execute Reviewer asked for the per-invocation context to come first so the parameter list reads context-then-payload, matching the conventional shape of "function(context, args)" used elsewhere in the codebase. Update the AgentTool interface signature, all production tool implementations (ReadToolOutputTool, GrepToolOutputTool, RunSelectQueryTool, RunMutationQueryTool), the ToolRegistry call site, and every test/test-helper call site that exercises tool.execute(...). Addresses review feedback on #7417. --- .../engine/dataagent/tool/AgentTool.java | 4 +- .../engine/dataagent/tool/ToolRegistry.java | 2 +- .../tool/output/GrepToolOutputTool.java | 2 +- .../tool/output/ReadToolOutputTool.java | 2 +- .../tool/sql/RunMutationQueryTool.java | 2 +- .../tool/sql/RunSelectQueryTool.java | 2 +- .../engine/dataagent/mysql/DialectTest.java | 2 +- .../middleware/ApprovalMiddlewareTest.java | 2 +- .../tool/ToolRegistryThreadSafetyTest.java | 6 +-- .../engine/dataagent/tool/ToolTest.java | 4 +- .../tool/sql/RunMutationQueryToolTest.java | 16 +++---- .../tool/sql/RunSelectQueryToolTest.java | 44 +++++++++---------- 12 files changed, 44 insertions(+), 44 deletions(-) diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java index b19f6787c87..55ecf03db64 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java @@ -49,11 +49,11 @@ default ToolRiskLevel riskLevel() { /** * Execute the tool with the given deserialized arguments. * - * @param args the deserialized arguments from the LLM's tool call * @param ctx per-invocation context (session id, etc.); never null — use {@link * ToolContext#EMPTY} for calls without a session. Tools that are session-agnostic may ignore * it. + * @param args the deserialized arguments from the LLM's tool call * @return the result string to feed back to the LLM */ - String execute(T args, ToolContext ctx); + String execute(ToolContext ctx, T args); } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java index 3a11bab567c..fd2113e33f8 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java @@ -203,7 +203,7 @@ public CompletableFuture submitTool(String toolName, String argsJson, To () -> { try { Object args = JSON.readValue(argsJson, tool.argsType()); - String out = tool.execute(args, toolCtx); + String out = tool.execute(toolCtx, args); // When the timeout handler interrupts us, the tool may still unwind cleanly and // produce a stale return value — don't race the scheduler's timeout message with // it. Let the timeout path be the single authority for the final result. diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java index f94fb21b315..2f42be2c450 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java @@ -55,7 +55,7 @@ public Class argsType() { } @Override - public String execute(GrepToolOutputArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, GrepToolOutputArgs args) { if (args == null || args.path == null || args.path.isEmpty()) { return "Error: 'path' parameter is required."; } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java index 93e837998f3..f9922f8a8fd 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java @@ -56,7 +56,7 @@ public Class argsType() { } @Override - public String execute(ReadToolOutputArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, ReadToolOutputArgs args) { if (args == null || args.path == null || args.path.isEmpty()) { return "Error: 'path' parameter is required."; } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java index 88838ca477a..5073a21a772 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java @@ -70,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, SqlQueryArgs args) { return SqlExecutor.execute(dataSource, args.sql, queryTimeoutSeconds); } } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java index 9136b0aa903..b4558527519 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java @@ -70,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, SqlQueryArgs args) { String sql = args.sql; if (sql == null || sql.trim().isEmpty()) { return "Error: 'sql' parameter is required."; diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java index db86b92cc80..47940f7ff25 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java @@ -67,7 +67,7 @@ public void testBacktickQuotingWithReservedWord() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT " + quotedCol + " FROM " + quotedTable + " WHERE id = 1"; - String result = selectTool.execute(args, ToolContext.EMPTY); + String result = selectTool.execute(ToolContext.EMPTY, args); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("value1")); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java index ecd30a747d5..c7a7916db37 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java @@ -298,7 +298,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, DummyArgs args) { return "ok"; } } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java index 7f790e1238e..fa72c1644e9 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java @@ -66,7 +66,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, DummyArgs args) { return "result_" + idx; } }); @@ -118,7 +118,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, DummyArgs args) { return "existing_result"; } }); @@ -159,7 +159,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, DummyArgs args) { return "dynamic"; } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java index 777017c4438..b2aa3dd7b80 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java @@ -109,7 +109,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, ToolRegistryThreadSafetyTest.DummyArgs args) { try { Thread.sleep(60_000); } catch (InterruptedException e) { @@ -147,7 +147,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { + public String execute(ToolContext ctx, ToolRegistryThreadSafetyTest.DummyArgs args) { throw new RuntimeException("intentional failure"); } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java index 5384bd937d9..304ff18fc0e 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java @@ -61,7 +61,7 @@ public void testRiskLevelDestructive() { public void testInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO t VALUES (9999, 'hello')"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertTrue(result.contains("1 row(s) affected")); } @@ -69,21 +69,21 @@ public void testInsert() { public void testUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE t SET v = 'updated' WHERE id = 1"; - assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); + assertTrue(tool.execute(ToolContext.EMPTY, args).contains("1 row(s) affected")); } @Test public void testDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM t WHERE id = 1"; - assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); + assertTrue(tool.execute(ToolContext.EMPTY, args).contains("1 row(s) affected")); } @Test public void testCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE new_t (id INTEGER PRIMARY KEY, v TEXT)"; - assertTrue(tool.execute(args, ToolContext.EMPTY).contains("executed successfully")); + assertTrue(tool.execute(ToolContext.EMPTY, args).contains("executed successfully")); } @Test @@ -91,7 +91,7 @@ public void testAlsoAcceptsSelect() { // Mutation tool does not enforce read-only check; SELECT works fine here. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT v FROM t WHERE id = 1"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertFalse(result.startsWith("Error:")); } @@ -99,18 +99,18 @@ public void testAlsoAcceptsSelect() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, emptyArgs).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, nullArgs).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO nonexistent_table VALUES (1)"; - assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } // --- Helpers --- diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java index d1015ede63d..d94165f9e6f 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java @@ -57,7 +57,7 @@ public void tearDown() { public void testRejectsInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO large_table VALUES (9999, 'x')"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("read-only")); assertTrue(result.contains("run_mutation_query")); @@ -67,28 +67,28 @@ public void testRejectsInsert() { public void testRejectsUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE large_table SET value = 'x' WHERE id = 1"; - assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } @Test public void testRejectsDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM large_table WHERE id = 1"; - assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } @Test public void testRejectsCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE x (id INT)"; - assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } @Test public void testAllowsSelect() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 100"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[100 row(s) returned]")); } @@ -97,7 +97,7 @@ public void testAllowsSelect() { public void testAllowsCte() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "WITH cte AS (SELECT id, value FROM large_table LIMIT 5) SELECT * FROM cte"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("row(s)")); } @@ -108,7 +108,7 @@ public void testAllowsCte() { public void testRespectsLimitInSql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 5"; - assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[5 row(s) returned]")); + assertTrue(tool.execute(ToolContext.EMPTY, args).contains("[5 row(s) returned]")); } @Test @@ -117,7 +117,7 @@ public void testNoClientSideCapWhenLimitOmitted() { // Cap discipline is delegated to the LLM via the system prompt. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table"; - assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[1500 row(s) returned]")); + assertTrue(tool.execute(ToolContext.EMPTY, args).contains("[1500 row(s) returned]")); } // --- Zero-row result --- @@ -126,7 +126,7 @@ public void testNoClientSideCapWhenLimitOmitted() { public void testZeroRowsResult() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table WHERE id < 0"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[0 row(s) returned]")); } @@ -137,15 +137,15 @@ public void testZeroRowsResult() { public void testSelectWithLeadingBlockComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "/* get count */ SELECT COUNT(*) FROM large_table"; - assertFalse(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertFalse(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } @Test public void testRejectsMutationHiddenBehindComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "-- looks innocent\nDROP TABLE large_table"; - assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); - assertTrue(tool.execute(args, ToolContext.EMPTY).contains("read-only")); + assertTrue(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, args).contains("read-only")); } // --- Edge cases --- @@ -154,25 +154,25 @@ public void testRejectsMutationHiddenBehindComment() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, emptyArgs).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, nullArgs).startsWith("Error:")); } @Test public void testRejectsWhitespaceOnlySql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = " \t\n "; - assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM nonexistent_table"; - assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } // --- Output formatting --- @@ -189,7 +189,7 @@ public void testNullValuesRenderedAsNULL() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id, name FROM nullable_test ORDER BY ROWID"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertTrue(result.contains("NULL")); assertTrue(result.contains("Alice")); } @@ -205,7 +205,7 @@ public void testPipeCharacterEscapedInOutput() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT val FROM pipe_test"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertTrue("Pipe should be escaped for markdown table", result.contains("a\\|b\\|c")); } @@ -215,7 +215,7 @@ public void testPipeCharacterEscapedInOutput() { public void testExtractRootCauseFromNestedExceptions() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM this_table_does_not_exist_at_all"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("this_table_does_not_exist_at_all")); } @@ -224,7 +224,7 @@ public void testExtractRootCauseFromNestedExceptions() { public void testErrorMessageIsConcise() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELEC INVALID SYNTAX HERE !!!"; - String result = tool.execute(args, ToolContext.EMPTY); + String result = tool.execute(ToolContext.EMPTY, args); assertTrue(result.startsWith("Error:")); long newlines = result.chars().filter(c -> c == '\n').count(); assertTrue("Error should be concise (<=2 newlines), got " + newlines, newlines <= 2); @@ -237,7 +237,7 @@ public void testCustomQueryTimeout() { RunSelectQueryTool customTool = new RunSelectQueryTool(ds, 5); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT COUNT(*) FROM large_table"; - assertFalse(customTool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertFalse(customTool.execute(ToolContext.EMPTY, args).startsWith("Error:")); } @Test @@ -315,7 +315,7 @@ public boolean isWrapperFor(Class iface) { RunSelectQueryTool timeoutTool = new RunSelectQueryTool(slowDs, 1); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM large_table"; - String result = timeoutTool.execute(args, ToolContext.EMPTY); + String result = timeoutTool.execute(ToolContext.EMPTY, args); assertTrue("Expected error on timeout", result.startsWith("Error:")); assertTrue("Expected timeout message", result.contains("timed out")); } From 211e86770f63f2845bed60a084c6679cc2dea11e Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Fri, 1 May 2026 01:42:07 +0800 Subject: [PATCH 08/10] [KYUUBI #7379][2b/4][FOLLOWUP] Unify AgentMiddleware hook return types under Decision Collapse the three sealed action hierarchies (LlmCallAction, ToolCallAction, ToolResultAction) plus nullable onEvent into a single generic Decision with proceed / replace / abort. Pack tool-call (id, name, args) into ToolInvocation so beforeToolCall can rewrite args (e.g. inject SQL LIMIT, redact params), and align afterLlmCall by moving its dispatch ahead of the memory write so replace actually rewrites what enters memory and tool-call extraction. --- .../engine/dataagent/runtime/ReactAgent.java | 151 ++++++++++------ .../runtime/middleware/AgentMiddleware.java | 162 +++--------------- .../middleware/ApprovalMiddleware.java | 19 +- .../middleware/CompactionMiddleware.java | 8 +- .../runtime/middleware/Decision.java | 95 ++++++++++ .../runtime/middleware/LoggingMiddleware.java | 29 ++-- .../runtime/middleware/ToolInvocation.java | 50 ++++++ .../ToolResultOffloadMiddleware.java | 17 +- .../middleware/ApprovalMiddlewareTest.java | 69 ++++---- .../CompactionMiddlewareLiveTest.java | 6 +- .../middleware/CompactionMiddlewareTest.java | 12 +- .../ToolResultOffloadMiddlewareTest.java | 55 +++--- 12 files changed, 375 insertions(+), 298 deletions(-) create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolInvocation.java diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java index 1f9b090b642..4ddf484a44f 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java @@ -46,6 +46,8 @@ import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware; import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.Decision; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolInvocation; import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; import org.slf4j.Logger; @@ -175,17 +177,28 @@ public void run( emit(ctx, new ContentComplete(result.content), eventConsumer); } ChatCompletionAssistantMessageParam assistantMsg = buildAssistantMessage(result); + Decision after = + dispatchAfterLlmCall(ctx, assistantMsg); + if (after.kind() == Decision.Kind.ABORT) { + emit( + ctx, + new AgentError("LLM response rejected by middleware: " + after.reason()), + eventConsumer); + emitFinish(ctx, eventConsumer); + return; + } + if (after.kind() == Decision.Kind.REPLACE) assistantMsg = after.replacement(); memory.addAssistantMessage(assistantMsg); - dispatchAfterLlmCall(ctx, assistantMsg); - if (result.toolCalls == null || result.toolCalls.isEmpty()) { + List toolCalls = assistantMsg.toolCalls().orElse(null); + if (toolCalls == null || toolCalls.isEmpty()) { // No tool calls — agent is done. emit(ctx, new StepEnd(step), eventConsumer); emitFinish(ctx, eventConsumer); return; } - executeToolCalls(ctx, memory, result.toolCalls, eventConsumer); + executeToolCalls(ctx, memory, toolCalls, eventConsumer); emit(ctx, new StepEnd(step), eventConsumer); } @@ -212,17 +225,16 @@ private List resolveMessagesForCall( AgentRunContext ctx, List messages, Consumer eventConsumer) { - AgentMiddleware.LlmCallAction action = dispatchBeforeLlmCall(ctx, messages); - if (action instanceof AgentMiddleware.LlmSkip) { - String reason = ((AgentMiddleware.LlmSkip) action).reason(); - emit(ctx, new AgentError("LLM call skipped by middleware: " + reason), eventConsumer); + Decision> decision = dispatchBeforeLlmCall(ctx, messages); + if (decision.kind() == Decision.Kind.ABORT) { + emit( + ctx, + new AgentError("LLM call skipped by middleware: " + decision.reason()), + eventConsumer); emitFinish(ctx, eventConsumer); return null; } - if (action instanceof AgentMiddleware.LlmModifyMessages) { - return ((AgentMiddleware.LlmModifyMessages) action).messages(); - } - return messages; + return decision.kind() == Decision.Kind.REPLACE ? decision.replacement() : messages; } private static ChatCompletionAssistantMessageParam buildAssistantMessage(StreamResult result) { @@ -268,32 +280,36 @@ private void executeToolCalls( continue; } - AgentMiddleware.ToolCallAction action = - dispatchBeforeToolCall(ctx, fnCall.id(), toolName, toolArgs); - if (action instanceof AgentMiddleware.ToolCallDenial) { - String denied = "Tool call denied: " + ((AgentMiddleware.ToolCallDenial) action).reason(); + ToolInvocation invocation = new ToolInvocation(fnCall.id(), toolName, toolArgs); + Decision decision = dispatchBeforeToolCall(ctx, invocation); + if (decision.kind() == Decision.Kind.ABORT) { + String denied = "Tool call denied: " + decision.reason(); memory.addToolResult(fnCall.id(), denied); emit(ctx, new ToolResult(fnCall.id(), toolName, denied, true), eventConsumer); continue; } + boolean rewritten = decision.kind() == Decision.Kind.REPLACE; + ToolInvocation effective = rewritten ? decision.replacement() : invocation; - emit(ctx, new ToolCall(fnCall.id(), toolName, toolArgs), eventConsumer); - approved.add(new ToolCallEntry(fnCall, toolName, toolArgs)); + emit(ctx, new ToolCall(effective.id(), effective.name(), effective.args()), eventConsumer); + approved.add(new ToolCallEntry(fnCall, effective, rewritten)); } ToolContext toolCtx = new ToolContext(ctx.getSessionId()); List> futures = new ArrayList<>(approved.size()); for (ToolCallEntry entry : approved) { - futures.add( - toolRegistry.submitTool(entry.toolName, entry.fnCall.function().arguments(), toolCtx)); + futures.add(toolRegistry.submitTool(entry.invocation.name(), entry.argsJson(), toolCtx)); } for (int i = 0; i < approved.size(); i++) { ToolCallEntry entry = approved.get(i); String raw = futures.get(i).join(); - String output = dispatchAfterToolCall(ctx, entry.toolName, entry.toolArgs, raw); + String output = dispatchAfterToolCall(ctx, entry.invocation, raw); memory.addToolResult(entry.fnCall.id(), output); - emit(ctx, new ToolResult(entry.fnCall.id(), entry.toolName, output, false), eventConsumer); + emit( + ctx, + new ToolResult(entry.fnCall.id(), entry.invocation.name(), output, false), + eventConsumer); } } @@ -312,19 +328,33 @@ boolean isEmpty() { } } - /** Holds an approved tool call's parsed metadata for the 3-phase execution pipeline. */ + /** + * Holds an approved tool call's parsed metadata for the 3-phase execution pipeline. {@code + * rewritten} is {@code true} when middleware replaced the {@link ToolInvocation}; in that case + * args must be re-serialized for {@link ToolRegistry#submitTool}, otherwise the LLM's original + * JSON is reused verbatim. + */ private static class ToolCallEntry { final ChatCompletionMessageFunctionToolCall fnCall; - final String toolName; - final Map toolArgs; + final ToolInvocation invocation; + final boolean rewritten; ToolCallEntry( ChatCompletionMessageFunctionToolCall fnCall, - String toolName, - Map toolArgs) { + ToolInvocation invocation, + boolean rewritten) { this.fnCall = fnCall; - this.toolName = toolName; - this.toolArgs = toolArgs; + this.invocation = invocation; + this.rewritten = rewritten; + } + + String argsJson() { + if (!rewritten) return fnCall.function().arguments(); + try { + return JSON.writeValueAsString(invocation.args()); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new IllegalStateException("Failed to serialize rewritten tool args", e); + } } } @@ -477,8 +507,9 @@ private void emitFinish(AgentRunContext ctx, Consumer eventConsumer) private void emit(AgentRunContext ctx, AgentEvent event, Consumer consumer) { AgentEvent filtered = event; for (AgentMiddleware mw : middlewares) { - filtered = mw.onEvent(ctx, filtered); - if (filtered == null) return; + Decision d = mw.onEvent(ctx, filtered); + if (d.kind() == Decision.Kind.ABORT) return; + if (d.kind() == Decision.Kind.REPLACE) filtered = d.replacement(); } consumer.accept(filtered); } @@ -501,40 +532,64 @@ private void dispatchAgentFinish(AgentRunContext ctx) { } } - private AgentMiddleware.LlmCallAction dispatchBeforeLlmCall( + /** + * Run {@code beforeLlmCall} middleware in onion order, folding REPLACE so later middlewares see + * rewritten messages. Returns PROCEED if no middleware touched the value, REPLACE with the final + * value if any did, or ABORT if any middleware short-circuited. + */ + private Decision> dispatchBeforeLlmCall( AgentRunContext ctx, List messages) { + List current = messages; for (AgentMiddleware mw : middlewares) { - AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, messages); - if (!(action instanceof AgentMiddleware.LlmNoopAction)) return action; + Decision> d = mw.beforeLlmCall(ctx, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); } - return AgentMiddleware.LlmNoopAction.INSTANCE; + return Decision.of(messages, current); } - private void dispatchAfterLlmCall( + /** + * Run {@code afterLlmCall} middleware in reverse onion order, folding REPLACE so earlier + * middlewares see rewritten responses. Returns the final response, or ABORT if any middleware + * short-circuits. + */ + private Decision dispatchAfterLlmCall( AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + ChatCompletionAssistantMessageParam current = response; for (int i = middlewares.size() - 1; i >= 0; i--) { - middlewares.get(i).afterLlmCall(ctx, response); + Decision d = + middlewares.get(i).afterLlmCall(ctx, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); } + return Decision.of(response, current); } - private AgentMiddleware.ToolCallAction dispatchBeforeToolCall( - AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + /** + * Run {@code beforeToolCall} middleware in onion order, folding REPLACE so later middlewares can + * further rewrite. Returns PROCEED if untouched, REPLACE with the final invocation otherwise, or + * ABORT if any middleware denies the call. + */ + private Decision dispatchBeforeToolCall( + AgentRunContext ctx, ToolInvocation call) { + ToolInvocation current = call; for (AgentMiddleware mw : middlewares) { - AgentMiddleware.ToolCallAction action = - mw.beforeToolCall(ctx, toolCallId, toolName, toolArgs); - if (action instanceof AgentMiddleware.ToolCallDenial) return action; + Decision d = mw.beforeToolCall(ctx, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); } - return AgentMiddleware.ToolCallApproval.INSTANCE; + return Decision.of(call, current); } - private String dispatchAfterToolCall( - AgentRunContext ctx, String toolName, Map toolArgs, String result) { + private String dispatchAfterToolCall(AgentRunContext ctx, ToolInvocation call, String result) { String current = result; for (int i = middlewares.size() - 1; i >= 0; i--) { - AgentMiddleware.ToolResultAction action = - middlewares.get(i).afterToolCall(ctx, toolName, toolArgs, current); - if (action instanceof AgentMiddleware.ToolResultReplace) { - current = ((AgentMiddleware.ToolResultReplace) action).replacement(); + Decision d = middlewares.get(i).afterToolCall(ctx, call, current); + if (d.kind() == Decision.Kind.REPLACE) { + current = d.replacement(); + } else if (d.kind() == Decision.Kind.ABORT) { + // afterToolCall.abort: discard result; reason replaces it so the LLM still sees something. + current = d.reason(); } } return current; diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java index 4d5319536a3..1f27d957186 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java @@ -20,7 +20,6 @@ import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; import com.openai.models.chat.completions.ChatCompletionMessageParam; import java.util.List; -import java.util.Map; import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; @@ -29,7 +28,9 @@ * Middleware interface for the Data Agent ReAct loop. Middlewares are executed in onion-model * order: before_* hooks run first-to-last, after_* hooks run last-to-first. * - *

    All hooks have default no-op implementations. Override only what you need. + *

    All hooks have default no-op implementations. Override only what you need. Every interceptor + * hook returns {@link Decision} — see that class for the {@code proceed / replace / abort} + * vocabulary and per-hook semantics of {@code abort}. */ public interface AgentMiddleware { @@ -46,51 +47,39 @@ default void onAgentStart(AgentRunContext ctx) {} /** Called when the agent finishes. Runs last-to-first (cleanup order). */ default void onAgentFinish(AgentRunContext ctx) {} - /** - * Called before each LLM invocation. Runs first-to-last. - * - * @return {@link LlmNoopAction#INSTANCE} to proceed normally, {@link LlmSkip} to abort, or {@link - * LlmModifyMessages} to replace the message list for this call. - */ - default LlmCallAction beforeLlmCall( + /** Called before each LLM invocation. Runs first-to-last. */ + default Decision> beforeLlmCall( AgentRunContext ctx, List messages) { - return LlmNoopAction.INSTANCE; + return Decision.proceed(); } - /** Called after each LLM invocation. Runs last-to-first. */ - default void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {} - /** - * Called before each tool execution. Runs first-to-last. - * - * @return {@link ToolCallApproval#INSTANCE} to allow the call, {@link ToolCallDenial} to block - * it. + * Called after each LLM invocation, before the assistant message is committed to memory or + * inspected for tool calls. Runs last-to-first. Replacing the response lets middleware rewrite + * what enters memory or strip tool calls before they execute. */ - default ToolCallAction beforeToolCall( - AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { - return ToolCallApproval.INSTANCE; + default Decision afterLlmCall( + AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + return Decision.proceed(); } /** - * Called after each tool execution. Runs last-to-first. - * - *

    Middlewares can intercept and transform the tool result before it is fed back to the LLM — - * e.g. for data masking, output truncation, or injecting metadata. - * - * @return {@link ToolResultUnchanged#INSTANCE} to keep the original result, {@link - * ToolResultReplace} to substitute it. + * Called before each tool execution. Runs first-to-last. Replacing the {@link ToolInvocation} + * lets middleware rewrite the tool name or args (e.g. inject a SQL LIMIT, redact parameters) + * before execution. */ - default ToolResultAction afterToolCall( - AgentRunContext ctx, String toolName, Map toolArgs, String result) { - return ToolResultUnchanged.INSTANCE; + default Decision beforeToolCall(AgentRunContext ctx, ToolInvocation call) { + return Decision.proceed(); } - /** - * Called for every event before it is emitted. Return null to suppress the event. Runs - * first-to-last. - */ - default AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { - return event; + /** Called after each tool execution. Runs last-to-first. */ + default Decision afterToolCall(AgentRunContext ctx, ToolInvocation call, String result) { + return Decision.proceed(); + } + + /** Called for every event before it is emitted. Runs first-to-last. */ + default Decision onEvent(AgentRunContext ctx, AgentEvent event) { + return Decision.proceed(); } /** @@ -104,105 +93,4 @@ default void onSessionClose(String sessionId) {} * waiting on this middleware. Dispatched by {@code ReactAgent.stop}. */ default void onStop() {} - - /** - * Base type for {@code beforeLlmCall} return values. Subtypes: {@link LlmNoopAction} to proceed - * normally, {@link LlmSkip} to abort the LLM call, {@link LlmModifyMessages} to replace the - * message list for this call. - */ - abstract class LlmCallAction { - private LlmCallAction() {} - } - - /** Returned from {@code beforeLlmCall} to proceed without changing the LLM call. */ - final class LlmNoopAction extends LlmCallAction { - public static final LlmNoopAction INSTANCE = new LlmNoopAction(); - - private LlmNoopAction() {} - } - - /** Returned from {@code beforeLlmCall} to skip the LLM call and abort the agent loop. */ - class LlmSkip extends LlmCallAction { - private final String reason; - - public LlmSkip(String reason) { - this.reason = reason; - } - - public String reason() { - return reason; - } - } - - /** - * Returned from {@code beforeLlmCall} to replace the message list for this LLM invocation. The - * agent loop continues normally with the modified messages. - */ - class LlmModifyMessages extends LlmCallAction { - private final List messages; - - public LlmModifyMessages(List messages) { - this.messages = messages; - } - - public List messages() { - return messages; - } - } - - /** - * Base type for {@code beforeToolCall} return values. Subtypes: {@link ToolCallApproval} to allow - * the call, {@link ToolCallDenial} to block it. - */ - abstract class ToolCallAction { - private ToolCallAction() {} - } - - /** Returned from {@code beforeToolCall} to allow the tool call to proceed. */ - final class ToolCallApproval extends ToolCallAction { - public static final ToolCallApproval INSTANCE = new ToolCallApproval(); - - private ToolCallApproval() {} - } - - /** Returned from {@code beforeToolCall} to block a tool call. */ - final class ToolCallDenial extends ToolCallAction { - private final String reason; - - public ToolCallDenial(String reason) { - this.reason = reason; - } - - public String reason() { - return reason; - } - } - - /** - * Base type for {@code afterToolCall} return values. Subtypes: {@link ToolResultUnchanged} to - * pass the result through, {@link ToolResultReplace} to substitute it. - */ - abstract class ToolResultAction { - private ToolResultAction() {} - } - - /** Returned from {@code afterToolCall} to keep the tool result as is. */ - final class ToolResultUnchanged extends ToolResultAction { - public static final ToolResultUnchanged INSTANCE = new ToolResultUnchanged(); - - private ToolResultUnchanged() {} - } - - /** Returned from {@code afterToolCall} to replace the tool result with a new string. */ - final class ToolResultReplace extends ToolResultAction { - private final String replacement; - - public ToolResultReplace(String replacement) { - this.replacement = replacement; - } - - public String replacement() { - return replacement; - } - } } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java index f727dba3376..786569dbd12 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java @@ -17,7 +17,6 @@ package org.apache.kyuubi.engine.dataagent.runtime.middleware; -import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -64,41 +63,41 @@ public void onRegister(ToolRegistry registry) { } @Override - public ToolCallAction beforeToolCall( - AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + public Decision beforeToolCall(AgentRunContext ctx, ToolInvocation call) { + String toolName = call.name(); ToolRiskLevel riskLevel = toolRegistry.getRiskLevel(toolName); if (shouldAutoApprove(ctx.getApprovalMode(), riskLevel)) { - return ToolCallApproval.INSTANCE; + return Decision.proceed(); } String requestId = UUID.randomUUID().toString(); CompletableFuture future = new CompletableFuture<>(); pending.put(requestId, future); - ctx.emit(new ApprovalRequest(requestId, toolCallId, toolName, toolArgs, riskLevel)); + ctx.emit(new ApprovalRequest(requestId, call.id(), toolName, call.args(), riskLevel)); LOG.info("Approval requested for tool '{}' (requestId={})", toolName, requestId); try { boolean approved = future.get(timeoutSeconds, TimeUnit.SECONDS); if (!approved) { LOG.info("Tool '{}' denied by user (requestId={})", toolName, requestId); - return new ToolCallDenial("User denied execution of " + toolName); + return Decision.abort("User denied execution of " + toolName); } LOG.info("Tool '{}' approved by user (requestId={})", toolName, requestId); - return ToolCallApproval.INSTANCE; + return Decision.proceed(); } catch (TimeoutException e) { // Complete the future so that a late resolve() call is a harmless no-op // instead of completing a dangling future. future.completeExceptionally(e); LOG.warn("Approval timed out for tool '{}' (requestId={})", toolName, requestId); - return new ToolCallDenial("Approval timed out after " + timeoutSeconds + "s for " + toolName); + return Decision.abort("Approval timed out after " + timeoutSeconds + "s for " + toolName); } catch (InterruptedException e) { Thread.currentThread().interrupt(); - return new ToolCallDenial("Approval interrupted for " + toolName); + return Decision.abort("Approval interrupted for " + toolName); } catch (Exception e) { LOG.error("Unexpected error waiting for approval", e); - return new ToolCallDenial("Approval error: " + e.getMessage()); + return Decision.abort("Approval error: " + e.getMessage()); } finally { pending.remove(requestId); } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java index df0425f5a0a..c82aa9ad03d 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java @@ -146,7 +146,7 @@ public CompactionMiddleware( } @Override - public LlmCallAction beforeLlmCall( + public Decision> beforeLlmCall( AgentRunContext ctx, List messages) { ConversationMemory mem = ctx.getMemory(); // 1) Real token count of the previous LLM call (prompt + completion, i.e. everything through @@ -156,7 +156,7 @@ public LlmCallAction beforeLlmCall( long newTailEstimate = estimateTailAfterLastAssistant(messages); if (lastTotal + newTailEstimate < triggerPromptTokens) { - return LlmNoopAction.INSTANCE; + return Decision.proceed(); } List history = mem.getHistory(); @@ -165,7 +165,7 @@ public LlmCallAction beforeLlmCall( // tool_result. Split split = computeSplit(history, KEEP_RECENT_TURNS); if (split.old.isEmpty()) { - return LlmNoopAction.INSTANCE; + return Decision.proceed(); } String summary = summarize(mem.getSystemPrompt(), split.old); @@ -187,7 +187,7 @@ public LlmCallAction beforeLlmCall( new Compaction( split.old.size(), split.kept.size(), triggerPromptTokens, lastTotal + newTailEstimate)); - return new LlmModifyMessages(mem.buildLlmMessages()); + return Decision.replace(mem.buildLlmMessages()); } /** Call the LLM to produce a summary of {@code oldMessages}. Failures propagate. */ diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java new file mode 100644 index 00000000000..187e30341f1 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +/** + * Outcome of an {@link AgentMiddleware} interceptor hook. Three arms, uniform across all hooks: + * + *

      + *
    • {@link #proceed()} — pass through with the original value + *
    • {@link #replace(Object)} — substitute the value, then continue + *
    • {@link #abort(String)} — stop processing this item; per-hook semantics: + *
        + *
      • {@code beforeLlmCall}: skip the LLM call and end the loop + *
      • {@code beforeToolCall}: deny the call; reason is fed back to the LLM as the result + *
      • {@code afterToolCall}: discard the result; reason replaces it for the LLM + *
      • {@code onEvent}: drop the event + *
      + *
    + */ +public final class Decision { + + public enum Kind { + PROCEED, + REPLACE, + ABORT + } + + private static final Decision PROCEED = new Decision<>(Kind.PROCEED, null, null); + + private final Kind kind; + private final T replacement; + private final String reason; + + private Decision(Kind kind, T replacement, String reason) { + this.kind = kind; + this.replacement = replacement; + this.reason = reason; + } + + @SuppressWarnings("unchecked") + public static Decision proceed() { + return (Decision) PROCEED; + } + + public static Decision replace(T value) { + if (value == null) { + throw new IllegalArgumentException("replace value must not be null"); + } + return new Decision<>(Kind.REPLACE, value, null); + } + + public static Decision abort(String reason) { + if (reason == null) { + throw new IllegalArgumentException("abort reason must not be null"); + } + return new Decision<>(Kind.ABORT, null, reason); + } + + /** + * Fold helper for middleware dispatchers: PROCEED if {@code current} is still the original + * reference (no middleware replaced anything), REPLACE otherwise. + */ + public static Decision of(T original, T current) { + return current == original ? proceed() : replace(current); + } + + public Kind kind() { + return kind; + } + + /** Non-null only when {@link #kind()} is {@link Kind#REPLACE}. */ + public T replacement() { + return replacement; + } + + /** Non-null only when {@link #kind()} is {@link Kind#ABORT}. */ + public String reason() { + return reason; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java index ec96b2e71ea..ef12e42758a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java @@ -20,7 +20,6 @@ import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; import com.openai.models.chat.completions.ChatCompletionMessageParam; import java.util.List; -import java.util.Map; import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError; import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; @@ -95,14 +94,15 @@ public void onAgentFinish(AgentRunContext ctx) { } @Override - public LlmCallAction beforeLlmCall( + public Decision> beforeLlmCall( AgentRunContext ctx, List messages) { LOG.info("{}LLM call: step={}, messages={}", prefix(), ctx.getIteration(), messages.size()); - return LlmNoopAction.INSTANCE; + return Decision.proceed(); } @Override - public void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + public Decision afterLlmCall( + AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { String content = response.content().map(Object::toString).orElse(""); int toolCallCount = response.toolCalls().map(List::size).orElse(0); LOG.info( @@ -115,25 +115,24 @@ public void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessagePara ctx.getPromptTokens(), ctx.getCompletionTokens(), ctx.getTotalTokens()); + return Decision.proceed(); } @Override - public ToolCallAction beforeToolCall( - AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { - LOG.info("{}Tool call: id={}, name={}", prefix(), toolCallId, toolName); - LOG.debug("{}Tool args: {}", prefix(), toolArgs); - return ToolCallApproval.INSTANCE; + public Decision beforeToolCall(AgentRunContext ctx, ToolInvocation call) { + LOG.info("{}Tool call: id={}, name={}", prefix(), call.id(), call.name()); + LOG.debug("{}Tool args: {}", prefix(), call.args()); + return Decision.proceed(); } @Override - public ToolResultAction afterToolCall( - AgentRunContext ctx, String toolName, Map toolArgs, String result) { - LOG.info("{}Tool result: {} -> \"{}\"", prefix(), toolName, truncate(result)); - return ToolResultUnchanged.INSTANCE; + public Decision afterToolCall(AgentRunContext ctx, ToolInvocation call, String result) { + LOG.info("{}Tool result: {} -> \"{}\"", prefix(), call.name(), truncate(result)); + return Decision.proceed(); } @Override - public AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { + public Decision onEvent(AgentRunContext ctx, AgentEvent event) { switch (event.eventType()) { case STEP_START: LOG.info("{}Step {}", prefix(), ((StepStart) event).stepNumber()); @@ -150,7 +149,7 @@ public AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { default: break; } - return event; + return Decision.proceed(); } private static String truncate(String s) { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolInvocation.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolInvocation.java new file mode 100644 index 00000000000..d2909cceeda --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolInvocation.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import java.util.Collections; +import java.util.Map; + +/** A single tool invocation: id, name, parsed arguments. Immutable. */ +public final class ToolInvocation { + private final String id; + private final String name; + private final Map args; + + public ToolInvocation(String id, String name, Map args) { + this.id = id; + this.name = name; + this.args = Collections.unmodifiableMap(args); + } + + public String id() { + return id; + } + + public String name() { + return name; + } + + public Map args() { + return args; + } + + public ToolInvocation withArgs(Map newArgs) { + return new ToolInvocation(id, name, newArgs); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java index 14b2d4ba7a2..f3e87b7ecb6 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java @@ -22,7 +22,6 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.HashSet; -import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; @@ -78,21 +77,21 @@ public void onRegister(ToolRegistry registry) { } @Override - public ToolResultAction afterToolCall( - AgentRunContext ctx, String toolName, Map toolArgs, String result) { - if (result.isEmpty()) return ToolResultUnchanged.INSTANCE; - if (EXEMPT_TOOLS.contains(toolName)) return ToolResultUnchanged.INSTANCE; + public Decision afterToolCall(AgentRunContext ctx, ToolInvocation call, String result) { + if (result.isEmpty()) return Decision.proceed(); + String toolName = call.name(); + if (EXEMPT_TOOLS.contains(toolName)) return Decision.proceed(); int bytes = result.getBytes(StandardCharsets.UTF_8).length; int lines = countLines(result); if (lines <= MAX_LINES && bytes <= MAX_BYTES) { - return ToolResultUnchanged.INSTANCE; + return Decision.proceed(); } // AgentRunContext.sessionId is null in unit-test constructions that don't exercise offload. // In production the provider always threads it through, so treat null as "skip offload". String sessionId = ctx.getSessionId(); - if (sessionId == null) return ToolResultUnchanged.INSTANCE; + if (sessionId == null) return Decision.proceed(); long n = counters.computeIfAbsent(sessionId, k -> new AtomicLong()).incrementAndGet(); String toolCallId = toolName + "_" + n; @@ -106,7 +105,7 @@ public ToolResultAction afterToolCall( toolName, sessionId, e); - return ToolResultUnchanged.INSTANCE; + return Decision.proceed(); } LOG.info( @@ -116,7 +115,7 @@ public ToolResultAction afterToolCall( lines, bytes, file.getFileName()); - return new ToolResultReplace(buildPreview(result, lines, bytes, file)); + return Decision.replace(buildPreview(result, lines, bytes, file)); } /** Clean up counter and temp dir for a closed session. Idempotent. */ diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java index c7a7916db37..dae9c589a19 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java @@ -60,12 +60,10 @@ public void testAutoApproveModeSkipsAllApproval() { ApprovalMiddleware mw = newApprovalMiddleware(); AgentRunContext ctx = makeContext(ApprovalMode.AUTO_APPROVE); - assertSame( - AgentMiddleware.ToolCallApproval.INSTANCE, - mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); - assertSame( - AgentMiddleware.ToolCallApproval.INSTANCE, - mw.beforeToolCall(ctx, "tc2", "safe_tool", Collections.emptyMap())); + assertEquals( + Decision.Kind.PROCEED, mw.beforeToolCall(ctx, invocation("tc1", "dangerous_tool")).kind()); + assertEquals( + Decision.Kind.PROCEED, mw.beforeToolCall(ctx, invocation("tc2", "safe_tool")).kind()); assertTrue("No approval events should be emitted", emittedEvents.isEmpty()); } @@ -76,9 +74,8 @@ public void testNormalModeAutoApprovesSafeTool() { ApprovalMiddleware mw = newApprovalMiddleware(); AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); - assertSame( - AgentMiddleware.ToolCallApproval.INSTANCE, - mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + assertEquals( + Decision.Kind.PROCEED, mw.beforeToolCall(ctx, invocation("tc1", "safe_tool")).kind()); assertTrue(emittedEvents.isEmpty()); } @@ -97,9 +94,8 @@ public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception eventEmitted.countDown(); }); - Future future = - exec.submit( - () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + Future> future = + exec.submit(() -> mw.beforeToolCall(ctx, invocation("tc1", "dangerous_tool"))); // Wait for the approval request event assertTrue("Approval event should be emitted", eventEmitted.await(2, TimeUnit.SECONDS)); @@ -112,10 +108,10 @@ public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception // Approve assertTrue(mw.resolve(req.requestId(), true)); - assertSame( - "Approved tool should return null (no denial)", - AgentMiddleware.ToolCallApproval.INSTANCE, - future.get(2, TimeUnit.SECONDS)); + assertEquals( + "Approved tool should proceed", + Decision.Kind.PROCEED, + future.get(2, TimeUnit.SECONDS).kind()); } finally { exec.shutdownNow(); } @@ -135,18 +131,17 @@ public void testDeniedToolReturnsToolCallDenial() throws Exception { eventEmitted.countDown(); }); - Future future = - exec.submit( - () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + Future> future = + exec.submit(() -> mw.beforeToolCall(ctx, invocation("tc1", "dangerous_tool"))); assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); // Deny assertTrue(mw.resolve(req.requestId(), false)); - AgentMiddleware.ToolCallAction action = future.get(2, TimeUnit.SECONDS); - assertTrue(action instanceof AgentMiddleware.ToolCallDenial); - assertTrue(((AgentMiddleware.ToolCallDenial) action).reason().contains("denied")); + Decision decision = future.get(2, TimeUnit.SECONDS); + assertEquals(Decision.Kind.ABORT, decision.kind()); + assertTrue(decision.reason().contains("denied")); } finally { exec.shutdownNow(); } @@ -168,15 +163,15 @@ public void testStrictModeRequiresApprovalForSafeTool() throws Exception { eventEmitted.countDown(); }); - Future future = - exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + Future> future = + exec.submit(() -> mw.beforeToolCall(ctx, invocation("tc1", "safe_tool"))); assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); assertEquals("safe_tool", req.toolName()); assertTrue(mw.resolve(req.requestId(), true)); - assertSame(AgentMiddleware.ToolCallApproval.INSTANCE, future.get(2, TimeUnit.SECONDS)); + assertEquals(Decision.Kind.PROCEED, future.get(2, TimeUnit.SECONDS).kind()); } finally { exec.shutdownNow(); } @@ -192,14 +187,13 @@ public void testApprovalTimeoutReturnsDenial() throws Exception { ExecutorService exec = Executors.newSingleThreadExecutor(); try { - Future future = - exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + Future> future = + exec.submit(() -> mw.beforeToolCall(ctx, invocation("tc1", "safe_tool"))); // Don't resolve — let it time out - AgentMiddleware.ToolCallAction action = future.get(5, TimeUnit.SECONDS); - assertTrue( - "Timeout should produce a denial", action instanceof AgentMiddleware.ToolCallDenial); - assertTrue(((AgentMiddleware.ToolCallDenial) action).reason().contains("timed out")); + Decision decision = future.get(5, TimeUnit.SECONDS); + assertEquals("Timeout should produce a denial", Decision.Kind.ABORT, decision.kind()); + assertTrue(decision.reason().contains("timed out")); } finally { exec.shutdownNow(); } @@ -216,11 +210,11 @@ public void testOnStopUnblocksPendingRequests() throws Exception { ExecutorService exec = Executors.newSingleThreadExecutor(); try { CountDownLatch started = new CountDownLatch(1); - Future future = + Future> future = exec.submit( () -> { started.countDown(); - return mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap()); + return mw.beforeToolCall(ctx, invocation("tc1", "safe_tool")); }); assertTrue(started.await(2, TimeUnit.SECONDS)); @@ -228,9 +222,8 @@ public void testOnStopUnblocksPendingRequests() throws Exception { mw.onStop(); - AgentMiddleware.ToolCallAction action = future.get(2, TimeUnit.SECONDS); - assertTrue( - "onStop should unblock with a denial", action instanceof AgentMiddleware.ToolCallDenial); + Decision decision = future.get(2, TimeUnit.SECONDS); + assertEquals("onStop should unblock with a denial", Decision.Kind.ABORT, decision.kind()); } finally { exec.shutdownNow(); } @@ -256,6 +249,10 @@ private AgentRunContext makeContext(ApprovalMode mode) { return ctx; } + private static ToolInvocation invocation(String id, String name) { + return new ToolInvocation(id, name, Collections.emptyMap()); + } + private static AgentTool safeTool(String name) { return new DummyTool(name, ToolRiskLevel.SAFE); } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java index 49187623bd1..055e3065828 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java @@ -80,10 +80,10 @@ public void compactsHistoryWhenThresholdCrossed() { CompactionMiddleware mw = new CompactionMiddleware(client, MODEL_NAME, /* trigger */ 50_000L); - AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, memory.buildLlmMessages()); + Decision> decision = + mw.beforeLlmCall(ctx, memory.buildLlmMessages()); - assertNotNull("expected compaction to fire", action); - assertTrue(action instanceof AgentMiddleware.LlmModifyMessages); + assertEquals("expected compaction to fire", Decision.Kind.REPLACE, decision.kind()); // History got rewritten: [summary user msg] + kept tail. List hist = memory.getHistory(); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java index 2e8940422cc..37b35f04753 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java @@ -172,8 +172,7 @@ public void belowThresholdReturnsNull() { ctx.addTokenUsage(1000, 0, 1000); CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); - assertSame( - AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertEquals(Decision.Kind.PROCEED, mw.beforeLlmCall(ctx, memory.buildLlmMessages()).kind()); // Nothing was mutated. assertEquals(6, memory.size()); } @@ -189,8 +188,7 @@ public void aboveThresholdButHistoryTooShortReturnsNull() { ctx.addTokenUsage(60_000, 0, 60_000); CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); - assertSame( - AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertEquals(Decision.Kind.PROCEED, mw.beforeLlmCall(ctx, memory.buildLlmMessages()).kind()); assertEquals(3, memory.size()); } @@ -205,12 +203,10 @@ public void triggerUsesLastCallTotalNotCumulative() { CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); ctx.addTokenUsage(4_000, 1_000, 5_000); - assertSame( - AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertEquals(Decision.Kind.PROCEED, mw.beforeLlmCall(ctx, memory.buildLlmMessages()).kind()); ctx.addTokenUsage(8_000, 2_000, 10_000); - assertSame( - AgentMiddleware.LlmNoopAction.INSTANCE, mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertEquals(Decision.Kind.PROCEED, mw.beforeLlmCall(ctx, memory.buildLlmMessages()).kind()); assertEquals(10_000L, memory.getLastTotalTokens()); assertEquals(15_000L, memory.getCumulativeTotalTokens()); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java index 86f67fcc642..562014f8fd2 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java @@ -51,9 +51,9 @@ public void tearDown() { @Test public void underThresholdPassesThrough() { String small = "row1\nrow2\nrow3\n"; - assertSame( - AgentMiddleware.ToolResultUnchanged.INSTANCE, - mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), small)); + assertEquals( + Decision.Kind.PROCEED, + mw.afterToolCall(ctxWithSession, invocation("run_select_query"), small).kind()); } @Test @@ -62,8 +62,7 @@ public void overLineThresholdTriggersOffload() { for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); String out = replacement( - mw.afterToolCall( - ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); + mw.afterToolCall(ctxWithSession, invocation("run_select_query"), sb.toString())); assertTrue(out, out.contains("Tool output truncated")); assertTrue(out, out.contains("Saved to:")); @@ -83,8 +82,7 @@ public void overByteThresholdTriggersOffload() { } String out = replacement( - mw.afterToolCall( - ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); + mw.afterToolCall(ctxWithSession, invocation("run_select_query"), sb.toString())); assertTrue(out, out.contains("Tool output truncated")); } @@ -92,31 +90,31 @@ public void overByteThresholdTriggersOffload() { public void retrievalToolsAreExemptFromGate() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < 2000; i++) sb.append("row").append(i).append('\n'); - assertSame( - AgentMiddleware.ToolResultUnchanged.INSTANCE, - mw.afterToolCall( - ctxWithSession, ReadToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); - assertSame( - AgentMiddleware.ToolResultUnchanged.INSTANCE, - mw.afterToolCall( - ctxWithSession, GrepToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); + assertEquals( + Decision.Kind.PROCEED, + mw.afterToolCall(ctxWithSession, invocation(ReadToolOutputTool.NAME), sb.toString()) + .kind()); + assertEquals( + Decision.Kind.PROCEED, + mw.afterToolCall(ctxWithSession, invocation(GrepToolOutputTool.NAME), sb.toString()) + .kind()); } @Test public void missingSessionIdPassesThrough() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < 1000; i++) sb.append("row").append(i).append('\n'); - assertSame( + assertEquals( "without sessionId, cannot offload safely — pass through", - AgentMiddleware.ToolResultUnchanged.INSTANCE, - mw.afterToolCall(ctxNoSession, "run_select_query", Collections.emptyMap(), sb.toString())); + Decision.Kind.PROCEED, + mw.afterToolCall(ctxNoSession, invocation("run_select_query"), sb.toString()).kind()); } @Test public void onSessionCloseClearsCounterAndFiles() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); - mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + mw.afterToolCall(ctxWithSession, invocation("run_select_query"), sb.toString()); assertEquals(1, mw.trackedSessions()); mw.onSessionClose("sess-1"); @@ -129,23 +127,24 @@ public void multipleOffloadsReuseSameSessionDir() { for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); String out1 = replacement( - mw.afterToolCall( - ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); + mw.afterToolCall(ctxWithSession, invocation("run_select_query"), sb.toString())); String out2 = replacement( - mw.afterToolCall( - ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString())); + mw.afterToolCall(ctxWithSession, invocation("run_select_query"), sb.toString())); // Both previews reference the same session dir, different file names. assertNotEquals(extractPath(out1), extractPath(out2)); assertTrue(extractPath(out1).contains("sess-1")); assertTrue(extractPath(out2).contains("sess-1")); } - private static String replacement(AgentMiddleware.ToolResultAction action) { - assertTrue( - "expected ToolResultReplace but got " + action, - action instanceof AgentMiddleware.ToolResultReplace); - return ((AgentMiddleware.ToolResultReplace) action).replacement(); + private static ToolInvocation invocation(String name) { + return new ToolInvocation(name + "_id", name, Collections.emptyMap()); + } + + private static String replacement(Decision decision) { + assertEquals( + "expected REPLACE but got " + decision.kind(), Decision.Kind.REPLACE, decision.kind()); + return decision.replacement(); } private static String extractPath(String preview) { From 9fe962509ac1f015a043c1a545c5c60e71b38bea Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Fri, 1 May 2026 02:15:51 +0800 Subject: [PATCH 09/10] [KYUUBI #7379][2b/4][FOLLOWUP] Instruct LLM to SELECT after UPDATE in approval live test testApprovalApproveFlow asked the model to increment a counter and return the new value, but UPDATE returns no value, so weaker models (e.g. kimi-k2.5) hallucinated "0" instead of running a follow-up SELECT. Make the instruction explicit so behavior converges across models. --- .../kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java index e21410a99bf..95d0a083cca 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java @@ -356,8 +356,9 @@ public void testApprovalApproveFlow() throws Exception { try { agent.run( new AgentInvocation( - "Increment the 'hits' counter in the counters table by 1, then tell me its" - + " new value. Respond with ONLY the new value, no explanation."), + "Increment the 'hits' counter in the counters table by 1. After the update" + + " succeeds, run a SELECT query to read the new value back from the" + + " database, then respond with ONLY that value, no explanation."), memory, listener); } finally { From 1aac6aaa93f57854f736a2ee3db8e892d842a289 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Sat, 2 May 2026 13:50:19 +0800 Subject: [PATCH 10/10] [KYUUBI #7379][2b/4][FOLLOWUP] Split ReactAgent into LlmStreamClient + composite MiddlewareDispatcher MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ReactAgent had grown to mix three concerns: the ReAct control loop, OpenAI streaming/chunk assembly, and middleware fold logic. Extract: - LlmStreamClient: owns one streaming chat completion call, accumulates content + tool-call deltas, and exposes StreamResult.toAssistantMessage for SDK message construction. Depends only on the OpenAI SDK and AgentRunContext (emits ContentDelta via ctx.emit, no dispatcher reference). - MiddlewareDispatcher: implements AgentMiddleware as a composite over the configured list. ReactAgent calls onAgentStart / onEvent / beforeLlmCall etc. on the composite the same way it would call any middleware; resolveApproval stays as a non-interface accessor for the approval flow's special case. Also: afterToolCall now returns Decision for symmetry with the other interceptor hooks; ABORT marks ToolResult.isError=true so listeners can distinguish a middleware-vetoed result from a successful one. The emit-then-forward step splits cleanly: the composite runs onEvent, and ReactAgent's ctx.setEventEmitter lambda forwards the filtered event to the user's raw consumer. ReactAgent's run() drops the eventConsumer parameter threading through internal helpers — everywhere downstream uses ctx.emit(). --- .../dataagent/runtime/LlmStreamClient.java | 188 +++++++++ .../runtime/MiddlewareDispatcher.java | 198 +++++++++ .../engine/dataagent/runtime/ReactAgent.java | 397 +++--------------- .../runtime/middleware/Decision.java | 4 +- 4 files changed, 452 insertions(+), 335 deletions(-) create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/LlmStreamClient.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/MiddlewareDispatcher.java diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/LlmStreamClient.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/LlmStreamClient.java new file mode 100644 index 00000000000..8d8d6494aaf --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/LlmStreamClient.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.openai.client.OpenAIClient; +import com.openai.core.http.StreamResponse; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionChunk; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionStreamOptions; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Streams one chat completion call and assembles assistant content plus streamed tool calls. */ +final class LlmStreamClient { + + private static final Logger LOG = LoggerFactory.getLogger(LlmStreamClient.class); + + private final OpenAIClient client; + private final ToolRegistry toolRegistry; + + LlmStreamClient(OpenAIClient client, ToolRegistry toolRegistry) { + this.client = client; + this.toolRegistry = toolRegistry; + } + + /** + * Stream LLM response, emitting ContentDelta through {@code ctx} for each text chunk. Assembles + * tool calls directly from streamed chunks with no non-streaming fallback. + */ + StreamResult stream( + AgentRunContext ctx, List messages, String effectiveModel) { + ChatCompletionCreateParams.Builder paramsBuilder = + ChatCompletionCreateParams.builder() + .model(effectiveModel) + .streamOptions(ChatCompletionStreamOptions.builder().includeUsage(true).build()); + for (ChatCompletionMessageParam msg : messages) { + paramsBuilder.addMessage(msg); + } + toolRegistry.addToolsTo(paramsBuilder); + + LOG.info("LLM request: model={}", effectiveModel); + StreamAccumulator acc = new StreamAccumulator(); + try (StreamResponse stream = + client.chat().completions().createStreaming(paramsBuilder.build())) { + stream.stream().forEach(chunk -> consumeChunk(ctx, chunk, acc)); + } + return new StreamResult(acc.content.toString(), acc.buildToolCalls()); + } + + /** Fold one streaming chunk into {@code acc}, emitting per-token {@link ContentDelta}s. */ + private void consumeChunk(AgentRunContext ctx, ChatCompletionChunk chunk, StreamAccumulator acc) { + if (!acc.serverModelLogged) { + LOG.info("LLM response: server-echoed model={}", chunk.model()); + acc.serverModelLogged = true; + } + chunk + .usage() + .ifPresent(u -> ctx.addTokenUsage(u.promptTokens(), u.completionTokens(), u.totalTokens())); + + for (ChatCompletionChunk.Choice c : chunk.choices()) { + c.delta() + .content() + .ifPresent( + text -> { + acc.content.append(text); + ctx.emit(new ContentDelta(text)); + }); + c.delta().toolCalls().ifPresent(acc::mergeToolCallDeltas); + } + } + + /** + * Mutable accumulator for a single streaming LLM turn. Tool call fields are keyed by the chunk's + * {@code index} because provider SDKs may deliver a single logical call across multiple chunks + * and only surface the {@code id}/{@code name} on the first one. + */ + private static final class StreamAccumulator { + final StringBuilder content = new StringBuilder(); + final Map toolCallIds = new HashMap<>(); + final Map toolCallNames = new HashMap<>(); + final Map toolCallArgs = new HashMap<>(); + boolean serverModelLogged = false; + + void mergeToolCallDeltas(List deltas) { + for (ChatCompletionChunk.Choice.Delta.ToolCall tc : deltas) { + int idx = (int) tc.index(); + tc.id().ifPresent(id -> toolCallIds.put(idx, id)); + tc.function() + .ifPresent( + fn -> { + fn.name().ifPresent(name -> toolCallNames.put(idx, name)); + fn.arguments() + .ifPresent( + args -> + toolCallArgs + .computeIfAbsent(idx, k -> new StringBuilder()) + .append(args)); + }); + } + } + + /** + * Materialize accumulated deltas into SDK tool-call objects. Returns {@code null} (not an empty + * list) if no tool calls were seen, matching the existing {@link StreamResult} contract. + */ + List buildToolCalls() { + if (toolCallIds.isEmpty()) return null; + List out = new ArrayList<>(toolCallIds.size()); + for (Map.Entry e : toolCallIds.entrySet()) { + int idx = e.getKey(); + String id = (e.getValue() == null || e.getValue().isEmpty()) ? synthId() : e.getValue(); + String args = toolCallArgs.containsKey(idx) ? toolCallArgs.get(idx).toString() : "{}"; + out.add( + ChatCompletionMessageToolCall.ofFunction( + ChatCompletionMessageFunctionToolCall.builder() + .id(id) + .function( + ChatCompletionMessageFunctionToolCall.Function.builder() + .name(toolCallNames.getOrDefault(idx, "")) + .arguments(args) + .build()) + .build())); + } + return out; + } + + /** + * Synthesize an id for tool calls whose id never arrived on the stream (some OpenAI-compatible + * providers omit it). The id has to be stable within a turn and unique across turns so the + * assistant/tool_result pairing downstream holds. + */ + private static String synthId() { + return "local_" + java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 24); + } + } + + /** Result of a streaming LLM call, assembled from chunks. */ + static final class StreamResult { + final String content; + final List toolCalls; + + StreamResult(String content, List toolCalls) { + this.content = content; + this.toolCalls = toolCalls; + } + + boolean isEmpty() { + return content.isEmpty() && (toolCalls == null || toolCalls.isEmpty()); + } + + /** Build the SDK assistant message corresponding to this streamed result. */ + ChatCompletionAssistantMessageParam toAssistantMessage() { + ChatCompletionAssistantMessageParam.Builder b = ChatCompletionAssistantMessageParam.builder(); + if (!content.isEmpty()) { + b.content(content); + } + if (toolCalls != null && !toolCalls.isEmpty()) { + b.toolCalls(toolCalls); + } + return b.build(); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/MiddlewareDispatcher.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/MiddlewareDispatcher.java new file mode 100644 index 00000000000..9e406e40e3c --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/MiddlewareDispatcher.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.Decision; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolInvocation; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Composite {@link AgentMiddleware} — folds a list of middlewares into one. Hook ordering follows + * the onion model: {@code before*} / {@code on*Start} run first-to-last, {@code after*} / {@code + * on*Finish} run last-to-first. + * + *

    Component middlewares are internal framework code. If one throws during ordinary hook + * dispatch, the agent run fails via {@link ReactAgent#run}; lifecycle cleanup hooks ({@link + * #onAgentFinish}, {@link #onSessionClose}, {@link #onStop}) swallow exceptions so later + * middlewares still get a chance to release state. + */ +final class MiddlewareDispatcher implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(MiddlewareDispatcher.class); + + private final List middlewares; + private final ApprovalMiddleware approvalMiddleware; + + MiddlewareDispatcher(List middlewares) { + this.middlewares = middlewares; + this.approvalMiddleware = findApprovalMiddleware(middlewares); + } + + /** + * Resolve a pending approval request. Not part of {@link AgentMiddleware} — special accessor for + * the approval flow. + */ + boolean resolveApproval(String requestId, boolean approved) { + if (approvalMiddleware == null) return false; + return approvalMiddleware.resolve(requestId, approved); + } + + @Override + public void onRegister(ToolRegistry registry) { + for (AgentMiddleware mw : middlewares) { + mw.onRegister(registry); + } + } + + @Override + public void onAgentStart(AgentRunContext ctx) { + for (AgentMiddleware mw : middlewares) { + mw.onAgentStart(ctx); + } + } + + @Override + public void onAgentFinish(AgentRunContext ctx) { + // Runs even when the agent body threw, so swallow here to ensure every middleware's cleanup + // gets a chance to run; otherwise we'd leak session state in later middlewares. + for (int i = middlewares.size() - 1; i >= 0; i--) { + try { + middlewares.get(i).onAgentFinish(ctx); + } catch (Exception e) { + LOG.warn("Middleware onAgentFinish error", e); + } + } + } + + @Override + public void onSessionClose(String sessionId) { + for (AgentMiddleware mw : middlewares) { + try { + mw.onSessionClose(sessionId); + } catch (Exception e) { + LOG.warn("Middleware onSessionClose error", e); + } + } + } + + @Override + public void onStop() { + for (AgentMiddleware mw : middlewares) { + try { + mw.onStop(); + } catch (Exception e) { + LOG.warn("Middleware onStop error", e); + } + } + } + + /** + * Fold {@code onEvent} in onion order. Returns PROCEED if untouched, REPLACE with the final event + * if any middleware rewrote it, or ABORT if any short-circuited. + */ + @Override + public Decision onEvent(AgentRunContext ctx, AgentEvent event) { + AgentEvent current = event; + for (AgentMiddleware mw : middlewares) { + Decision d = mw.onEvent(ctx, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); + } + return Decision.of(event, current); + } + + /** + * Fold {@code beforeLlmCall} in onion order so later middlewares see rewritten messages. Returns + * PROCEED if untouched, REPLACE with the final value if any did, or ABORT if any short-circuited. + */ + @Override + public Decision> beforeLlmCall( + AgentRunContext ctx, List messages) { + List current = messages; + for (AgentMiddleware mw : middlewares) { + Decision> d = mw.beforeLlmCall(ctx, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); + } + return Decision.of(messages, current); + } + + /** + * Fold {@code afterLlmCall} in reverse onion order so earlier middlewares see rewritten + * responses. Returns the final response, or ABORT if any middleware short-circuits. + */ + @Override + public Decision afterLlmCall( + AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + ChatCompletionAssistantMessageParam current = response; + for (int i = middlewares.size() - 1; i >= 0; i--) { + Decision d = + middlewares.get(i).afterLlmCall(ctx, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); + } + return Decision.of(response, current); + } + + /** + * Fold {@code beforeToolCall} in onion order so later middlewares can further rewrite. Returns + * PROCEED if untouched, REPLACE with the final invocation otherwise, or ABORT if any middleware + * denies the call. + */ + @Override + public Decision beforeToolCall(AgentRunContext ctx, ToolInvocation call) { + ToolInvocation current = call; + for (AgentMiddleware mw : middlewares) { + Decision d = mw.beforeToolCall(ctx, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); + } + return Decision.of(call, current); + } + + /** + * Fold {@code afterToolCall} in reverse onion order so earlier middlewares see rewritten results. + * Returns the final result, or ABORT if any middleware short-circuits — caller decides how to + * surface the abort (typically: use {@code reason()} as the result text the LLM sees). + */ + @Override + public Decision afterToolCall(AgentRunContext ctx, ToolInvocation call, String result) { + String current = result; + for (int i = middlewares.size() - 1; i >= 0; i--) { + Decision d = middlewares.get(i).afterToolCall(ctx, call, current); + if (d.kind() == Decision.Kind.ABORT) return d; + if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); + } + return Decision.of(result, current); + } + + private static ApprovalMiddleware findApprovalMiddleware(List middlewares) { + for (AgentMiddleware mw : middlewares) { + if (mw instanceof ApprovalMiddleware) return (ApprovalMiddleware) mw; + } + return null; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java index 4ddf484a44f..b5138e9c661 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java @@ -20,14 +20,10 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.openai.client.OpenAIClient; -import com.openai.core.http.StreamResponse; import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; -import com.openai.models.chat.completions.ChatCompletionChunk; -import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; import com.openai.models.chat.completions.ChatCompletionMessageParam; import com.openai.models.chat.completions.ChatCompletionMessageToolCall; -import com.openai.models.chat.completions.ChatCompletionStreamOptions; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -39,13 +35,11 @@ import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish; import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart; import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete; -import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd; import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall; import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware; -import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; import org.apache.kyuubi.engine.dataagent.runtime.middleware.Decision; import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolInvocation; import org.apache.kyuubi.engine.dataagent.tool.ToolContext; @@ -65,11 +59,10 @@ public class ReactAgent { private static final Logger LOG = LoggerFactory.getLogger(ReactAgent.class); private static final ObjectMapper JSON = new ObjectMapper(); - private final OpenAIClient client; private final String defaultModelName; private final ToolRegistry toolRegistry; - private final List middlewares; - private final ApprovalMiddleware approvalMiddleware; + private final MiddlewareDispatcher dispatcher; + private final LlmStreamClient llmStreamClient; private final int maxIterations; private final String systemPrompt; @@ -80,48 +73,27 @@ public ReactAgent( List middlewares, int maxIterations, String systemPrompt) { - this.client = client; this.defaultModelName = modelName; this.toolRegistry = toolRegistry; - this.middlewares = middlewares; - this.approvalMiddleware = findApprovalMiddleware(middlewares); + this.dispatcher = new MiddlewareDispatcher(middlewares); + this.llmStreamClient = new LlmStreamClient(client, toolRegistry); this.maxIterations = maxIterations; this.systemPrompt = systemPrompt; } - private static ApprovalMiddleware findApprovalMiddleware(List middlewares) { - for (AgentMiddleware mw : middlewares) { - if (mw instanceof ApprovalMiddleware) return (ApprovalMiddleware) mw; - } - return null; - } - /** Resolve a pending approval request. Returns false if no pending request matches. */ public boolean resolveApproval(String requestId, boolean approved) { - if (approvalMiddleware == null) return false; - return approvalMiddleware.resolve(requestId, approved); + return dispatcher.resolveApproval(requestId, approved); } /** Fan out session-close to every middleware. Errors in one middleware don't block others. */ public void closeSession(String sessionId) { - for (AgentMiddleware mw : middlewares) { - try { - mw.onSessionClose(sessionId); - } catch (Exception e) { - LOG.warn("Middleware onSessionClose error", e); - } - } + dispatcher.onSessionClose(sessionId); } /** Fan out engine-stop to every middleware. Errors in one middleware don't block others. */ public void stop() { - for (AgentMiddleware mw : middlewares) { - try { - mw.onStop(); - } catch (Exception e) { - LOG.warn("Middleware onStop error", e); - } - } + dispatcher.onStop(); } /** @@ -150,41 +122,45 @@ public void run( memory.addUserMessage(userInput); AgentRunContext ctx = new AgentRunContext(memory, approvalMode, request.getSessionId()); - ctx.setEventEmitter(event -> emit(ctx, event, eventConsumer)); - dispatchAgentStart(ctx); - emit(ctx, new AgentStart(), eventConsumer); + // Wire the event pipeline: ctx.emit -> middleware.onEvent -> raw consumer. + // Splitting filter and forward keeps the middleware composite ignorant of the consumer. + ctx.setEventEmitter( + event -> { + Decision d = dispatcher.onEvent(ctx, event); + if (d.kind() == Decision.Kind.ABORT) return; + eventConsumer.accept(d.kind() == Decision.Kind.REPLACE ? d.replacement() : event); + }); + dispatcher.onAgentStart(ctx); + ctx.emit(new AgentStart()); try { for (int step = 1; step <= maxIterations; step++) { ctx.setIteration(step); - emit(ctx, new StepStart(step), eventConsumer); + ctx.emit(new StepStart(step)); List messages = - resolveMessagesForCall(ctx, memory.buildLlmMessages(), eventConsumer); + resolveMessagesForCall(ctx, memory.buildLlmMessages()); if (messages == null) { // Middleware asked us to skip — AgentError + AgentFinish have already been emitted. return; } - StreamResult result = streamLlmResponse(ctx, messages, effectiveModel, eventConsumer); + LlmStreamClient.StreamResult result = llmStreamClient.stream(ctx, messages, effectiveModel); if (result.isEmpty()) { - emit(ctx, new AgentError("LLM returned empty response"), eventConsumer); - emitFinish(ctx, eventConsumer); + ctx.emit(new AgentError("LLM returned empty response")); + emitFinish(ctx); return; } if (!result.content.isEmpty()) { - emit(ctx, new ContentComplete(result.content), eventConsumer); + ctx.emit(new ContentComplete(result.content)); } - ChatCompletionAssistantMessageParam assistantMsg = buildAssistantMessage(result); + ChatCompletionAssistantMessageParam assistantMsg = result.toAssistantMessage(); Decision after = - dispatchAfterLlmCall(ctx, assistantMsg); + dispatcher.afterLlmCall(ctx, assistantMsg); if (after.kind() == Decision.Kind.ABORT) { - emit( - ctx, - new AgentError("LLM response rejected by middleware: " + after.reason()), - eventConsumer); - emitFinish(ctx, eventConsumer); + ctx.emit(new AgentError("LLM response rejected by middleware: " + after.reason())); + emitFinish(ctx); return; } if (after.kind() == Decision.Kind.REPLACE) assistantMsg = after.replacement(); @@ -193,61 +169,52 @@ public void run( List toolCalls = assistantMsg.toolCalls().orElse(null); if (toolCalls == null || toolCalls.isEmpty()) { // No tool calls — agent is done. - emit(ctx, new StepEnd(step), eventConsumer); - emitFinish(ctx, eventConsumer); + ctx.emit(new StepEnd(step)); + emitFinish(ctx); return; } - executeToolCalls(ctx, memory, toolCalls, eventConsumer); - emit(ctx, new StepEnd(step), eventConsumer); + executeToolCalls(ctx, memory, toolCalls); + ctx.emit(new StepEnd(step)); } - emit( - ctx, new AgentError("Reached maximum iterations (" + maxIterations + ")"), eventConsumer); - emitFinish(ctx, eventConsumer); + ctx.emit(new AgentError("Reached maximum iterations (" + maxIterations + ")")); + emitFinish(ctx); } catch (Exception e) { LOG.error("Agent run failed", e); - emit( - ctx, new AgentError(e.getClass().getSimpleName() + ": " + e.getMessage()), eventConsumer); - emitFinish(ctx, eventConsumer); + ctx.emit(new AgentError(e.getClass().getSimpleName() + ": " + e.getMessage())); + emitFinish(ctx); } finally { - dispatchAgentFinish(ctx); + dispatcher.onAgentFinish(ctx); } } + private static void emitFinish(AgentRunContext ctx) { + ctx.emit( + new AgentFinish( + ctx.getIteration(), + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens())); + } + /** * Run {@code beforeLlmCall} middleware against {@code messages}. Returns the messages to send, * possibly rewritten by middleware, or {@code null} if middleware aborted the call (in which case * this method has already emitted the terminal events). */ private List resolveMessagesForCall( - AgentRunContext ctx, - List messages, - Consumer eventConsumer) { - Decision> decision = dispatchBeforeLlmCall(ctx, messages); + AgentRunContext ctx, List messages) { + Decision> decision = dispatcher.beforeLlmCall(ctx, messages); if (decision.kind() == Decision.Kind.ABORT) { - emit( - ctx, - new AgentError("LLM call skipped by middleware: " + decision.reason()), - eventConsumer); - emitFinish(ctx, eventConsumer); + ctx.emit(new AgentError("LLM call skipped by middleware: " + decision.reason())); + emitFinish(ctx); return null; } return decision.kind() == Decision.Kind.REPLACE ? decision.replacement() : messages; } - private static ChatCompletionAssistantMessageParam buildAssistantMessage(StreamResult result) { - ChatCompletionAssistantMessageParam.Builder b = ChatCompletionAssistantMessageParam.builder(); - if (!result.content.isEmpty()) { - b.content(result.content); - } - if (result.toolCalls != null && !result.toolCalls.isEmpty()) { - b.toolCalls(result.toolCalls); - } - return b.build(); - } - /** * Execute the assistant's tool calls in 3 phases: * @@ -262,8 +229,7 @@ private static ChatCompletionAssistantMessageParam buildAssistantMessage(StreamR private void executeToolCalls( AgentRunContext ctx, ConversationMemory memory, - List toolCalls, - Consumer eventConsumer) { + List toolCalls) { List approved = new ArrayList<>(); for (ChatCompletionMessageToolCall toolCall : toolCalls) { ChatCompletionMessageFunctionToolCall fnCall = toolCall.asFunction(); @@ -276,22 +242,22 @@ private void executeToolCalls( // assistant/tool_result pairing the next API call needs) and let the loop self-correct. String err = "Tool call failed: " + e.getMessage(); memory.addToolResult(fnCall.id(), err); - emit(ctx, new ToolResult(fnCall.id(), toolName, err, true), eventConsumer); + ctx.emit(new ToolResult(fnCall.id(), toolName, err, true)); continue; } ToolInvocation invocation = new ToolInvocation(fnCall.id(), toolName, toolArgs); - Decision decision = dispatchBeforeToolCall(ctx, invocation); + Decision decision = dispatcher.beforeToolCall(ctx, invocation); if (decision.kind() == Decision.Kind.ABORT) { String denied = "Tool call denied: " + decision.reason(); memory.addToolResult(fnCall.id(), denied); - emit(ctx, new ToolResult(fnCall.id(), toolName, denied, true), eventConsumer); + ctx.emit(new ToolResult(fnCall.id(), toolName, denied, true)); continue; } boolean rewritten = decision.kind() == Decision.Kind.REPLACE; ToolInvocation effective = rewritten ? decision.replacement() : invocation; - emit(ctx, new ToolCall(effective.id(), effective.name(), effective.args()), eventConsumer); + ctx.emit(new ToolCall(effective.id(), effective.name(), effective.args())); approved.add(new ToolCallEntry(fnCall, effective, rewritten)); } @@ -304,27 +270,16 @@ private void executeToolCalls( for (int i = 0; i < approved.size(); i++) { ToolCallEntry entry = approved.get(i); String raw = futures.get(i).join(); - String output = dispatchAfterToolCall(ctx, entry.invocation, raw); + Decision after = dispatcher.afterToolCall(ctx, entry.invocation, raw); + // ABORT.afterToolCall: discard the result; surface reason() to the LLM and mark the event + // as an error so listeners can distinguish it from a successful tool result. + boolean isError = after.kind() == Decision.Kind.ABORT; + String output = + after.kind() == Decision.Kind.ABORT + ? after.reason() + : (after.kind() == Decision.Kind.REPLACE ? after.replacement() : raw); memory.addToolResult(entry.fnCall.id(), output); - emit( - ctx, - new ToolResult(entry.fnCall.id(), entry.invocation.name(), output, false), - eventConsumer); - } - } - - /** Result of a streaming LLM call, assembled from chunks. */ - private static class StreamResult { - final String content; - final List toolCalls; - - StreamResult(String content, List toolCalls) { - this.content = content; - this.toolCalls = toolCalls; - } - - boolean isEmpty() { - return content.isEmpty() && (toolCalls == null || toolCalls.isEmpty()); + ctx.emit(new ToolResult(entry.fnCall.id(), entry.invocation.name(), output, isError)); } } @@ -358,125 +313,6 @@ String argsJson() { } } - /** - * Stream LLM response, emitting ContentDelta for each text chunk. Assembles tool calls directly - * from streamed chunks — no non-streaming fallback. Exceptions propagate to the top-level handler - * in {@link #run}. - */ - private StreamResult streamLlmResponse( - AgentRunContext ctx, - List messages, - String effectiveModel, - Consumer eventConsumer) { - ChatCompletionCreateParams.Builder paramsBuilder = - ChatCompletionCreateParams.builder() - .model(effectiveModel) - .streamOptions(ChatCompletionStreamOptions.builder().includeUsage(true).build()); - for (ChatCompletionMessageParam msg : messages) { - paramsBuilder.addMessage(msg); - } - toolRegistry.addToolsTo(paramsBuilder); - - LOG.info("LLM request: model={}", effectiveModel); - StreamAccumulator acc = new StreamAccumulator(); - try (StreamResponse stream = - client.chat().completions().createStreaming(paramsBuilder.build())) { - stream.stream().forEach(chunk -> consumeChunk(ctx, chunk, acc, eventConsumer)); - } - return new StreamResult(acc.content.toString(), acc.buildToolCalls()); - } - - /** Fold one streaming chunk into {@code acc}, emitting per-token {@link ContentDelta}s. */ - private void consumeChunk( - AgentRunContext ctx, - ChatCompletionChunk chunk, - StreamAccumulator acc, - Consumer eventConsumer) { - if (!acc.serverModelLogged) { - LOG.info("LLM response: server-echoed model={}", chunk.model()); - acc.serverModelLogged = true; - } - chunk - .usage() - .ifPresent(u -> ctx.addTokenUsage(u.promptTokens(), u.completionTokens(), u.totalTokens())); - - for (ChatCompletionChunk.Choice c : chunk.choices()) { - c.delta() - .content() - .ifPresent( - text -> { - acc.content.append(text); - emit(ctx, new ContentDelta(text), eventConsumer); - }); - c.delta().toolCalls().ifPresent(acc::mergeToolCallDeltas); - } - } - - /** - * Mutable accumulator for a single streaming LLM turn. Tool call fields are keyed by the chunk's - * {@code index} because provider SDKs may deliver a single logical call across multiple chunks - * and only surface the {@code id}/{@code name} on the first one. - */ - private static final class StreamAccumulator { - final StringBuilder content = new StringBuilder(); - final Map toolCallIds = new HashMap<>(); - final Map toolCallNames = new HashMap<>(); - final Map toolCallArgs = new HashMap<>(); - boolean serverModelLogged = false; - - void mergeToolCallDeltas(List deltas) { - for (ChatCompletionChunk.Choice.Delta.ToolCall tc : deltas) { - int idx = (int) tc.index(); - tc.id().ifPresent(id -> toolCallIds.put(idx, id)); - tc.function() - .ifPresent( - fn -> { - fn.name().ifPresent(name -> toolCallNames.put(idx, name)); - fn.arguments() - .ifPresent( - args -> - toolCallArgs - .computeIfAbsent(idx, k -> new StringBuilder()) - .append(args)); - }); - } - } - - /** - * Materialize accumulated deltas into SDK tool-call objects. Returns {@code null} (not an empty - * list) if no tool calls were seen, matching the existing {@link StreamResult} contract. - */ - List buildToolCalls() { - if (toolCallIds.isEmpty()) return null; - List out = new ArrayList<>(toolCallIds.size()); - for (Map.Entry e : toolCallIds.entrySet()) { - int idx = e.getKey(); - String id = (e.getValue() == null || e.getValue().isEmpty()) ? synthId() : e.getValue(); - String args = toolCallArgs.containsKey(idx) ? toolCallArgs.get(idx).toString() : "{}"; - out.add( - ChatCompletionMessageToolCall.ofFunction( - ChatCompletionMessageFunctionToolCall.builder() - .id(id) - .function( - ChatCompletionMessageFunctionToolCall.Function.builder() - .name(toolCallNames.getOrDefault(idx, "")) - .arguments(args) - .build()) - .build())); - } - return out; - } - - /** - * Synthesize an id for tool calls whose id never arrived on the stream (some OpenAI-compatible - * providers omit it). The id has to be stable within a turn and unique across turns so the - * assistant/tool_result pairing downstream holds. - */ - private static String synthId() { - return "local_" + java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 24); - } - } - private static Map parseToolArgs(String json) { if (json == null || json.isEmpty()) { return new HashMap<>(); @@ -488,113 +324,6 @@ private static Map parseToolArgs(String json) { } } - // --- Middleware dispatch methods --- - // - // Middlewares are internal framework code. If one throws, the agent run fails via the - // top-level catch in run() — we do not wrap individual dispatch calls in try/catch. - - private void emitFinish(AgentRunContext ctx, Consumer eventConsumer) { - emit( - ctx, - new AgentFinish( - ctx.getIteration(), - ctx.getPromptTokens(), - ctx.getCompletionTokens(), - ctx.getTotalTokens()), - eventConsumer); - } - - private void emit(AgentRunContext ctx, AgentEvent event, Consumer consumer) { - AgentEvent filtered = event; - for (AgentMiddleware mw : middlewares) { - Decision d = mw.onEvent(ctx, filtered); - if (d.kind() == Decision.Kind.ABORT) return; - if (d.kind() == Decision.Kind.REPLACE) filtered = d.replacement(); - } - consumer.accept(filtered); - } - - private void dispatchAgentStart(AgentRunContext ctx) { - for (AgentMiddleware mw : middlewares) { - mw.onAgentStart(ctx); - } - } - - private void dispatchAgentFinish(AgentRunContext ctx) { - // Runs even when the agent body threw, so swallow here to ensure every middleware's cleanup - // gets a chance to run; otherwise we'd leak session state in later middlewares. - for (int i = middlewares.size() - 1; i >= 0; i--) { - try { - middlewares.get(i).onAgentFinish(ctx); - } catch (Exception e) { - LOG.warn("Middleware onAgentFinish error", e); - } - } - } - - /** - * Run {@code beforeLlmCall} middleware in onion order, folding REPLACE so later middlewares see - * rewritten messages. Returns PROCEED if no middleware touched the value, REPLACE with the final - * value if any did, or ABORT if any middleware short-circuited. - */ - private Decision> dispatchBeforeLlmCall( - AgentRunContext ctx, List messages) { - List current = messages; - for (AgentMiddleware mw : middlewares) { - Decision> d = mw.beforeLlmCall(ctx, current); - if (d.kind() == Decision.Kind.ABORT) return d; - if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); - } - return Decision.of(messages, current); - } - - /** - * Run {@code afterLlmCall} middleware in reverse onion order, folding REPLACE so earlier - * middlewares see rewritten responses. Returns the final response, or ABORT if any middleware - * short-circuits. - */ - private Decision dispatchAfterLlmCall( - AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { - ChatCompletionAssistantMessageParam current = response; - for (int i = middlewares.size() - 1; i >= 0; i--) { - Decision d = - middlewares.get(i).afterLlmCall(ctx, current); - if (d.kind() == Decision.Kind.ABORT) return d; - if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); - } - return Decision.of(response, current); - } - - /** - * Run {@code beforeToolCall} middleware in onion order, folding REPLACE so later middlewares can - * further rewrite. Returns PROCEED if untouched, REPLACE with the final invocation otherwise, or - * ABORT if any middleware denies the call. - */ - private Decision dispatchBeforeToolCall( - AgentRunContext ctx, ToolInvocation call) { - ToolInvocation current = call; - for (AgentMiddleware mw : middlewares) { - Decision d = mw.beforeToolCall(ctx, current); - if (d.kind() == Decision.Kind.ABORT) return d; - if (d.kind() == Decision.Kind.REPLACE) current = d.replacement(); - } - return Decision.of(call, current); - } - - private String dispatchAfterToolCall(AgentRunContext ctx, ToolInvocation call, String result) { - String current = result; - for (int i = middlewares.size() - 1; i >= 0; i--) { - Decision d = middlewares.get(i).afterToolCall(ctx, call, current); - if (d.kind() == Decision.Kind.REPLACE) { - current = d.replacement(); - } else if (d.kind() == Decision.Kind.ABORT) { - // afterToolCall.abort: discard result; reason replaces it so the LLM still sees something. - current = d.reason(); - } - } - return current; - } - // --- Builder --- public static Builder builder() { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java index 187e30341f1..7fd72726b4d 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/Decision.java @@ -27,7 +27,9 @@ *

      *
    • {@code beforeLlmCall}: skip the LLM call and end the loop *
    • {@code beforeToolCall}: deny the call; reason is fed back to the LLM as the result - *
    • {@code afterToolCall}: discard the result; reason replaces it for the LLM + *
    • {@code afterToolCall}: short-circuit the chain; outer middlewares are not invoked, + * reason replaces the result for the LLM, and the emitted {@code ToolResult} is marked + * {@code isError=true} *
    • {@code onEvent}: drop the event *
    *