Skip to content

Commit e0c99ac

Browse files
[Feature][Transforms-V2] Handling LLM non-standard format responses (#8551)
1 parent c439b99 commit e0c99ac

File tree

3 files changed

+92
-3
lines changed

3 files changed

+92
-3
lines changed

seatunnel-transforms-v2/pom.xml

+6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@
9292
<artifactId>httpcore</artifactId>
9393
<version>${httpcore.version}</version>
9494
</dependency>
95+
<dependency>
96+
<groupId>com.squareup.okhttp3</groupId>
97+
<artifactId>mockwebserver</artifactId>
98+
<version>3.6.0</version>
99+
<scope>test</scope>
100+
</dependency>
95101
</dependencies>
96102

97103
<build>

seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java

+10-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import com.jayway.jsonpath.JsonPath;
4141

4242
import java.io.IOException;
43+
import java.util.Collections;
4344
import java.util.Iterator;
4445
import java.util.List;
4546
import java.util.Map;
@@ -94,8 +95,15 @@ protected List<String> chatWithModel(String promptWithLimit, String rowsJson)
9495
if (response.getStatusLine().getStatusCode() != 200) {
9596
throw new IOException("Failed to get vector from custom, response: " + responseStr);
9697
}
97-
return OBJECT_MAPPER.convertValue(
98-
parseResponse(responseStr), new TypeReference<List<String>>() {});
98+
try {
99+
return OBJECT_MAPPER.convertValue(
100+
parseResponse(responseStr), new TypeReference<List<String>>() {});
101+
} catch (Exception e) {
102+
String result =
103+
OBJECT_MAPPER.convertValue(
104+
parseResponse(responseStr), new TypeReference<String>() {});
105+
return Collections.singletonList(result);
106+
}
99107
}
100108

101109
@VisibleForTesting

seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java

+76-1
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@
3535
import org.junit.jupiter.api.Assertions;
3636
import org.junit.jupiter.api.Test;
3737

38+
import okhttp3.mockwebserver.MockResponse;
39+
import okhttp3.mockwebserver.MockWebServer;
40+
3841
import java.io.IOException;
3942
import java.lang.reflect.Field;
4043
import java.util.ArrayList;
44+
import java.util.Collections;
4145
import java.util.HashMap;
4246
import java.util.List;
4347
import java.util.Map;
@@ -172,7 +176,6 @@ void testCustomRequestJson() throws IOException {
172176

173177
Map<String, String> header = new HashMap<>();
174178
header.put("Content-Type", "application/json");
175-
header.put("Authorization", "Bearer " + "apikey");
176179

177180
List<Map<String, String>> messagesList = new ArrayList<>();
178181

@@ -209,4 +212,76 @@ void testCustomRequestJson() throws IOException {
209212
"{\"messages\":[{\"role\":\"system\",\"content\":\"Determine whether someone is Chinese or American by their name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1, \\\"name\\\":\\\"John\\\"}\"}],\"model\":\"custom-model\"}",
210213
OBJECT_MAPPER.writeValueAsString(node));
211214
}
215+
216+
@Test
217+
void testCustomOllamaRequestJson() throws IOException {
218+
219+
MockWebServer mockWebServer = new MockWebServer();
220+
mockWebServer.start(11434);
221+
String jsonResponse =
222+
"{\n"
223+
+ " \"model\": \"qwen:7b\",\n"
224+
+ " \"created_at\": \"2025-02-07T01:22:46.589856Z\",\n"
225+
+ " \"message\": {\n"
226+
+ " \"role\": \"assistant\",\n"
227+
+ " \"content\": \"Based on the information provided in the JSON object, \\\"John\\\" does not inherently indicate if the person is Chinese or American. The name \\\"John\\\" is commonly used across many cultures. To determine a person's nationality based solely on their name, more context would be needed.\"\n"
228+
+ " },\n"
229+
+ " \"done_reason\": \"stop\",\n"
230+
+ " \"done\": true,\n"
231+
+ " \"total_duration\": 14435322300,\n"
232+
+ " \"load_duration\": 28998200,\n"
233+
+ " \"prompt_eval_count\": 34,\n"
234+
+ " \"prompt_eval_duration\": 302000000,\n"
235+
+ " \"eval_count\": 56,\n"
236+
+ " \"eval_duration\": 14102000000\n"
237+
+ "}";
238+
239+
mockWebServer.enqueue(
240+
new MockResponse()
241+
.setBody(jsonResponse)
242+
.addHeader("Content-Type", "application/json"));
243+
244+
SeaTunnelRowType rowType =
245+
new SeaTunnelRowType(
246+
new String[] {"id", "name"},
247+
new SeaTunnelDataType[] {BasicType.INT_TYPE, BasicType.STRING_TYPE});
248+
249+
Map<String, String> header = new HashMap<>();
250+
header.put("Content-Type", "application/json");
251+
252+
List<Map<String, String>> messagesList = new ArrayList<>();
253+
254+
Map<String, String> systemMessage = new HashMap<>();
255+
systemMessage.put("role", "system");
256+
systemMessage.put("content", "${prompt}");
257+
messagesList.add(systemMessage);
258+
259+
Map<String, String> userMessage = new HashMap<>();
260+
userMessage.put("role", "user");
261+
userMessage.put("content", "${input}");
262+
messagesList.add(userMessage);
263+
264+
Map<String, Object> resultMap = new HashMap<>();
265+
resultMap.put("model", "${model}");
266+
resultMap.put("stream", false);
267+
resultMap.put("messages", messagesList);
268+
269+
CustomModel model =
270+
new CustomModel(
271+
rowType,
272+
SqlType.STRING,
273+
null,
274+
"Determine whether someone is Chinese or American by their name",
275+
"qwen:7b",
276+
"http://localhost:11434/api/chat",
277+
header,
278+
resultMap,
279+
"$.message.content");
280+
281+
SeaTunnelRow row = new SeaTunnelRow(rowType.getFieldTypes().length);
282+
row.setField(0, 1);
283+
row.setField(1, "John");
284+
List<String> successResult = model.inference(Collections.singletonList(row));
285+
Assertions.assertFalse(successResult.isEmpty());
286+
}
212287
}

0 commit comments

Comments
 (0)