Skip to content

Commit 51ffc5a

Browse files
[Feature][Transform-v2] Add support for Zhipu AI in Embedding and LLM module (#8790)
1 parent 36a7533 commit 51ffc5a

File tree

17 files changed

+238
-58
lines changed

17 files changed

+238
-58
lines changed

Diff for: docs/en/transform-v2/embedding.md

+15-14
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@ different API endpoints.
1010

1111
## Options
1212

13-
| Name | Type | Required | Default Value | Description |
14-
|--------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------|
15-
| model_provider | enum | yes | - | The model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc. |
16-
| api_key | string | yes | - | The API key required to authenticate with the embedding service. |
17-
| secret_key | string | yes | - | The secret key required for additional authentication with the embedding service. |
18-
| single_vectorized_input_number | int | no | 1 | The number of inputs vectorized in one request. Default is 1. |
19-
| vectorization_fields | map | yes | - | A mapping between input fields and their corresponding output vector fields. |
20-
| model | string | yes | - | The specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI). |
21-
| api_path | string | no | - | The API endpoint for the embedding service. Typically provided by the model provider. |
22-
| oauth_path | string | no | - | The API endpoint for the oauth service. |
23-
| custom_config | map | no | | Custom configurations for the model. |
24-
| custom_response_parse | string | no | | Specifies how to parse the response from the model using JsonPath. Example: `$.choices[*].message.content`. |
25-
| custom_request_headers | map | no | | Custom headers for the request to the model. |
26-
| custom_request_body | map | no | | Custom body for the request. Supports placeholders like `${model}`, `${input}`. |
13+
| Name | Type | Required | Default Value | Description |
14+
|----------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
15+
| model_provider | enum | yes | - | The model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc. |
16+
| api_key | string | yes | - | The API key required to authenticate with the embedding service. |
17+
| secret_key | string | yes | - | The secret key required for additional authentication with the embedding service. |
18+
| single_vectorized_input_number | int | no | 1 | The number of inputs vectorized in one request. Default is 1. |
19+
| vectorization_fields | map | yes | - | A mapping between input fields and their corresponding output vector fields. |
20+
| model | string | yes | - | The specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI). |
21+
| api_path | string | no | - | The API endpoint for the embedding service. Typically provided by the model provider. |
22+
| dimension | int | no | - | TThe vector dimension defaults to 2048. The Embedding-3 model supports custom vector dimensions, and it is recommended to choose dimensions of 256, 512, 1024, or 2048. |
23+
| oauth_path | string | no | - | The API endpoint for the oauth service. |
24+
| custom_config | map | no | | Custom configurations for the model. |
25+
| custom_response_parse | string | no | | Specifies how to parse the response from the model using JsonPath. Example: `$.choices[*].message.content`. |
26+
| custom_request_headers | map | no | | Custom headers for the request to the model. |
27+
| custom_request_body | map | no | | Custom body for the request. Supports placeholders like `${model}`, `${input}`. |
2728

2829
### model_provider
2930

Diff for: docs/en/transform-v2/llm.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ more.
2828
### model_provider
2929

3030
The model provider to use. The available options are:
31-
OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, CUSTOM
31+
OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, ZHIPU, CUSTOM
3232

3333
> tips: If you use Microsoft, please make sure api_path cannot be empty
3434

Diff for: docs/zh/transform-v2/embedding.md

+15-14
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,21 @@
88

99
## 配置选项
1010

