Skip to content

Commit 5b5ee84

Browse files
corgy-wHisoka-X
andauthored
[Improve][Transform] Add LLM model provider microsoft (#7778)
Co-authored-by: Jia Fan <[email protected]>
1 parent 03d325e commit 5b5ee84

File tree

10 files changed

+277
-6
lines changed

10 files changed

+277
-6
lines changed

docs/en/transform-v2/llm.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ more.
1111
## Options
1212

1313
| name | type | required | default value |
14-
|------------------------| ------ | -------- |---------------|
14+
|------------------------|--------|----------|---------------|
1515
| model_provider | enum | yes | |
1616
| output_data_type | enum | no | String |
1717
| output_column_name | string | no | llm_output |
@@ -28,7 +28,9 @@ more.
2828
### model_provider
2929

3030
The model provider to use. The available options are:
31-
OPENAI, DOUBAO, KIMIAI, CUSTOM
31+
OPENAI, DOUBAO, KIMIAI, MICROSOFT, CUSTOM
32+
33+
> tips: If you use Microsoft, please make sure api_path cannot be empty
3234
3335
### output_data_type
3436

@@ -254,6 +256,7 @@ sink {
254256
}
255257
}
256258
```
259+
257260
### Customize the LLM model
258261

259262
```hocon

docs/zh/transform-v2/llm.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
### model_provider
2727

2828
要使用的模型提供者。可用选项为:
29-
OPENAI、DOUBAO、KIMIAI、CUSTOM
29+
OPENAI、DOUBAO、KIMIAI、MICROSOFT, CUSTOM
30+
31+
> tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空
3032
3133
### output_data_type
3234

seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java

+7
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ public void testLLMWithOpenAI(TestContainer container)
8888
Assertions.assertEquals(0, execResult.getExitCode());
8989
}
9090

91+
@TestTemplate
92+
public void testLLMWithMicrosoft(TestContainer container)
93+
throws IOException, InterruptedException {
94+
Container.ExecResult execResult = container.executeJob("/llm_microsoft_transform.conf");
95+
Assertions.assertEquals(0, execResult.getExitCode());
96+
}
97+
9198
@TestTemplate
9299
public void testLLMWithOpenAIBoolean(TestContainer container)
93100
throws IOException, InterruptedException {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
######
18+
###### This config file is a demonstration of streaming processing in seatunnel config
19+
######
20+
21+
env {
22+
job.mode = "BATCH"
23+
}
24+
25+
source {
26+
FakeSource {
27+
row.num = 5
28+
schema = {
29+
fields {
30+
id = "int"
31+
name = "string"
32+
}
33+
}
34+
rows = [
35+
{fields = [1, "Jia Fan"], kind = INSERT}
36+
{fields = [2, "Hailin Wang"], kind = INSERT}
37+
{fields = [3, "Tomas"], kind = INSERT}
38+
{fields = [4, "Eric"], kind = INSERT}
39+
{fields = [5, "Guangdong Liu"], kind = INSERT}
40+
]
41+
result_table_name = "fake"
42+
}
43+
}
44+
45+
transform {
46+
LLM {
47+
source_table_name = "fake"
48+
model_provider = MICROSOFT
49+
model = gpt-35-turbo
50+
api_key = sk-xxx
51+
prompt = "Determine whether someone is Chinese or American by their name"
52+
api_path = "http://mockserver:1080/openai/deployments/${model}/chat/completions?api-version=2024-02-01"
53+
result_table_name = "llm_output"
54+
}
55+
}
56+
57+
sink {
58+
Assert {
59+
source_table_name = "llm_output"
60+
rules =
61+
{
62+
field_rules = [
63+
{
64+
field_name = llm_output
65+
field_type = string
66+
field_value = [
67+
{
68+
rule_type = NOT_NULL
69+
}
70+
]
71+
}
72+
]
73+
}
74+
}
75+
}

seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json

+32
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,37 @@
104104
"Content-Type": "application/json"
105105
}
106106
}
107+
},
108+
{
109+
"httpRequest": {
110+
"method": "POST",
111+
"path": "/openai/deployments/gpt-35-turbo/chat/.*"
112+
},
113+
"httpResponse": {
114+
"body": {
115+
"id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9",
116+
"object": "chat.completion",
117+
"created": 1679072642,
118+
"model": "gpt-35-turbo",
119+
"usage": {
120+
"prompt_tokens": 58,
121+
"completion_tokens": 68,
122+
"total_tokens": 126
123+
},
124+
"choices": [
125+
{
126+
"message": {
127+
"role": "assistant",
128+
"content": "[\"Chinese\"]"
129+
},
130+
"finish_reason": "stop",
131+
"index": 0
132+
}
133+
]
134+
},
135+
"headers": {
136+
"Content-Type": "application/json"
137+
}
138+
}
107139
}
108140
]

seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public enum ModelProvider {
2626
"https://ark.cn-beijing.volces.com/api/v3/embeddings"),
2727
QIANFAN("", "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"),
2828
KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""),
29+
MICROSOFT("", ""),
2930
CUSTOM("", ""),
3031
LOCAL("", "");
3132

seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java

+12
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model;
3232
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
3333
import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
34+
import org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
3435
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
3536

3637
import lombok.NonNull;
@@ -94,6 +95,17 @@ public void open() {
9495
LLMTransformConfig.CustomRequestConfig
9596
.CUSTOM_RESPONSE_PARSE));
9697
break;
98+
case MICROSOFT:
99+
model =
100+
new MicrosoftModel(
101+
inputCatalogTable.getSeaTunnelRowType(),
102+
outputDataType.getSqlType(),
103+
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
104+
config.get(LLMTransformConfig.PROMPT),
105+
config.get(LLMTransformConfig.MODEL),
106+
config.get(LLMTransformConfig.API_KEY),
107+
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)));
108+
break;
97109
case OPENAI:
98110
case DOUBAO:
99111
model =

seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.apache.seatunnel.api.table.factory.TableTransformFactory;
2727
import org.apache.seatunnel.api.table.factory.TableTransformFactoryContext;
2828
import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
29-
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
3029

3130
import com.google.auto.service.AutoService;
3231

