66import static org .springdoc .webflux .core .fn .SpringdocRouteBuilder .route ;
77
88import io .swagger .v3 .oas .annotations .enums .ParameterIn ;
9+ import java .util .LinkedHashSet ;
910import java .util .List ;
11+ import java .util .Locale ;
1012import java .util .Map ;
1113import lombok .Data ;
1214import lombok .RequiredArgsConstructor ;
2830import run .halo .aifoundation .Message ;
2931import run .halo .aifoundation .extension .AiModel ;
3032import run .halo .aifoundation .extension .AiProvider ;
33+ import run .halo .aifoundation .provider .AiProviderType ;
34+ import run .halo .aifoundation .provider .support .ModelCapability ;
3135import run .halo .aifoundation .provider .support .ProviderClientCache ;
3236import run .halo .app .core .extension .endpoint .CustomEndpoint ;
3337import run .halo .app .extension .GroupVersion ;
@@ -229,7 +233,6 @@ private Mono<Void> validateModel(AiModel model) {
229233 }
230234 var providerName = model .getSpec ().getProviderName ();
231235 var modelId = model .getSpec ().getModelId ();
232- var endpointType = model .getSpec ().getEndpointType ();
233236
234237 if (providerName == null || providerName .isBlank ()) {
235238 return Mono .error (
@@ -239,10 +242,6 @@ private Mono<Void> validateModel(AiModel model) {
239242 return Mono .error (
240243 new ResponseStatusException (HttpStatus .BAD_REQUEST , "modelId is required" ));
241244 }
242- if (endpointType == null || endpointType .isBlank ()) {
243- return Mono .error (
244- new ResponseStatusException (HttpStatus .BAD_REQUEST , "endpointType is required" ));
245- }
246245
247246 return client .fetch (AiProvider .class , providerName )
248247 .switchIfEmpty (Mono .error (new ResponseStatusException (HttpStatus .BAD_REQUEST ,
@@ -254,7 +253,14 @@ private Mono<Void> validateModel(AiModel model) {
254253 return Mono .error (new ResponseStatusException (HttpStatus .BAD_REQUEST ,
255254 "Unsupported provider type: " + providerType ));
256255 }
257- var supportedTypes = type .getSupportedEndpointTypes ();
256+ applyDefaultEndpointType (model , type );
257+ var endpointType = model .getSpec ().getEndpointType ();
258+ if (endpointType == null || endpointType .isBlank ()) {
259+ return Mono .error (new ResponseStatusException (HttpStatus .BAD_REQUEST ,
260+ "endpointType is required and no supported default could be recommended" ));
261+ }
262+ var supportedTypes = type .getSupportedEndpointTypes () != null
263+ ? type .getSupportedEndpointTypes () : List .<String >of ();
258264 if (!supportedTypes .contains (endpointType )) {
259265 return Mono .error (new ResponseStatusException (HttpStatus .BAD_REQUEST ,
260266 "Endpoint type '" + endpointType + "' is not supported by provider type '"
@@ -264,6 +270,36 @@ private Mono<Void> validateModel(AiModel model) {
264270 });
265271 }
266272
273+ private void applyDefaultEndpointType (AiModel model , AiProviderType providerType ) {
274+ var spec = model .getSpec ();
275+ var endpointType = spec .getEndpointType ();
276+ if (endpointType != null && !endpointType .isBlank ()) {
277+ return ;
278+ }
279+ providerType .recommendEndpointType (spec .getModelId (), modelCapabilities (model ))
280+ .ifPresent (spec ::setEndpointType );
281+ }
282+
283+ private List <ModelCapability > modelCapabilities (AiModel model ) {
284+ var capabilities = new LinkedHashSet <ModelCapability >();
285+ var labels = model .getSpec ().getCapabilities ();
286+ if (labels == null ) {
287+ return List .of ();
288+ }
289+ for (var label : labels ) {
290+ if (label == null ) {
291+ continue ;
292+ }
293+ switch (label .toLowerCase (Locale .ROOT )) {
294+ case "chat" -> capabilities .add (ModelCapability .CHAT );
295+ case "embedding" -> capabilities .add (ModelCapability .EMBEDDING );
296+ default -> {
297+ }
298+ }
299+ }
300+ return List .copyOf (capabilities );
301+ }
302+
267303 private Mono <Void > checkModelUniqueness (AiModel model , String excludeName ) {
268304 var providerName = model .getSpec ().getProviderName ();
269305 var modelId = model .getSpec ().getModelId ();
0 commit comments