11-
| 名称 | 类型 | 是否必填 | 默认值 | 描述 |
12-
|--------------------------------|--------|------|-----|------------------------------------------------------------------|
13-
| model_provider | enum || - | embedding模型的提供商。可选项包括 `QIANFAN``OPENAI` 等。 |
14-
| api_key | string || - | 用于验证embedding服务的API密钥。 |
15-
| secret_key | string || - | 用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。 |
16-
| single_vectorized_input_number | int || 1 | 单次请求向量化的输入数量。默认值为1。 |
17-
| vectorization_fields | map || - | 输入字段和相应的输出向量字段之间的映射。 |
18-
| model | string || - | 要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`|
19-
| api_path | string || - | embedding服务的API。通常由模型提供商提供。 |
20-
| oauth_path | string || - | oauth 服务的 API 。 |
21-
| custom_config | map || | 模型的自定义配置。 |
22-
| custom_response_parse | string || | 使用 JsonPath 解析模型响应的方式。示例:`$.choices[*].message.content`|
23-
| custom_request_headers | map || | 发送到模型的请求的自定义头信息。 |
24-
| custom_request_body | map || | 请求体的自定义配置。支持占位符如 `${model}``${input}`|
11+
| 名称 | 类型 | 是否必填 | 默认值 | 描述 |
12+
|----------------------------------|--------|------|--------|--------------------------------------------------------------------|
13+
| model_provider | enum || - | embedding模型的提供商。可选项包括 `QIANFAN``OPENAI` 等。 |
14+
| api_key | string || - | 用于验证embedding服务的API密钥。 |
15+
| secret_key | string || - | 用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。 |
16+
| single_vectorized_input_number | int || 1 | 单次请求向量化的输入数量。默认值为1。 |
17+
| vectorization_fields | map || - | 输入字段和相应的输出向量字段之间的映射。 |
18+
| model | string || - | 要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`|
19+
| api_path | string || - | embedding服务的API。通常由模型提供商提供。 |
20+
| dimension | int || 2048 | 向量维度默认为 2048,Embedding-3模型支持自定义向量维度,建议选择256、512、1024或2048维度。 |
21+
| oauth_path | string || - | oauth 服务的 API 。 |
22+
| custom_config | map || | 模型的自定义配置。 |
23+
| custom_response_parse | string || | 使用 JsonPath 解析模型响应的方式。示例:`$.choices[*].message.content`|
24+
| custom_request_headers | map || | 发送到模型的请求的自定义头信息。 |
25+
| custom_request_body | map || | 请求体的自定义配置。支持占位符如 `${model}``${input}`|
2526

2627
### embedding_model_provider
2728

Diff for: docs/zh/transform-v2/llm.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
### model_provider
2727

2828
要使用的模型提供者。可用选项为:
29-
OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, CUSTOM
29+
OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, ZHIPU, CUSTOM
3030

3131
> tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空
3232

Diff for: seatunnel-transforms-v2/pom.xml

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
<properties>
3333
<httpclient.version>4.5.13</httpclient.version>
3434
<httpcore.version>4.4.4</httpcore.version>
35+
<mockwebserver.version>3.6.0</mockwebserver.version>
36+
<zhipu.version>release-V4-2.3.0</zhipu.version>
3537
</properties>
3638

3739
<dependencyManagement>
@@ -95,7 +97,7 @@
9597
<dependency>
9698
<groupId>com.squareup.okhttp3</groupId>
9799
<artifactId>mockwebserver</artifactId>
98-
<version>3.6.0</version>
100+
<version>${mockwebserver.version}</version>
99101
<scope>test</scope>
100102
</dependency>
101103
</dependencies>

Diff for: seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ public enum ModelProvider {
2828
KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""),
2929
DEEPSEEK("https://api.deepseek.com/chat/completions", ""),
3030
MICROSOFT("", ""),
31+
ZHIPU(
32+
"https://open.bigmodel.cn/api/paas/v4/chat/completions",
33+
"https://open.bigmodel.cn/api/paas/v4/embeddings"),
3134
CUSTOM("", ""),
3235
LOCAL("", "");
3336

Diff for: seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java