@@ -50,14 +49,17 @@ public OptionRule optionRule() {
5049
LLMTransformConfig.PROCESS_BATCH_SIZE)
5150
.conditional(
5251
LLMTransformConfig.MODEL_PROVIDER,
53-
Lists.newArrayList(ModelProvider.OPENAI, ModelProvider.DOUBAO),
52+
Lists.newArrayList(
53+
ModelProvider.OPENAI,
54+
ModelProvider.DOUBAO,
55+
ModelProvider.MICROSOFT),
5456
LLMTransformConfig.API_KEY)
5557
.conditional(
5658
LLMTransformConfig.MODEL_PROVIDER,
5759
ModelProvider.QIANFAN,
5860
LLMTransformConfig.API_KEY,
5961
LLMTransformConfig.SECRET_KEY,
60-
ModelTransformConfig.OAUTH_PATH)
62+
LLMTransformConfig.OAUTH_PATH)
6163
.conditional(
6264
LLMTransformConfig.MODEL_PROVIDER,
6365
ModelProvider.CUSTOM,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft;
19+
20+
import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
21+
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
22+
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
23+
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
24+
25+
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
26+
import org.apache.seatunnel.api.table.type.SqlType;
27+
import org.apache.seatunnel.transform.nlpmodel.CustomConfigPlaceholder;
28+
import org.apache.seatunnel.transform.nlpmodel.llm.remote.AbstractModel;
29+
30+
import org.apache.http.client.config.RequestConfig;
31+
import org.apache.http.client.methods.CloseableHttpResponse;
32+
import org.apache.http.client.methods.HttpPost;
33+
import org.apache.http.entity.StringEntity;
34+
import org.apache.http.impl.client.CloseableHttpClient;
35+
import org.apache.http.impl.client.HttpClients;
36+
import org.apache.http.util.EntityUtils;
37+
38+
import com.google.common.annotations.VisibleForTesting;
39+
40+
import java.io.IOException;
41+
import java.util.List;
42+
43+
public class MicrosoftModel extends AbstractModel {
44+
45+
private final CloseableHttpClient client;
46+
private final String apiKey;
47+
private final String model;
48+
private final String apiPath;
49+
50+
public MicrosoftModel(
51+
SeaTunnelRowType rowType,
52+
SqlType outputType,
53+
List<String> projectionColumns,
54+
String prompt,
55+
String model,
56+
String apiKey,
57+
String apiPath) {
58+
super(rowType, outputType, projectionColumns, prompt);
59+
this.model = model;
60+
this.apiKey = apiKey;
61+
this.apiPath =
62+
CustomConfigPlaceholder.replacePlaceholders(
63+
apiPath, CustomConfigPlaceholder.REPLACE_PLACEHOLDER_MODEL, model, null);
64+
this.client = HttpClients.createDefault();
65+
}
66+
67+
@Override
68+
protected List<String> chatWithModel(String prompt, String data) throws IOException {
69+
HttpPost post = new HttpPost(apiPath);
70+
post.setHeader("Authorization", "Bearer " + apiKey);
71+
post.setHeader("Content-Type", "application/json");
72+
ObjectNode objectNode = createJsonNodeFromData(prompt, data);
73+
post.setEntity(new StringEntity(OBJECT_MAPPER.writeValueAsString(objectNode), "UTF-8"));
74+
post.setConfig(
75+
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
76+
CloseableHttpResponse response = client.execute(post);
77+
String responseStr = EntityUtils.toString(response.getEntity());
78+
if (response.getStatusLine().getStatusCode() != 200) {
79+
throw new IOException("Failed to chat with model, response: " + responseStr);
80+
}
81+
82+
JsonNode result = OBJECT_MAPPER.readTree(responseStr);
83+
String resultData = result.get("choices").get(0).get("message").get("content").asText();
84+
return OBJECT_MAPPER.readValue(
85+
convertData(resultData), new TypeReference<List<String>>() {});
86+
}
87+
88+
@VisibleForTesting
89+
public ObjectNode createJsonNodeFromData(String prompt, String data) {
90+
ObjectNode objectNode = OBJECT_MAPPER.createObjectNode();
91+
ArrayNode messages = objectNode.putArray("messages");
92+
messages.addObject().put("role", "system").put("content", prompt);
93+
messages.addObject().put("role", "user").put("content", data);
94+
return objectNode;
95+
}
96+
97+
@Override
98+
public void close() throws IOException {
99+
if (client != null) {
100+
client.close();
101+
}
102+
}
103+
}

seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java

+34
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.seatunnel.format.json.RowToJsonConverters;
2929
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
3030
import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
31+
import org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
3132
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
3233

3334
import org.junit.jupiter.api.Assertions;
@@ -36,6 +37,7 @@
3637
import com.google.common.collect.Lists;
3738

3839
import java.io.IOException;
40+
import java.lang.reflect.Field;
3941
import java.util.ArrayList;
4042
import java.util.HashMap;
4143
import java.util.List;
@@ -130,6 +132,38 @@ void testKimiAIRequestJson() throws IOException {
130132
model.close();
131133
}
132134

135+
@Test
136+
void testMicrosoftRequestJson() throws Exception {
137+
SeaTunnelRowType rowType =
138+
new SeaTunnelRowType(
139+
new String[] {"id", "name"},
140+
new SeaTunnelDataType[] {BasicType.INT_TYPE, BasicType.STRING_TYPE});
141+
MicrosoftModel model =
142+
new MicrosoftModel(
143+
rowType,
144+
SqlType.STRING,
145+
null,
146+
"Determine whether someone is Chinese or American by their name",
147+
"gpt-35-turbo",
148+
"sk-xxx",
149+
"https://api.moonshot.cn/openai/deployments/${model}/chat/completions?api-version=2024-02-01");
150+
Field apiPathField = model.getClass().getDeclaredField("apiPath");
151+
apiPathField.setAccessible(true);
152+
String apiPath = (String) apiPathField.get(model);
153+
Assertions.assertEquals(
154+
"https://api.moonshot.cn/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-02-01",
155+
apiPath);
156+
157+
ObjectNode node =
158+
model.createJsonNodeFromData(
159+
"Determine whether someone is Chinese or American by their name",
160+
"{\"id\":1, \"name\":\"John\"}");
161+
Assertions.assertEquals(
162+
"{\"messages\":[{\"role\":\"system\",\"content\":\"Determine whether someone is Chinese or American by their name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1, \\\"name\\\":\\\"John\\\"}\"}]}",
163+
OBJECT_MAPPER.writeValueAsString(node));
164+
model.close();
165+
}
166+
133167
@Test
134168
void testCustomRequestJson() throws IOException {
135169
SeaTunnelRowType rowType =

0 commit comments

Comments
 (0)