22
33import static org .assertj .core .api .Assertions .assertThat ;
44import static org .mockito .ArgumentMatchers .any ;
5- import static org .mockito .ArgumentMatchers .eq ;
65import static org .mockito .Mockito .mock ;
76import static org .mockito .Mockito .when ;
87
9- import java .util .List ;
108import org .springframework .ai .chat .model .ChatModel ;
119import org .junit .jupiter .api .BeforeEach ;
1210import org .junit .jupiter .api .Test ;
1311import org .junit .jupiter .api .extension .ExtendWith ;
1412import org .mockito .Mock ;
1513import org .mockito .junit .jupiter .MockitoExtension ;
16- import org .springframework .data .domain .Sort ;
17- import reactor .core .publisher .Flux ;
1814import reactor .core .publisher .Mono ;
1915import reactor .test .StepVerifier ;
2016import run .halo .aifoundation .exception .DefaultModelNotConfiguredException ;
3329import run .halo .aifoundation .service .model .DefaultAiModelResolver ;
3430import run .halo .aifoundation .setting .DefaultModelSlotStore ;
3531import run .halo .aifoundation .setting .DefaultModelSlots ;
36- import run .halo .app .extension .ListOptions ;
3732import run .halo .app .extension .Metadata ;
3833import run .halo .app .extension .ReactiveExtensionClient ;
3934
@@ -57,82 +52,13 @@ class AiModelServiceImplTest {
5752 @ BeforeEach
5853 void setUp () {
5954 service = new AiModelServiceImpl (
60- client ,
6155 new DefaultAiModelResolver (client , providerClientCache , secretResolver ,
6256 defaultModelSlotStore ),
6357 new DefaultLanguageModelFactory (providerClientCache ),
6458 new DefaultEmbeddingModelFactory (providerClientCache )
6559 );
6660 }
6761
68- // ---- listModels ----
69-
70- @ Test
71- void listModels_returnsAllModels () {
72- when (client .listAll (eq (AiModel .class ), any (ListOptions .class ), any (Sort .class )))
73- .thenReturn (Flux .just (
74- aiModel ("openai-prod-gpt-4-abc" , "provider-a" , "gpt-4" , "GPT-4" , true ),
75- aiModel ("ollama-local-claude-3-xyz" , "provider-b" , "claude-3" , "Claude 3" , true )
76- ));
77-
78- StepVerifier .create (service .listModels ())
79- .assertNext (models -> {
80- assertThat (models ).hasSize (2 );
81- assertThat (models .get (0 ).getName ()).isEqualTo ("openai-prod-gpt-4-abc" );
82- assertThat (models .get (0 ).getModelId ()).isEqualTo ("gpt-4" );
83- assertThat (models .get (0 ).getProviderName ()).isEqualTo ("provider-a" );
84- assertThat (models .get (1 ).getName ()).isEqualTo ("ollama-local-claude-3-xyz" );
85- assertThat (models .get (1 ).getModelId ()).isEqualTo ("claude-3" );
86- })
87- .verifyComplete ();
88- }
89-
90- @ Test
91- void listModels_emptyResult_returnsEmptyList () {
92- when (client .listAll (eq (AiModel .class ), any (ListOptions .class ), any (Sort .class )))
93- .thenReturn (Flux .empty ());
94-
95- StepVerifier .create (service .listModels ())
96- .assertNext (models -> assertThat (models ).isEmpty ())
97- .verifyComplete ();
98- }
99-
100- // ---- listProviders ----
101-
102- @ Test
103- void listProviders_returnsAllProviders () {
104- var provider1 = aiProvider ("openai-prod" , "openai" , true );
105- provider1 .setStatus (statusWithPhase (AiProvider .AiProviderStatus .Phase .OK ));
106- var provider2 = aiProvider ("ollama-local" , "ollama" , false );
107-
108- when (client .listAll (eq (AiProvider .class ), any (ListOptions .class ), any (Sort .class )))
109- .thenReturn (Flux .just (provider1 , provider2 ));
110-
111- StepVerifier .create (service .listProviders ())
112- .assertNext (providers -> {
113- assertThat (providers ).hasSize (2 );
114- assertThat (providers .get (0 ).getName ()).isEqualTo ("openai-prod" );
115- assertThat (providers .get (0 ).getProviderType ()).isEqualTo ("openai" );
116- assertThat (providers .get (0 ).isEnabled ()).isTrue ();
117- assertThat (providers .get (0 ).getPhase ()).isEqualTo ("OK" );
118- assertThat (providers .get (1 ).isEnabled ()).isFalse ();
119- assertThat (providers .get (1 ).getPhase ()).isEqualTo ("UNKNOWN" );
120- })
121- .verifyComplete ();
122- }
123-
124- @ Test
125- void listProviders_nullStatus_showsUnknownPhase () {
126- var provider = aiProvider ("my-provider" , "openai" , true );
127- provider .setStatus (null );
128- when (client .listAll (eq (AiProvider .class ), any (ListOptions .class ), any (Sort .class )))
129- .thenReturn (Flux .just (provider ));
130-
131- StepVerifier .create (service .listProviders ())
132- .assertNext (providers -> assertThat (providers .get (0 ).getPhase ()).isEqualTo ("UNKNOWN" ))
133- .verifyComplete ();
134- }
135-
13662 // ---- languageModel — fetch by metadata.name ----
13763
13864 @ Test
@@ -194,16 +120,37 @@ void languageModel_wrongModelType_emitsIncompatibleModelTypeException() {
194120 }
195121
196122 @ Test
197- void defaultLanguageModel_missingSlot_emitsDefaultModelNotConfiguredException () {
123+ void languageModel_withoutNameAndMissingSlot_emitsDefaultModelNotConfiguredException () {
198124 when (defaultModelSlotStore .get ()).thenReturn (Mono .just (new DefaultModelSlots ()));
199125
200- StepVerifier .create (service .defaultLanguageModel ())
126+ StepVerifier .create (service .languageModel ())
201127 .expectError (DefaultModelNotConfiguredException .class )
202128 .verify ();
203129 }
204130
205131 @ Test
206- void defaultLanguageModel_resolvesConfiguredModel () {
132+ void languageModel_withoutName_resolvesConfiguredModel () {
133+ var slots = defaultSlots ("openai-prod-gpt-4-abc" , null );
134+ var model = aiModel ("openai-prod-gpt-4-abc" , "openai-prod" , "gpt-4" , "GPT-4" , true );
135+ var provider = aiProvider ("openai-prod" , "openai" , true );
136+ var chatModel = mock (ChatModel .class );
137+ var providerType = languageProviderType ();
138+
139+ when (defaultModelSlotStore .get ()).thenReturn (Mono .just (slots ));
140+ when (client .fetch (AiModel .class , "openai-prod-gpt-4-abc" )).thenReturn (Mono .just (model ));
141+ when (client .fetch (AiProvider .class , "openai-prod" )).thenReturn (Mono .just (provider ));
142+ when (secretResolver .resolveApiKey (null )).thenReturn (Mono .just ("sk-test" ));
143+ when (providerClientCache .getProviderType ("openai" )).thenReturn (providerType );
144+ when (providerClientCache .getOrCreateChatModel (provider , "sk-test" , "gpt-4" ))
145+ .thenReturn (chatModel );
146+
147+ StepVerifier .create (service .languageModel ())
148+ .assertNext (languageModel -> assertThat (languageModel ).isNotNull ())
149+ .verifyComplete ();
150+ }
151+
152+ @ Test
153+ void languageModel_blankName_resolvesConfiguredModel () {
207154 var slots = defaultSlots ("openai-prod-gpt-4-abc" , null );
208155 var model = aiModel ("openai-prod-gpt-4-abc" , "openai-prod" , "gpt-4" , "GPT-4" , true );
209156 var provider = aiProvider ("openai-prod" , "openai" , true );
@@ -218,11 +165,42 @@ void defaultLanguageModel_resolvesConfiguredModel() {
218165 when (providerClientCache .getOrCreateChatModel (provider , "sk-test" , "gpt-4" ))
219166 .thenReturn (chatModel );
220167
221- StepVerifier .create (service .defaultLanguageModel ( ))
168+ StepVerifier .create (service .languageModel ( " " ))
222169 .assertNext (languageModel -> assertThat (languageModel ).isNotNull ())
223170 .verifyComplete ();
224171 }
225172
173+ @ Test
174+ void embeddingModel_withoutName_resolvesConfiguredModel () {
175+ var slots = defaultSlots (null , "openai-prod-embedding" );
176+ var model = aiModel ("openai-prod-embedding" , "openai-prod" ,
177+ "text-embedding-3-small" , "Embedding" , true , ModelType .EMBEDDING );
178+ var provider = aiProvider ("openai-prod" , "openai" , true );
179+ var springEmbeddingModel = mock (org .springframework .ai .embedding .EmbeddingModel .class );
180+ var providerType = mock (AiProviderType .class );
181+
182+ when (defaultModelSlotStore .get ()).thenReturn (Mono .just (slots ));
183+ when (client .fetch (AiModel .class , "openai-prod-embedding" )).thenReturn (Mono .just (model ));
184+ when (client .fetch (AiProvider .class , "openai-prod" )).thenReturn (Mono .just (provider ));
185+ when (secretResolver .resolveApiKey (null )).thenReturn (Mono .just ("sk-test" ));
186+ when (providerClientCache .getProviderType ("openai" )).thenReturn (providerType );
187+ when (providerClientCache .getOrCreateEmbeddingModel (provider , "sk-test" ,
188+ "text-embedding-3-small" )).thenReturn (springEmbeddingModel );
189+
190+ StepVerifier .create (service .embeddingModel ())
191+ .assertNext (embeddingModel -> assertThat (embeddingModel ).isNotNull ())
192+ .verifyComplete ();
193+ }
194+
195+ @ Test
196+ void embeddingModel_blankNameAndMissingSlot_emitsDefaultModelNotConfiguredException () {
197+ when (defaultModelSlotStore .get ()).thenReturn (Mono .just (new DefaultModelSlots ()));
198+
199+ StepVerifier .create (service .embeddingModel ("" ))
200+ .expectError (DefaultModelNotConfiguredException .class )
201+ .verify ();
202+ }
203+
226204 // ---- helpers ----
227205
228206 private AiModel aiModel (String name , String providerName , String modelId ,
@@ -259,12 +237,6 @@ private AiProvider aiProvider(String name, String providerType, boolean enabled)
259237 return provider ;
260238 }
261239
262- private AiProvider .AiProviderStatus statusWithPhase (AiProvider .AiProviderStatus .Phase phase ) {
263- var status = new AiProvider .AiProviderStatus ();
264- status .setPhase (phase );
265- return status ;
266- }
267-
268240 private DefaultModelSlots defaultSlots (String languageModelName , String embeddingModelName ) {
269241 var slots = new DefaultModelSlots ();
270242 slots .setLanguageModelName (languageModelName );
0 commit comments