+3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ public class ModelTransformConfig implements Serializable {
7979
.withFallbackKeys("inference_batch_size")
8080
.withDescription("The row batch size of each process");
8181

82+
public static final Option<Integer> DIMENSION =
83+
Options.key("dimension").intType().defaultValue(2048).withDescription("dimension");
84+
8285
public static class CustomRequestConfig {
8386

8487
// Custom response parsing

Diff for: seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java

+13
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
3434
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel;
3535
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel;
36+
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu.ZhipuModel;
3637
import org.apache.seatunnel.transform.nlpmodel.llm.LLMTransformConfig;
3738

3839
import lombok.NonNull;
@@ -136,6 +137,18 @@ public void open() {
136137
EmbeddingTransformConfig
137138
.SINGLE_VECTORIZED_INPUT_NUMBER));
138139
break;
140+
case ZHIPU:
141+
model =
142+
new ZhipuModel(
143+
config.get(ModelTransformConfig.API_KEY),
144+
config.get(ModelTransformConfig.MODEL),
145+
provider.usedEmbeddingPath(
146+
config.get(ModelTransformConfig.API_PATH)),
147+
config.get(ModelTransformConfig.DIMENSION),
148+
config.get(
149+
EmbeddingTransformConfig
150+
.SINGLE_VECTORIZED_INPUT_NUMBER));
151+
break;
139152
case LOCAL:
140153
default:
141154
throw new IllegalArgumentException("Unsupported model provider: " + provider);

Diff for: seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java

+4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ public OptionRule optionRule() {
6262
LLMTransformConfig.MODEL_PROVIDER,
6363
ModelProvider.CUSTOM,
6464
LLMTransformConfig.CustomRequestConfig.CUSTOM_CONFIG)
65+
.conditional(
66+
EmbeddingTransformConfig.MODEL_PROVIDER,
67+
ModelProvider.ZHIPU,
68+
EmbeddingTransformConfig.DIMENSION)
6569
.optional(TransformCommonOptions.MULTI_TABLES)
6670
.optional(TransformCommonOptions.TABLE_MATCH_REGEX)
6771
.build();

Diff for: seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,23 @@ protected AbstractModel(Integer singleVectorizedInputNumber) {
4242
public List<ByteBuffer> vectorization(Object[] fields) throws IOException {
4343
List<ByteBuffer> result = new ArrayList<>();
4444

45-
List<List<Float>> vectors = batchProcess(fields, singleVectorizedInputNumber);
46-
for (List<Float> vector : vectors) {
47-
result.add(BufferUtils.toByteBuffer(vector.toArray(new Float[0])));
45+
List<List<Double>> vectors = batchProcess(fields, singleVectorizedInputNumber);
46+
for (List<Double> vector : vectors) {
47+
result.add(BufferUtils.toByteBuffer(vector.toArray(new Double[0])));
4848
}
4949
return result;
5050
}
5151

52-
protected abstract List<List<Float>> vector(Object[] fields) throws IOException;
52+
protected abstract List<List<Double>> vector(Object[] fields) throws IOException;
5353

54-
public List<List<Float>> batchProcess(Object[] array, int batchSize) throws IOException {
55-
List<List<Float>> merged = new ArrayList<>();
54+
public List<List<Double>> batchProcess(Object[] array, int batchSize) throws IOException {
55+
List<List<Double>> merged = new ArrayList<>();
5656
if (array == null || array.length == 0) {
5757
return merged;
5858
}
5959
for (int i = 0; i < array.length; i += batchSize) {
6060
Object[] batch = ArrayUtils.subarray(array, i, i + batchSize);
61-
List<List<Float>> vector = vector(batch);
61+
List<List<Double>> vector = vector(batch);
6262
merged.addAll(vector);
6363
}
6464
if (array.length != merged.size()) {

Diff for: seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public CustomModel(
6767
}
6868

6969
@Override
70-
protected List<List<Float>> vector(Object[] fields) throws IOException {
70+
protected List<List<Double>> vector(Object[] fields) throws IOException {
7171
return vectorGeneration(fields);
7272
}
7373

@@ -76,7 +76,7 @@ public Integer dimension() throws IOException {
7676
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
7777
}
7878

79-
private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
79+
private List<List<Double>> vectorGeneration(Object[] fields) throws IOException {
8080
HttpPost post = new HttpPost(apiPath);
8181
// Construct a request with custom parameters
8282
for (Map.Entry<String, String> entry : header.entrySet()) {
@@ -96,7 +96,7 @@ private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
9696
}
9797

9898
return OBJECT_MAPPER.convertValue(
99-
parseResponse(responseStr), new TypeReference<List<List<Float>>>() {});
99+
parseResponse(responseStr), new TypeReference<List<List<Double>>>() {});
100100
}
101101

102102
@VisibleForTesting

Diff for: seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public DoubaoModel(String apiKey, String model, String apiPath, Integer vectoriz
5454
}
5555

5656
@Override
57-
protected List<List<Float>> vector(Object[] fields) throws IOException {
57+
protected List<List<Double>> vector(Object[] fields) throws IOException {
5858
return vectorGeneration(fields);
5959
}
6060

@@ -63,7 +63,7 @@ public Integer dimension() throws IOException {
6363
return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
6464
}
6565

66-
private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
66+
private List<List<Double>> vectorGeneration(Object[] fields) throws IOException {
6767
HttpPost post = new HttpPost(apiPath);
6868
post.setHeader("Authorization", "Bearer " + apiKey);
6969
post.setHeader("Content-Type", "application/json");
@@ -82,14 +82,14 @@ private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
8282
}
8383

8484
JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
85-
List<List<Float>> embeddings = new ArrayList<>();
85+
List<List<Double>> embeddings = new ArrayList<>();
8686

8787
if (data.isArray()) {
8888
for (JsonNode node : data) {
8989
JsonNode embeddingNode = node.get("embedding");
90-
List<Float> embedding =
90+
List<Double> embedding =
9191
OBJECT_MAPPER.readValue(
92-
embeddingNode.traverse(), new TypeReference<List<Float>>() {});
92+
embeddingNode.traverse(), new TypeReference<List<Double>>() {});
9393
embeddings.add(embedding);
9494
}
9595
}

0 commit comments

Comments
 (0)