Skip to content

Commit fd6ce1e

Browse files
committed
Move discovery endpoint type inference to backend
Fixes #11
1 parent 7ab15e1 commit fd6ce1e

28 files changed

Lines changed: 830 additions & 113 deletions

File tree

api-docs/openapi/v3_0/aiFoundationApis.json

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,10 @@
563563
"tags" : [ "console.api.aifoundation.halo.run/v1alpha1/Model" ]
564564
}
565565
},
566-
"/apis/console.api.aifoundation.halo.run/v1alpha1/models/{name}/test-chat" : {
566+
"/apis/console.api.aifoundation.halo.run/v1alpha1/models/{name}/test-chat/stream" : {
567567
"post" : {
568-
"description" : "Test chat completion with a specific model.",
569-
"operationId" : "TestModelChat",
568+
"description" : "Test chat completion with streaming response.",
569+
"operationId" : "TestModelChatStream",
570570
"parameters" : [ {
571571
"description" : "Model name (AiModel.metadata.name)",
572572
"in" : "path",
@@ -590,7 +590,7 @@
590590
"content" : {
591591
"*/*" : {
592592
"schema" : {
593-
"type" : "string"
593+
"$ref" : "#/components/schemas/ChatChunk"
594594
}
595595
}
596596
},
@@ -801,7 +801,7 @@
801801
"content" : {
802802
"*/*" : {
803803
"schema" : {
804-
"type" : "string"
804+
"$ref" : "#/components/schemas/ProviderModelDiscoveryResponse"
805805
}
806806
}
807807
},
@@ -1066,6 +1066,27 @@
10661066
}
10671067
}
10681068
},
1069+
"ChatChunk" : {
1070+
"type" : "object",
1071+
"properties" : {
1072+
"content" : {
1073+
"type" : "string"
1074+
},
1075+
"finishReason" : {
1076+
"type" : "string"
1077+
},
1078+
"last" : {
1079+
"type" : "boolean"
1080+
},
1081+
"type" : {
1082+
"type" : "string",
1083+
"enum" : [ "TEXT", "REASONING", "TOOL_CALL", "ERROR", "FINISH" ]
1084+
},
1085+
"usage" : {
1086+
"$ref" : "#/components/schemas/Usage"
1087+
}
1088+
}
1089+
},
10691090
"CopyOperation" : {
10701091
"required" : [ "op", "from", "path" ],
10711092
"type" : "object",
@@ -1088,6 +1109,29 @@
10881109
}
10891110
}
10901111
},
1112+
"DiscoveredModelItem" : {
1113+
"type" : "object",
1114+
"properties" : {
1115+
"capabilities" : {
1116+
"type" : "array",
1117+
"items" : {
1118+
"type" : "string"
1119+
}
1120+
},
1121+
"displayName" : {
1122+
"type" : "string"
1123+
},
1124+
"modelId" : {
1125+
"type" : "string"
1126+
},
1127+
"name" : {
1128+
"type" : "string"
1129+
},
1130+
"suggestedEndpointType" : {
1131+
"type" : "string"
1132+
}
1133+
}
1134+
},
10911135
"JsonPatch" : {
10921136
"minItems" : 1,
10931137
"uniqueItems" : true,
@@ -1187,6 +1231,20 @@
11871231
}
11881232
}
11891233
},
1234+
"ProviderModelDiscoveryResponse" : {
1235+
"type" : "object",
1236+
"properties" : {
1237+
"models" : {
1238+
"type" : "array",
1239+
"items" : {
1240+
"$ref" : "#/components/schemas/DiscoveredModelItem"
1241+
}
1242+
},
1243+
"providerName" : {
1244+
"type" : "string"
1245+
}
1246+
}
1247+
},
11901248
"ProviderTypeInfo" : {
11911249
"type" : "object",
11921250
"properties" : {
@@ -1289,6 +1347,19 @@
12891347
"description" : "Value can be any JSON value"
12901348
}
12911349
}
1350+
},
1351+
"Usage" : {
1352+
"type" : "object",
1353+
"properties" : {
1354+
"completionTokens" : {
1355+
"type" : "integer",
1356+
"format" : "int32"
1357+
},
1358+
"promptTokens" : {
1359+
"type" : "integer",
1360+
"format" : "int32"
1361+
}
1362+
}
12921363
}
12931364
},
12941365
"securitySchemes" : {
@@ -1303,4 +1374,4 @@
13031374
}
13041375
}
13051376
}
1306-
}
1377+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package run.halo.aifoundation.endpoint;
2+
3+
import java.util.List;
4+
import org.springframework.lang.Nullable;
5+
6+
public record DiscoveredModelItem(
7+
String modelId,
8+
String displayName,
9+
String name,
10+
List<String> capabilities,
11+
@Nullable
12+
String suggestedEndpointType
13+
) {
14+
}

