|
35 | 35 | import org.junit.jupiter.api.Assertions;
|
36 | 36 | import org.junit.jupiter.api.Test;
|
37 | 37 |
|
| 38 | +import okhttp3.mockwebserver.MockResponse; |
| 39 | +import okhttp3.mockwebserver.MockWebServer; |
| 40 | + |
38 | 41 | import java.io.IOException;
|
39 | 42 | import java.lang.reflect.Field;
|
40 | 43 | import java.util.ArrayList;
|
| 44 | +import java.util.Collections; |
41 | 45 | import java.util.HashMap;
|
42 | 46 | import java.util.List;
|
43 | 47 | import java.util.Map;
|
@@ -172,7 +176,6 @@ void testCustomRequestJson() throws IOException {
|
172 | 176 |
|
173 | 177 | Map<String, String> header = new HashMap<>();
|
174 | 178 | header.put("Content-Type", "application/json");
|
175 |
| - header.put("Authorization", "Bearer " + "apikey"); |
176 | 179 |
|
177 | 180 | List<Map<String, String>> messagesList = new ArrayList<>();
|
178 | 181 |
|
@@ -209,4 +212,76 @@ void testCustomRequestJson() throws IOException {
|
209 | 212 | "{\"messages\":[{\"role\":\"system\",\"content\":\"Determine whether someone is Chinese or American by their name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1, \\\"name\\\":\\\"John\\\"}\"}],\"model\":\"custom-model\"}",
|
210 | 213 | OBJECT_MAPPER.writeValueAsString(node));
|
211 | 214 | }
|
| 215 | + |
| 216 | + @Test |
| 217 | + void testCustomOllamaRequestJson() throws IOException { |
| 218 | + |
| 219 | + MockWebServer mockWebServer = new MockWebServer(); |
| 220 | + mockWebServer.start(11434); |
| 221 | + String jsonResponse = |
| 222 | + "{\n" |
| 223 | + + " \"model\": \"qwen:7b\",\n" |
| 224 | + + " \"created_at\": \"2025-02-07T01:22:46.589856Z\",\n" |
| 225 | + + " \"message\": {\n" |
| 226 | + + " \"role\": \"assistant\",\n" |
| 227 | + + " \"content\": \"Based on the information provided in the JSON object, \\\"John\\\" does not inherently indicate if the person is Chinese or American. The name \\\"John\\\" is commonly used across many cultures. To determine a person's nationality based solely on their name, more context would be needed.\"\n" |
| 228 | + + " },\n" |
| 229 | + + " \"done_reason\": \"stop\",\n" |
| 230 | + + " \"done\": true,\n" |
| 231 | + + " \"total_duration\": 14435322300,\n" |
| 232 | + + " \"load_duration\": 28998200,\n" |
| 233 | + + " \"prompt_eval_count\": 34,\n" |
| 234 | + + " \"prompt_eval_duration\": 302000000,\n" |
| 235 | + + " \"eval_count\": 56,\n" |
| 236 | + + " \"eval_duration\": 14102000000\n" |
| 237 | + + "}"; |
| 238 | + |
| 239 | + mockWebServer.enqueue( |
| 240 | + new MockResponse() |
| 241 | + .setBody(jsonResponse) |
| 242 | + .addHeader("Content-Type", "application/json")); |
| 243 | + |
| 244 | + SeaTunnelRowType rowType = |
| 245 | + new SeaTunnelRowType( |
| 246 | + new String[] {"id", "name"}, |
| 247 | + new SeaTunnelDataType[] {BasicType.INT_TYPE, BasicType.STRING_TYPE}); |
| 248 | + |
| 249 | + Map<String, String> header = new HashMap<>(); |
| 250 | + header.put("Content-Type", "application/json"); |
| 251 | + |
| 252 | + List<Map<String, String>> messagesList = new ArrayList<>(); |
| 253 | + |
| 254 | + Map<String, String> systemMessage = new HashMap<>(); |
| 255 | + systemMessage.put("role", "system"); |
| 256 | + systemMessage.put("content", "${prompt}"); |
| 257 | + messagesList.add(systemMessage); |
| 258 | + |
| 259 | + Map<String, String> userMessage = new HashMap<>(); |
| 260 | + userMessage.put("role", "user"); |
| 261 | + userMessage.put("content", "${input}"); |
| 262 | + messagesList.add(userMessage); |
| 263 | + |
| 264 | + Map<String, Object> resultMap = new HashMap<>(); |
| 265 | + resultMap.put("model", "${model}"); |
| 266 | + resultMap.put("stream", false); |
| 267 | + resultMap.put("messages", messagesList); |
| 268 | + |
| 269 | + CustomModel model = |
| 270 | + new CustomModel( |
| 271 | + rowType, |
| 272 | + SqlType.STRING, |
| 273 | + null, |
| 274 | + "Determine whether someone is Chinese or American by their name", |
| 275 | + "qwen:7b", |
| 276 | + "http://localhost:11434/api/chat", |
| 277 | + header, |
| 278 | + resultMap, |
| 279 | + "$.message.content"); |
| 280 | + |
| 281 | + SeaTunnelRow row = new SeaTunnelRow(rowType.getFieldTypes().length); |
| 282 | + row.setField(0, 1); |
| 283 | + row.setField(1, "John"); |
| 284 | + List<String> successResult = model.inference(Collections.singletonList(row)); |
| 285 | + Assertions.assertFalse(successResult.isEmpty()); |
| 286 | + } |
212 | 287 | }
|
0 commit comments