Skip to content

Commit c509a42

Browse files
authored
refactor: merge default model resolution APIs (#47)
* refactor: merge default model resolution APIs * docs: update AI model service API guide * refactor: remove no-arg model service methods * Revert "refactor: remove no-arg model service methods" This reverts commit a90d781.
1 parent 7751bc7 commit c509a42

4 files changed

Lines changed: 95 additions & 171 deletions

File tree

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
package run.halo.aifoundation;
22

3-
import java.util.List;
43
import org.pf4j.ExtensionPoint;
54
import reactor.core.publisher.Mono;
65
import run.halo.aifoundation.chat.LanguageModel;
76
import run.halo.aifoundation.embedding.EmbeddingModel;
8-
import run.halo.aifoundation.model.ModelInfo;
9-
import run.halo.aifoundation.model.ProviderInfo;
107

118
/**
129
* Cross-plugin entry point for resolving AI models managed by plugin-ai-foundation.
@@ -17,33 +14,25 @@
1714
*/
1815
public interface AiModelService extends ExtensionPoint {
1916

20-
/**
21-
* Resolves an enabled language model by {@code AiModel.metadata.name}.
22-
*/
23-
Mono<LanguageModel> languageModel(String modelName);
24-
25-
/**
26-
* Resolves an enabled embedding model by {@code AiModel.metadata.name}.
27-
*/
28-
Mono<EmbeddingModel> embeddingModel(String modelName);
29-
3017
/**
3118
* Resolves the configured default language model.
3219
*/
33-
Mono<LanguageModel> defaultLanguageModel();
20+
Mono<LanguageModel> languageModel();
3421

3522
/**
36-
* Resolves the configured default embedding model.
23+
* Resolves an enabled language model by {@code AiModel.metadata.name}. When {@code modelName}
24+
* is {@code null} or blank, resolves the configured default language model.
3725
*/
38-
Mono<EmbeddingModel> defaultEmbeddingModel();
26+
Mono<LanguageModel> languageModel(String modelName);
3927

4028
/**
41-
* Lists configured model resources visible through this service.
29+
* Resolves the configured default embedding model.
4230
*/
43-
Mono<List<ModelInfo>> listModels();
31+
Mono<EmbeddingModel> embeddingModel();
4432

4533
/**
46-
* Lists configured provider resources visible through this service.
34+
* Resolves an enabled embedding model by {@code AiModel.metadata.name}. When {@code modelName}
35+
* is {@code null} or blank, resolves the configured default embedding model.
4736
*/
48-
Mono<List<ProviderInfo>> listProviders();
37+
Mono<EmbeddingModel> embeddingModel(String modelName);
4938
}
Lines changed: 20 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,52 @@
11
package run.halo.aifoundation.service;
22

3-
import java.util.List;
43
import lombok.RequiredArgsConstructor;
54
import lombok.extern.slf4j.Slf4j;
6-
import org.springframework.data.domain.Sort;
75
import org.springframework.stereotype.Component;
6+
import org.springframework.util.StringUtils;
87
import reactor.core.publisher.Mono;
98
import run.halo.aifoundation.AiModelService;
10-
import run.halo.aifoundation.embedding.EmbeddingModel;
119
import run.halo.aifoundation.chat.LanguageModel;
12-
import run.halo.aifoundation.model.ModelInfo;
13-
import run.halo.aifoundation.model.ProviderInfo;
14-
import run.halo.aifoundation.extension.AiModel;
15-
import run.halo.aifoundation.extension.AiProvider;
10+
import run.halo.aifoundation.embedding.EmbeddingModel;
1611
import run.halo.aifoundation.provider.support.ModelType;
17-
import run.halo.app.extension.ListOptions;
18-
import run.halo.app.extension.ReactiveExtensionClient;
1912

2013
@Slf4j
2114
@Component
2215
@RequiredArgsConstructor
2316
public class AiModelServiceImpl implements AiModelService {
2417

25-
private final ReactiveExtensionClient client;
2618
private final AiModelResolver modelResolver;
2719
private final LanguageModelFactory languageModelFactory;
2820
private final EmbeddingModelFactory embeddingModelFactory;
2921

3022
@Override
31-
public Mono<LanguageModel> languageModel(String modelName) {
32-
return modelResolver.resolve(modelName, ModelType.LANGUAGE)
33-
.map(languageModelFactory::create);
34-
}
35-
36-
@Override
37-
public Mono<EmbeddingModel> embeddingModel(String modelName) {
38-
return modelResolver.resolve(modelName, ModelType.EMBEDDING)
39-
.map(embeddingModelFactory::create);
40-
}
41-
42-
@Override
43-
public Mono<LanguageModel> defaultLanguageModel() {
44-
return modelResolver.defaultLanguageModelName()
45-
.flatMap(this::languageModel);
23+
public Mono<LanguageModel> languageModel() {
24+
return languageModel(null);
4625
}
4726

4827
@Override
49-
public Mono<EmbeddingModel> defaultEmbeddingModel() {
50-
return modelResolver.defaultEmbeddingModelName()
51-
.flatMap(this::embeddingModel);
28+
public Mono<LanguageModel> languageModel(String modelName) {
29+
var resolvedModelName = StringUtils.hasText(modelName)
30+
? Mono.just(modelName)
31+
: modelResolver.defaultLanguageModelName();
32+
return resolvedModelName
33+
.flatMap(name -> modelResolver.resolve(name, ModelType.LANGUAGE))
34+
.map(languageModelFactory::create);
5235
}
5336

5437
@Override
55-
public Mono<List<ModelInfo>> listModels() {
56-
return client.listAll(AiModel.class, new ListOptions(),
57-
Sort.by("metadata.creationTimestamp").descending())
58-
.map(model -> ModelInfo.builder()
59-
.name(model.getMetadata().getName())
60-
.providerName(model.getSpec().getProviderName())
61-
.modelId(model.getSpec().getModelId())
62-
.displayName(model.getSpec().getDisplayName())
63-
.enabled(model.getSpec().isEnabled())
64-
.build())
65-
.collectList();
38+
public Mono<EmbeddingModel> embeddingModel() {
39+
return embeddingModel(null);
6640
}
6741

6842
@Override
69-
public Mono<List<ProviderInfo>> listProviders() {
70-
return client.listAll(AiProvider.class, new ListOptions(),
71-
Sort.by("metadata.creationTimestamp").descending())
72-
.map(provider -> {
73-
var status = provider.getStatus();
74-
var phase = status != null && status.getPhase() != null
75-
? status.getPhase().name() : "UNKNOWN";
76-
var lastCheckedAt = status != null && status.getLastCheckedAt() != null
77-
? status.getLastCheckedAt().toString() : null;
78-
return ProviderInfo.builder()
79-
.name(provider.getMetadata().getName())
80-
.displayName(provider.getSpec().getDisplayName())
81-
.providerType(provider.getSpec().getProviderType())
82-
.enabled(provider.getSpec().isEnabled())
83-
.phase(phase)
84-
.lastCheckedAt(lastCheckedAt)
85-
.build();
86-
})
87-
.collectList();
43+
public Mono<EmbeddingModel> embeddingModel(String modelName) {
44+
var resolvedModelName = StringUtils.hasText(modelName)
45+
? Mono.just(modelName)
46+
: modelResolver.defaultEmbeddingModelName();
47+
return resolvedModelName
48+
.flatMap(name -> modelResolver.resolve(name, ModelType.EMBEDDING))
49+
.map(embeddingModelFactory::create);
8850
}
8951

9052
}

app/src/test/java/run/halo/aifoundation/service/AiModelServiceImplTest.java

Lines changed: 56 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,15 @@
22

33
import static org.assertj.core.api.Assertions.assertThat;
44
import static org.mockito.ArgumentMatchers.any;
5-
import static org.mockito.ArgumentMatchers.eq;
65
import static org.mockito.Mockito.mock;
76
import static org.mockito.Mockito.when;
87

9-
import java.util.List;
108
import org.springframework.ai.chat.model.ChatModel;
119
import org.junit.jupiter.api.BeforeEach;
1210
import org.junit.jupiter.api.Test;
1311
import org.junit.jupiter.api.extension.ExtendWith;
1412
import org.mockito.Mock;
1513
import org.mockito.junit.jupiter.MockitoExtension;
16-
import org.springframework.data.domain.Sort;
17-
import reactor.core.publisher.Flux;
1814
import reactor.core.publisher.Mono;
1915
import reactor.test.StepVerifier;
2016
import run.halo.aifoundation.exception.DefaultModelNotConfiguredException;
@@ -33,7 +29,6 @@
3329
import run.halo.aifoundation.service.model.DefaultAiModelResolver;
3430
import run.halo.aifoundation.setting.DefaultModelSlotStore;
3531
import run.halo.aifoundation.setting.DefaultModelSlots;
36-
import run.halo.app.extension.ListOptions;
3732
import run.halo.app.extension.Metadata;
3833
import run.halo.app.extension.ReactiveExtensionClient;
3934

@@ -57,82 +52,13 @@ class AiModelServiceImplTest {
5752
@BeforeEach
5853
void setUp() {
5954
service = new AiModelServiceImpl(
60-
client,
6155
new DefaultAiModelResolver(client, providerClientCache, secretResolver,
6256
defaultModelSlotStore),
6357
new DefaultLanguageModelFactory(providerClientCache),
6458
new DefaultEmbeddingModelFactory(providerClientCache)
6559
);
6660
}
6761

68-
// ---- listModels ----
69-
70-
@Test
71-
void listModels_returnsAllModels() {
72-
when(client.listAll(eq(AiModel.class), any(ListOptions.class), any(Sort.class)))
73-
.thenReturn(Flux.just(
74-
aiModel("openai-prod-gpt-4-abc", "provider-a", "gpt-4", "GPT-4", true),
75-
aiModel("ollama-local-claude-3-xyz", "provider-b", "claude-3", "Claude 3", true)
76-
));
77-
78-
StepVerifier.create(service.listModels())
79-
.assertNext(models -> {
80-
assertThat(models).hasSize(2);
81-
assertThat(models.get(0).getName()).isEqualTo("openai-prod-gpt-4-abc");
82-
assertThat(models.get(0).getModelId()).isEqualTo("gpt-4");
83-
assertThat(models.get(0).getProviderName()).isEqualTo("provider-a");
84-
assertThat(models.get(1).getName()).isEqualTo("ollama-local-claude-3-xyz");
85-
assertThat(models.get(1).getModelId()).isEqualTo("claude-3");
86-
})
87-
.verifyComplete();
88-
}
89-
90-
@Test
91-
void listModels_emptyResult_returnsEmptyList() {
92-
when(client.listAll(eq(AiModel.class), any(ListOptions.class), any(Sort.class)))
93-
.thenReturn(Flux.empty());
94-
95-
StepVerifier.create(service.listModels())
96-
.assertNext(models -> assertThat(models).isEmpty())
97-
.verifyComplete();
98-
}
99-
100-
// ---- listProviders ----
101-
102-
@Test
103-
void listProviders_returnsAllProviders() {
104-
var provider1 = aiProvider("openai-prod", "openai", true);
105-
provider1.setStatus(statusWithPhase(AiProvider.AiProviderStatus.Phase.OK));
106-
var provider2 = aiProvider("ollama-local", "ollama", false);
107-
108-
when(client.listAll(eq(AiProvider.class), any(ListOptions.class), any(Sort.class)))
109-
.thenReturn(Flux.just(provider1, provider2));
110-
111-
StepVerifier.create(service.listProviders())
112-
.assertNext(providers -> {
113-
assertThat(providers).hasSize(2);
114-
assertThat(providers.get(0).getName()).isEqualTo("openai-prod");
115-
assertThat(providers.get(0).getProviderType()).isEqualTo("openai");
116-
assertThat(providers.get(0).isEnabled()).isTrue();
117-
assertThat(providers.get(0).getPhase()).isEqualTo("OK");
118-
assertThat(providers.get(1).isEnabled()).isFalse();
119-
assertThat(providers.get(1).getPhase()).isEqualTo("UNKNOWN");
120-
})
121-
.verifyComplete();
122-
}
123-
124-
@Test
125-
void listProviders_nullStatus_showsUnknownPhase() {
126-
var provider = aiProvider("my-provider", "openai", true);
127-
provider.setStatus(null);
128-
when(client.listAll(eq(AiProvider.class), any(ListOptions.class), any(Sort.class)))
129-
.thenReturn(Flux.just(provider));
130-
131-
StepVerifier.create(service.listProviders())
132-
.assertNext(providers -> assertThat(providers.get(0).getPhase()).isEqualTo("UNKNOWN"))
133-
.verifyComplete();
134-
}
135-
13662
// ---- languageModel — fetch by metadata.name ----
13763

13864
@Test
@@ -194,16 +120,37 @@ void languageModel_wrongModelType_emitsIncompatibleModelTypeException() {
194120
}
195121

196122
@Test
197-
void defaultLanguageModel_missingSlot_emitsDefaultModelNotConfiguredException() {
123+
void languageModel_withoutNameAndMissingSlot_emitsDefaultModelNotConfiguredException() {
198124
when(defaultModelSlotStore.get()).thenReturn(Mono.just(new DefaultModelSlots()));
199125

200-
StepVerifier.create(service.defaultLanguageModel())
126+
StepVerifier.create(service.languageModel())
201127
.expectError(DefaultModelNotConfiguredException.class)
202128
.verify();
203129
}
204130

205131
@Test
206-
void defaultLanguageModel_resolvesConfiguredModel() {
132+
void languageModel_withoutName_resolvesConfiguredModel() {
133+
var slots = defaultSlots("openai-prod-gpt-4-abc", null);
134+
var model = aiModel("openai-prod-gpt-4-abc", "openai-prod", "gpt-4", "GPT-4", true);
135+
var provider = aiProvider("openai-prod", "openai", true);
136+
var chatModel = mock(ChatModel.class);
137+
var providerType = languageProviderType();
138+
139+
when(defaultModelSlotStore.get()).thenReturn(Mono.just(slots));
140+
when(client.fetch(AiModel.class, "openai-prod-gpt-4-abc")).thenReturn(Mono.just(model));
141+
when(client.fetch(AiProvider.class, "openai-prod")).thenReturn(Mono.just(provider));
142+
when(secretResolver.resolveApiKey(null)).thenReturn(Mono.just("sk-test"));
143+
when(providerClientCache.getProviderType("openai")).thenReturn(providerType);
144+
when(providerClientCache.getOrCreateChatModel(provider, "sk-test", "gpt-4"))
145+
.thenReturn(chatModel);
146+
147+
StepVerifier.create(service.languageModel())
148+
.assertNext(languageModel -> assertThat(languageModel).isNotNull())
149+
.verifyComplete();
150+
}
151+
152+
@Test
153+
void languageModel_blankName_resolvesConfiguredModel() {
207154
var slots = defaultSlots("openai-prod-gpt-4-abc", null);
208155
var model = aiModel("openai-prod-gpt-4-abc", "openai-prod", "gpt-4", "GPT-4", true);
209156
var provider = aiProvider("openai-prod", "openai", true);
@@ -218,11 +165,42 @@ void defaultLanguageModel_resolvesConfiguredModel() {
218165
when(providerClientCache.getOrCreateChatModel(provider, "sk-test", "gpt-4"))
219166
.thenReturn(chatModel);
220167

221-
StepVerifier.create(service.defaultLanguageModel())
168+
StepVerifier.create(service.languageModel(" "))
222169
.assertNext(languageModel -> assertThat(languageModel).isNotNull())
223170
.verifyComplete();
224171
}
225172

173+
@Test
174+
void embeddingModel_withoutName_resolvesConfiguredModel() {
175+
var slots = defaultSlots(null, "openai-prod-embedding");
176+
var model = aiModel("openai-prod-embedding", "openai-prod",
177+
"text-embedding-3-small", "Embedding", true, ModelType.EMBEDDING);
178+
var provider = aiProvider("openai-prod", "openai", true);
179+
var springEmbeddingModel = mock(org.springframework.ai.embedding.EmbeddingModel.class);
180+
var providerType = mock(AiProviderType.class);
181+
182+
when(defaultModelSlotStore.get()).thenReturn(Mono.just(slots));
183+
when(client.fetch(AiModel.class, "openai-prod-embedding")).thenReturn(Mono.just(model));
184+
when(client.fetch(AiProvider.class, "openai-prod")).thenReturn(Mono.just(provider));
185+
when(secretResolver.resolveApiKey(null)).thenReturn(Mono.just("sk-test"));
186+
when(providerClientCache.getProviderType("openai")).thenReturn(providerType);
187+
when(providerClientCache.getOrCreateEmbeddingModel(provider, "sk-test",
188+
"text-embedding-3-small")).thenReturn(springEmbeddingModel);
189+
190+
StepVerifier.create(service.embeddingModel())
191+
.assertNext(embeddingModel -> assertThat(embeddingModel).isNotNull())
192+
.verifyComplete();
193+
}
194+
195+
@Test
196+
void embeddingModel_blankNameAndMissingSlot_emitsDefaultModelNotConfiguredException() {
197+
when(defaultModelSlotStore.get()).thenReturn(Mono.just(new DefaultModelSlots()));
198+
199+
StepVerifier.create(service.embeddingModel(""))
200+
.expectError(DefaultModelNotConfiguredException.class)
201+
.verify();
202+
}
203+
226204
// ---- helpers ----
227205

228206
private AiModel aiModel(String name, String providerName, String modelId,
@@ -259,12 +237,6 @@ private AiProvider aiProvider(String name, String providerType, boolean enabled)
259237
return provider;
260238
}
261239

262-
private AiProvider.AiProviderStatus statusWithPhase(AiProvider.AiProviderStatus.Phase phase) {
263-
var status = new AiProvider.AiProviderStatus();
264-
status.setPhase(phase);
265-
return status;
266-
}
267-
268240
private DefaultModelSlots defaultSlots(String languageModelName, String embeddingModelName) {
269241
var slots = new DefaultModelSlots();
270242
slots.setLanguageModelName(languageModelName);

0 commit comments

Comments
 (0)