app/src/main/java/run/halo/aifoundation/endpoint/ModelConsoleEndpoint.java

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import static org.springdoc.webflux.core.fn.SpringdocRouteBuilder.route;
77

88
import io.swagger.v3.oas.annotations.enums.ParameterIn;
9+
import java.util.LinkedHashSet;
910
import java.util.List;
11+
import java.util.Locale;
1012
import java.util.Map;
1113
import lombok.Data;
1214
import lombok.RequiredArgsConstructor;
@@ -28,6 +30,8 @@
2830
import run.halo.aifoundation.Message;
2931
import run.halo.aifoundation.extension.AiModel;
3032
import run.halo.aifoundation.extension.AiProvider;
33+
import run.halo.aifoundation.provider.AiProviderType;
34+
import run.halo.aifoundation.provider.support.ModelCapability;
3135
import run.halo.aifoundation.provider.support.ProviderClientCache;
3236
import run.halo.app.core.extension.endpoint.CustomEndpoint;
3337
import run.halo.app.extension.GroupVersion;
@@ -229,7 +233,6 @@ private Mono<Void> validateModel(AiModel model) {
229233
}
230234
var providerName = model.getSpec().getProviderName();
231235
var modelId = model.getSpec().getModelId();
232-
var endpointType = model.getSpec().getEndpointType();
233236

