Skip to content

Commit 55bb607

Browse files
fix: customHeaders are not passed to the VertexAI clients (#11506)
PiperOrigin-RevId: 738911474 Co-authored-by: Jaycee Li <[email protected]>
1 parent 831c574 commit 55bb607

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java

-14
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,6 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
360360
builder.setEndpoint(String.format("%s:443", apiEndpoint));
361361
builder.setCredentialsProvider(credentialsProvider);
362362

363-
HeaderProvider headerProvider =
364-
FixedHeaderProvider.create(
365-
"user-agent",
366-
String.format(
367-
"%s/%s",
368-
Constants.USER_AGENT_HEADER,
369-
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
370363
builder.setHeaderProvider(headerProvider);
371364
return builder.build();
372365
}
@@ -435,13 +428,6 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
435428
settingsBuilder.setEndpoint(String.format("%s:443", apiEndpoint));
436429
settingsBuilder.setCredentialsProvider(credentialsProvider);
437430

438-
HeaderProvider headerProvider =
439-
FixedHeaderProvider.create(
440-
"user-agent",
441-
String.format(
442-
"%s/%s",
443-
Constants.USER_AGENT_HEADER,
444-
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
445431
settingsBuilder.setHeaderProvider(headerProvider);
446432
return settingsBuilder.build();
447433
}

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java

+32
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import com.google.api.gax.core.GaxProperties;
2626
import com.google.api.gax.core.GoogleCredentialsProvider;
2727
import com.google.auth.oauth2.GoogleCredentials;
28+
import com.google.cloud.vertexai.api.LlmUtilityServiceClient;
29+
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
2830
import com.google.cloud.vertexai.api.PredictionServiceClient;
2931
import com.google.cloud.vertexai.api.PredictionServiceSettings;
3032
import com.google.common.collect.ImmutableList;
@@ -58,6 +60,8 @@ public final class VertexAITest {
5860

5961
@Mock private PredictionServiceClient mockPredictionServiceClient;
6062

63+
@Mock private LlmUtilityServiceClient mockLlmUtilityServiceClient;
64+
6165
@Mock private GoogleCredentialsProvider.Builder mockCredentialsProviderBuilder;
6266

6367
@Mock private GoogleCredentialsProvider mockCredentialsProvider;
@@ -425,6 +429,20 @@ public void testInstantiateVertexAI_builderWithCustomHeaders_shouldContainRightF
425429
Constants.USER_AGENT_HEADER,
426430
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
427431
assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
432+
433+
// make sure the custom headers are set correctly in the prediction service client
434+
try (MockedStatic mockStatic = mockStatic(PredictionServiceClient.class)) {
435+
mockStatic
436+
.when(() -> PredictionServiceClient.create(any(PredictionServiceSettings.class)))
437+
.thenReturn(mockPredictionServiceClient);
438+
PredictionServiceClient unused = vertexAi.getPredictionServiceClient();
439+
440+
ArgumentCaptor<PredictionServiceSettings> settings =
441+
ArgumentCaptor.forClass(PredictionServiceSettings.class);
442+
mockStatic.verify(() -> PredictionServiceClient.create(settings.capture()));
443+
444+
assertThat(settings.getValue().getHeaderProvider().getHeaders()).isEqualTo(expectedHeaders);
445+
}
428446
}
429447

430448
@Test
@@ -454,5 +472,19 @@ public void testInstantiateVertexAI_builderWithCustomHeaders_shouldContainRightF
454472
GaxProperties.getLibraryVersion(PredictionServiceSettings.class),
455473
"test_value"));
456474
assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
475+
476+
// make sure the custom headers are set correctly in the llm utility service client
477+
try (MockedStatic mockStatic = mockStatic(LlmUtilityServiceClient.class)) {
478+
mockStatic
479+
.when(() -> LlmUtilityServiceClient.create(any(LlmUtilityServiceSettings.class)))
480+
.thenReturn(mockLlmUtilityServiceClient);
481+
LlmUtilityServiceClient unused = vertexAi.getLlmUtilityClient();
482+
483+
ArgumentCaptor<LlmUtilityServiceSettings> settings =
484+
ArgumentCaptor.forClass(LlmUtilityServiceSettings.class);
485+
mockStatic.verify(() -> LlmUtilityServiceClient.create(settings.capture()));
486+
487+
assertThat(settings.getValue().getHeaderProvider().getHeaders()).isEqualTo(expectedHeaders);
488+
}
457489
}
458490
}

0 commit comments

Comments
 (0)