|
25 | 25 | import com.google.api.gax.core.GaxProperties;
|
26 | 26 | import com.google.api.gax.core.GoogleCredentialsProvider;
|
27 | 27 | import com.google.auth.oauth2.GoogleCredentials;
|
| 28 | +import com.google.cloud.vertexai.api.LlmUtilityServiceClient; |
| 29 | +import com.google.cloud.vertexai.api.LlmUtilityServiceSettings; |
28 | 30 | import com.google.cloud.vertexai.api.PredictionServiceClient;
|
29 | 31 | import com.google.cloud.vertexai.api.PredictionServiceSettings;
|
30 | 32 | import com.google.common.collect.ImmutableList;
|
@@ -58,6 +60,8 @@ public final class VertexAITest {
|
58 | 60 |
|
59 | 61 | @Mock private PredictionServiceClient mockPredictionServiceClient;
|
60 | 62 |
|
| 63 | + @Mock private LlmUtilityServiceClient mockLlmUtilityServiceClient; |
| 64 | + |
61 | 65 | @Mock private GoogleCredentialsProvider.Builder mockCredentialsProviderBuilder;
|
62 | 66 |
|
63 | 67 | @Mock private GoogleCredentialsProvider mockCredentialsProvider;
|
@@ -425,6 +429,20 @@ public void testInstantiateVertexAI_builderWithCustomHeaders_shouldContainRightF
|
425 | 429 | Constants.USER_AGENT_HEADER,
|
426 | 430 | GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
|
427 | 431 | assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
|
| 432 | + |
| 433 | + // make sure the custom headers are set correctly in the prediction service client |
| 434 | + try (MockedStatic mockStatic = mockStatic(PredictionServiceClient.class)) { |
| 435 | + mockStatic |
| 436 | + .when(() -> PredictionServiceClient.create(any(PredictionServiceSettings.class))) |
| 437 | + .thenReturn(mockPredictionServiceClient); |
| 438 | + PredictionServiceClient unused = vertexAi.getPredictionServiceClient(); |
| 439 | + |
| 440 | + ArgumentCaptor<PredictionServiceSettings> settings = |
| 441 | + ArgumentCaptor.forClass(PredictionServiceSettings.class); |
| 442 | + mockStatic.verify(() -> PredictionServiceClient.create(settings.capture())); |
| 443 | + |
| 444 | + assertThat(settings.getValue().getHeaderProvider().getHeaders()).isEqualTo(expectedHeaders); |
| 445 | + } |
428 | 446 | }
|
429 | 447 |
|
430 | 448 | @Test
|
@@ -454,5 +472,19 @@ public void testInstantiateVertexAI_builderWithCustomHeaders_shouldContainRightF
|
454 | 472 | GaxProperties.getLibraryVersion(PredictionServiceSettings.class),
|
455 | 473 | "test_value"));
|
456 | 474 | assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
|
| 475 | + |
| 476 | + // make sure the custom headers are set correctly in the llm utility service client |
| 477 | + try (MockedStatic mockStatic = mockStatic(LlmUtilityServiceClient.class)) { |
| 478 | + mockStatic |
| 479 | + .when(() -> LlmUtilityServiceClient.create(any(LlmUtilityServiceSettings.class))) |
| 480 | + .thenReturn(mockLlmUtilityServiceClient); |
| 481 | + LlmUtilityServiceClient unused = vertexAi.getLlmUtilityClient(); |
| 482 | + |
| 483 | + ArgumentCaptor<LlmUtilityServiceSettings> settings = |
| 484 | + ArgumentCaptor.forClass(LlmUtilityServiceSettings.class); |
| 485 | + mockStatic.verify(() -> LlmUtilityServiceClient.create(settings.capture())); |
| 486 | + |
| 487 | + assertThat(settings.getValue().getHeaderProvider().getHeaders()).isEqualTo(expectedHeaders); |
| 488 | + } |
457 | 489 | }
|
458 | 490 | }
|
0 commit comments