234237
if (providerName == null || providerName.isBlank()) {
235238
return Mono.error(
@@ -239,10 +242,6 @@ private Mono<Void> validateModel(AiModel model) {
239242
return Mono.error(
240243
new ResponseStatusException(HttpStatus.BAD_REQUEST, "modelId is required"));
241244
}
242-
if (endpointType == null || endpointType.isBlank()) {
243-
return Mono.error(
244-
new ResponseStatusException(HttpStatus.BAD_REQUEST, "endpointType is required"));
245-
}
246245

247246
return client.fetch(AiProvider.class, providerName)
248247
.switchIfEmpty(Mono.error(new ResponseStatusException(HttpStatus.BAD_REQUEST,
@@ -254,7 +253,14 @@ private Mono<Void> validateModel(AiModel model) {
254253
return Mono.error(new ResponseStatusException(HttpStatus.BAD_REQUEST,
255254
"Unsupported provider type: " + providerType));
256255
}
257-
var supportedTypes = type.getSupportedEndpointTypes();
256+
applyDefaultEndpointType(model, type);
257+
var endpointType = model.getSpec().getEndpointType();
258+
if (endpointType == null || endpointType.isBlank()) {
259+
return Mono.error(new ResponseStatusException(HttpStatus.BAD_REQUEST,
260+
"endpointType is required and no supported default could be recommended"));
261+
}
262+
var supportedTypes = type.getSupportedEndpointTypes() != null
263+
? type.getSupportedEndpointTypes() : List.<String>of();
258264
if (!supportedTypes.contains(endpointType)) {
259265
return Mono.error(new ResponseStatusException(HttpStatus.BAD_REQUEST,
260266
"Endpoint type '" + endpointType + "' is not supported by provider type '"
@@ -264,6 +270,36 @@ private Mono<Void> validateModel(AiModel model) {
264270
});
265271
}
266272

273+
private void applyDefaultEndpointType(AiModel model, AiProviderType providerType) {
274+
var spec = model.getSpec();
275+
var endpointType = spec.getEndpointType();
276+
if (endpointType != null && !endpointType.isBlank()) {
277+
return;
278+
}
279+
providerType.recommendEndpointType(spec.getModelId(), modelCapabilities(model))
280+
.ifPresent(spec::setEndpointType);
281+
}
282+
283+
private List<ModelCapability> modelCapabilities(AiModel model) {
284+
var capabilities = new LinkedHashSet<ModelCapability>();
285+
var labels = model.getSpec().getCapabilities();
286+
if (labels == null) {
287+
return List.of();
288+
}
289+
for (var label : labels) {
290+
if (label == null) {
291+
continue;
292+
}
293+
switch (label.toLowerCase(Locale.ROOT)) {
294+
case "chat" -> capabilities.add(ModelCapability.CHAT);
295+
case "embedding" -> capabilities.add(ModelCapability.EMBEDDING);
296+
default -> {
297+
}
298+
}
299+
}
300+
return List.copyOf(capabilities);
301+
}
302+
267303
private Mono<Void> checkModelUniqueness(AiModel model, String excludeName) {
268304
var providerName = model.getSpec().getProviderName();
269305
var modelId = model.getSpec().getModelId();

app/src/main/java/run/halo/aifoundation/endpoint/ProviderConsoleEndpoint.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public RouterFunction<ServerResponse> endpoint() {
112112
.implementation(String.class)
113113
.required(true))
114114
.response(responseBuilder()
115-
.implementation(Map.class))
115+
.implementation(ProviderModelDiscoveryResponse.class))
116116
)
117117
.POST("providers/{name}/connectivity", this::testConnectivity,
118118
builder -> builder.operationId("TestProviderConnectivity")
@@ -246,22 +246,23 @@ private Mono<ServerResponse> discoverModelsViaProviderType(AiProvider provider,
246246
}
247247
return providerType.discoverModels(provider, apiKey)
248248
.map(models -> models.stream()
249-
.map(dm -> Map.<String, Object>of(
250-
"modelId", dm.modelId(),
251-
"displayName", dm.displayName(),
252-
"name", "",
253-
"capabilities", dm.capabilities().stream()
249+
.map(dm -> new DiscoveredModelItem(
250+
dm.modelId(),
251+
dm.displayName(),
252+
"",
253+
dm.capabilities().stream()
254254
.map(ModelCapability::name)
255255
.map(String::toLowerCase)
256-
.toList()
256+
.toList(),
257+
providerType.recommendEndpointType(dm).orElse(null)
257258
))
258259
.toList()
259260
)
260261
.flatMap(models -> {
261262
log.info("Discovered {} models for provider {}", models.size(), providerName);
262263
return ServerResponse.ok()
263264
.contentType(MediaType.APPLICATION_JSON)
264-
.bodyValue(Map.of("models", models, "providerName", providerName));
265+
.bodyValue(new ProviderModelDiscoveryResponse(providerName, models));
265266
});
266267
}
267268

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package run.halo.aifoundation.endpoint;
2+
3+
import java.util.List;
4+
5+
public record ProviderModelDiscoveryResponse(
6+
String providerName,
7+
List<DiscoveredModelItem> models
8+
) {
9+
}

app/src/main/java/run/halo/aifoundation/provider/AiProviderType.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
package run.halo.aifoundation.provider;
22

3+
import java.util.Collection;
4+
import java.util.LinkedHashSet;
35
import java.util.List;
6+
import java.util.Locale;
7+
import java.util.Optional;
8+
import java.util.Set;
49
import org.springframework.ai.chat.model.ChatModel;
510
import org.springframework.ai.embedding.EmbeddingModel;
611
import org.springframework.lang.Nullable;
712
import reactor.core.publisher.Mono;
813
import run.halo.aifoundation.extension.AiProvider;
914
import run.halo.aifoundation.provider.support.DiscoveredModel;
15+
import run.halo.aifoundation.provider.support.ModelCapability;
1016

1117
public interface AiProviderType {
1218

@@ -62,6 +68,46 @@ default EmbeddingModel buildEmbeddingModel(AiProvider provider, String apiKey, S
6268

6369
Mono<List<DiscoveredModel>> discoverModels(AiProvider provider, String apiKey);
6470

71+
default Optional<String> recommendEndpointType(DiscoveredModel model) {
72+
return recommendEndpointType(model.modelId(), model.capabilities());
73+
}
74+
75+
default Optional<String> recommendEndpointType(String modelId,
76+
Collection<ModelCapability> capabilities) {
77+
var supportedTypes = getSupportedEndpointTypes();
78+
if (supportedTypes == null || supportedTypes.isEmpty()) {
79+
return Optional.empty();
80+
}
81+
82+
var normalizedCapabilities = new LinkedHashSet<>(
83+
capabilities != null ? capabilities : Set.<ModelCapability>of());
84+
if (normalizedCapabilities.isEmpty()) {
85+
var normalizedModelId = modelId != null ? modelId.toLowerCase(Locale.ROOT) : "";
86+
normalizedCapabilities.add(normalizedModelId.contains("embed")
87+
? ModelCapability.EMBEDDING : ModelCapability.CHAT);
88+
}
89+
90+
if (normalizedCapabilities.contains(ModelCapability.EMBEDDING)) {
91+
var embeddingEndpoint = findSupportedEndpointType("embedding");
92+
if (embeddingEndpoint.isPresent()) {
93+
return embeddingEndpoint;
94+
}
95+
}
96+
if (normalizedCapabilities.contains(ModelCapability.CHAT)) {
97+
var chatEndpoint = findSupportedEndpointType("chat");
98+
if (chatEndpoint.isPresent()) {
99+
return chatEndpoint;
100+
}
101+
}
102+
return Optional.empty();
103+
}
104+
105+
private Optional<String> findSupportedEndpointType(String token) {
106+
return getSupportedEndpointTypes().stream()
107+
.filter(endpointType -> endpointType.toLowerCase(Locale.ROOT).contains(token))
108+
.findFirst();
109+
}
110+
65111
default int maxEmbeddingsPerCall() {
66112
return 96;
67113
}

0 commit comments

Comments
 (0)