Skip to content

Commit ca209f1

Browse files
[KYUUBI #7379][PR 2b/4] Data Agent Engine: agent runtime, middleware, and OpenAI provider
1 parent ae16f6c commit ca209f1

19 files changed

Lines changed: 2867 additions & 0 deletions
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<!--
3+
~ Licensed to the Apache Software Foundation (ASF) under one or more
4+
~ contributor license agreements. See the NOTICE file distributed with
5+
~ this work for additional information regarding copyright ownership.
6+
~ The ASF licenses this file to You under the Apache License, Version 2.0
7+
~ (the "License"); you may not use this file except in compliance with
8+
~ the License. You may obtain a copy of the License at
9+
~
10+
~ http://www.apache.org/licenses/LICENSE-2.0
11+
~
12+
~ Unless required by applicable law or agreed to in writing, software
13+
~ distributed under the License is distributed on an "AS IS" BASIS,
14+
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
~ See the License for the specific language governing permissions and
16+
~ limitations under the License.
17+
-->
18+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
19+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
20+
<modelVersion>4.0.0</modelVersion>
21+
<parent>
22+
<groupId>org.apache.kyuubi</groupId>
23+
<artifactId>kyuubi-parent</artifactId>
24+
<version>1.12.0-SNAPSHOT</version>
25+
<relativePath>../../pom.xml</relativePath>
26+
</parent>
27+
28+
<artifactId>kyuubi-data-agent-engine_${scala.binary.version}</artifactId>
29+
<packaging>jar</packaging>
30+
<name>Kyuubi Project Engine Data Agent</name>
31+
<url>https://kyuubi.apache.org/</url>
32+
33+
<dependencies>
34+
<!-- kyuubi dependency -->
35+
<dependency>
36+
<groupId>org.apache.kyuubi</groupId>
37+
<artifactId>kyuubi-common_${scala.binary.version}</artifactId>
38+
<version>${project.version}</version>
39+
</dependency>
40+
41+
<dependency>
42+
<groupId>org.apache.kyuubi</groupId>
43+
<artifactId>kyuubi-ha_${scala.binary.version}</artifactId>
44+
<version>${project.version}</version>
45+
</dependency>
46+
47+
<dependency>
48+
<groupId>org.apache.kyuubi</groupId>
49+
<artifactId>${hive.jdbc.artifact}</artifactId>
50+
<version>${project.version}</version>
51+
</dependency>
52+
53+
<!-- OpenAI official Java SDK -->
54+
<dependency>
55+
<groupId>com.openai</groupId>
56+
<artifactId>openai-java</artifactId>
57+
<version>${openai.sdk.version}</version>
58+
</dependency>
59+
60+
<!-- JSON Schema generation from Jackson-annotated classes -->
61+
<dependency>
62+
<groupId>com.github.victools</groupId>
63+
<artifactId>jsonschema-generator</artifactId>
64+
<version>${victools.jsonschema.version}</version>
65+
</dependency>
66+
<dependency>
67+
<groupId>com.github.victools</groupId>
68+
<artifactId>jsonschema-module-jackson</artifactId>
69+
<version>${victools.jsonschema.version}</version>
70+
</dependency>
71+
72+
<!-- SQLite JDBC driver -->
73+
<dependency>
74+
<groupId>org.xerial</groupId>
75+
<artifactId>sqlite-jdbc</artifactId>
76+
<version>${sqlite.version}</version>
77+
</dependency>
78+
79+
<!-- MySQL JDBC driver (also works for StarRocks) -->
80+
<dependency>
81+
<groupId>com.mysql</groupId>
82+
<artifactId>mysql-connector-j</artifactId>
83+
</dependency>
84+
85+
<!-- Trino JDBC driver -->
86+
<dependency>
87+
<groupId>io.trino</groupId>
88+
<artifactId>trino-jdbc</artifactId>
89+
</dependency>
90+
91+
<!-- Connection pool -->
92+
<dependency>
93+
<groupId>com.zaxxer</groupId>
94+
<artifactId>HikariCP</artifactId>
95+
</dependency>
96+
97+
<!-- test dependencies -->
98+
<dependency>
99+
<groupId>org.apache.kyuubi</groupId>
100+
<artifactId>kyuubi-common_${scala.binary.version}</artifactId>
101+
<version>${project.version}</version>
102+
<type>test-jar</type>
103+
<scope>test</scope>
104+
</dependency>
105+
106+
<dependency>
107+
<groupId>junit</groupId>
108+
<artifactId>junit</artifactId>
109+
<scope>test</scope>
110+
</dependency>
111+
</dependencies>
112+
113+
<build>
114+
<plugins>
115+
<plugin>
116+
<groupId>org.apache.maven.plugins</groupId>
117+
<artifactId>maven-surefire-plugin</artifactId>
118+
<configuration>
119+
<skipTests>${skipTests}</skipTests>
120+
<argLine>${extraJavaTestArgs}</argLine>
121+
</configuration>
122+
</plugin>
123+
<plugin>
124+
<groupId>org.apache.maven.plugins</groupId>
125+
<artifactId>maven-jar-plugin</artifactId>
126+
<executions>
127+
<execution>
128+
<id>prepare-test-jar</id>
129+
<goals>
130+
<goal>test-jar</goal>
131+
</goals>
132+
<phase>test-compile</phase>
133+
</execution>
134+
</executions>
135+
</plugin>
136+
</plugins>
137+
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
138+
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
139+
</build>
140+
141+
</project>
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.dataagent.provider.openai;
19+
20+
import com.openai.client.OpenAIClient;
21+
import com.openai.client.okhttp.OpenAIOkHttpClient;
22+
import com.zaxxer.hikari.HikariDataSource;
23+
import java.time.Duration;
24+
import java.util.concurrent.ConcurrentHashMap;
25+
import java.util.function.Consumer;
26+
import javax.sql.DataSource;
27+
import org.apache.kyuubi.config.KyuubiConf;
28+
import org.apache.kyuubi.config.KyuubiReservedKeys;
29+
import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory;
30+
import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder;
31+
import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider;
32+
import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest;
33+
import org.apache.kyuubi.engine.dataagent.runtime.AgentRunRequest;
34+
import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode;
35+
import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory;
36+
import org.apache.kyuubi.engine.dataagent.runtime.ReactAgent;
37+
import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError;
38+
import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
39+
import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware;
40+
import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware;
41+
import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
42+
import org.apache.kyuubi.engine.dataagent.tool.sql.SqlQueryTool;
43+
import org.slf4j.Logger;
44+
import org.slf4j.LoggerFactory;
45+
46+
/**
47+
* An OpenAI-compatible provider that wires up the full ReactAgent with streaming LLM, tools, and
48+
* middleware pipeline. Uses the official OpenAI Java SDK.
49+
*
50+
* <p>The ReactAgent, DataSource, and ToolRegistry are shared across all sessions within this engine
51+
* instance. Each session only maintains its own {@link ConversationMemory}. This works because each
52+
* engine is bound to one user + one datasource (via USER share level + subdomain isolation), so all
53+
* sessions within the engine naturally share the same data connection.
54+
*/
55+
public class OpenAiProvider implements DataAgentProvider {
56+
57+
private static final Logger LOG = LoggerFactory.getLogger(OpenAiProvider.class);
58+
59+
private final ReactAgent agent;
60+
private final ApprovalMiddleware approvalMiddleware;
61+
private final DataSource dataSource;
62+
private final ConcurrentHashMap<String, ConversationMemory> sessions = new ConcurrentHashMap<>();
63+
64+
public OpenAiProvider(KyuubiConf conf) {
65+
scala.Option<String> apiKeyOpt = conf.get(KyuubiConf.ENGINE_DATA_AGENT_LLM_API_KEY());
66+
if (apiKeyOpt.isEmpty()) {
67+
throw new IllegalArgumentException(
68+
KyuubiConf.ENGINE_DATA_AGENT_LLM_API_KEY().key()
69+
+ " is required for OPENAI_COMPATIBLE provider");
70+
}
71+
scala.Option<String> apiUrlOpt = conf.get(KyuubiConf.ENGINE_DATA_AGENT_LLM_API_URL());
72+
if (apiUrlOpt.isEmpty()) {
73+
throw new IllegalArgumentException(
74+
KyuubiConf.ENGINE_DATA_AGENT_LLM_API_URL().key()
75+
+ " is required for OPENAI_COMPATIBLE provider");
76+
}
77+
String apiKey = apiKeyOpt.get();
78+
String baseUrl = apiUrlOpt.get();
79+
String modelName = conf.get(KyuubiConf.ENGINE_DATA_AGENT_LLM_MODEL());
80+
81+
OpenAIClient client =
82+
OpenAIOkHttpClient.builder()
83+
.apiKey(apiKey)
84+
.baseUrl(baseUrl)
85+
.maxRetries(3)
86+
.timeout(Duration.ofSeconds(120))
87+
.build();
88+
89+
int maxIterations = (int) conf.get(KyuubiConf.ENGINE_DATA_AGENT_MAX_ITERATIONS());
90+
int queryTimeoutSeconds = (int) conf.get(KyuubiConf.ENGINE_DATA_AGENT_QUERY_TIMEOUT());
91+
92+
// Register tools and build prompt from JDBC URL
93+
DataSource ds = null;
94+
ReactAgent builtAgent = null;
95+
try {
96+
ToolRegistry toolRegistry = new ToolRegistry();
97+
SystemPromptBuilder promptBuilder = SystemPromptBuilder.create();
98+
scala.Option<String> jdbcUrlOpt = conf.get(KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL());
99+
if (jdbcUrlOpt.isDefined()) {
100+
String jdbcUrl = jdbcUrlOpt.get();
101+
LOG.info(
102+
"Data Agent JDBC URL configured ({})", jdbcUrl.replaceAll("//.*@", "//<redacted>@"));
103+
scala.Option<String> userOpt = conf.getOption(KyuubiReservedKeys.KYUUBI_SESSION_USER_KEY());
104+
String sessionUser = userOpt.isDefined() ? userOpt.get() : null;
105+
ds = DataSourceFactory.create(jdbcUrl, sessionUser);
106+
toolRegistry.register(new SqlQueryTool(ds, queryTimeoutSeconds));
107+
promptBuilder.jdbcUrl(jdbcUrl);
108+
}
109+
110+
ApprovalMiddleware approval = new ApprovalMiddleware(toolRegistry);
111+
112+
builtAgent =
113+
ReactAgent.builder()
114+
.client(client)
115+
.modelName(modelName)
116+
.toolRegistry(toolRegistry)
117+
.addMiddleware(new LoggingMiddleware())
118+
.addMiddleware(approval)
119+
.maxIterations(maxIterations)
120+
.toolTimeoutSeconds(queryTimeoutSeconds)
121+
.systemPrompt(promptBuilder.build())
122+
.build();
123+
124+
this.agent = builtAgent;
125+
this.approvalMiddleware = approval;
126+
this.dataSource = ds;
127+
} catch (Exception e) {
128+
if (builtAgent != null) {
129+
try {
130+
builtAgent.close();
131+
} catch (Exception ex) {
132+
LOG.warn("Error closing ReactAgent during constructor cleanup", ex);
133+
}
134+
}
135+
if (ds instanceof HikariDataSource) {
136+
((HikariDataSource) ds).close();
137+
}
138+
throw e;
139+
}
140+
}
141+
142+
@Override
143+
public void open(String sessionId, String user) {
144+
sessions.put(sessionId, new ConversationMemory());
145+
LOG.info("Opened Data Agent session {} for user {}", sessionId, user);
146+
}
147+
148+
@Override
149+
public void run(String sessionId, ProviderRunRequest request, Consumer<AgentEvent> onEvent) {
150+
ConversationMemory memory = sessions.get(sessionId);
151+
if (memory == null) {
152+
onEvent.accept(new AgentError("Session not found. Please reconnect."));
153+
return;
154+
}
155+
156+
AgentRunRequest agentRequest =
157+
new AgentRunRequest(request.getQuestion()).modelName(request.getModelName());
158+
String modeStr = request.getApprovalMode();
159+
if (modeStr != null && !modeStr.isEmpty()) {
160+
try {
161+
agentRequest.approvalMode(ApprovalMode.valueOf(modeStr.toUpperCase()));
162+
} catch (IllegalArgumentException e) {
163+
LOG.warn("Unknown approval mode '{}', using default NORMAL", modeStr);
164+
}
165+
}
166+
agent.run(agentRequest, memory, onEvent);
167+
}
168+
169+
@Override
170+
public boolean resolveApproval(String requestId, boolean approved) {
171+
return approvalMiddleware.resolve(requestId, approved);
172+
}
173+
174+
@Override
175+
public void close(String sessionId) {
176+
sessions.remove(sessionId);
177+
LOG.info("Closed Data Agent session {}", sessionId);
178+
}
179+
180+
@Override
181+
public void stop() {
182+
approvalMiddleware.cancelAll();
183+
try {
184+
agent.close();
185+
} catch (Exception e) {
186+
LOG.warn("Error closing ReactAgent", e);
187+
}
188+
if (dataSource instanceof HikariDataSource) {
189+
((HikariDataSource) dataSource).close();
190+
LOG.info("Closed Data Agent connection pool");
191+
}
192+
}
193+
}

0 commit comments

Comments
 (0)