2727import org .elasticsearch .xpack .inference .services .settings .RateLimitSettings ;
2828
2929import java .io .IOException ;
30+ import java .util .HashMap ;
3031import java .util .Map ;
3132import java .util .Objects ;
3233
3637import static org .elasticsearch .xpack .inference .services .ServiceFields .SIMILARITY ;
3738import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalBoolean ;
3839import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalPositiveInteger ;
40+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalPositiveIntegerLessThanOrEqualToMax ;
3941import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredString ;
4042import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractSimilarity ;
43+ import static org .elasticsearch .xpack .inference .services .googlevertexai .GoogleVertexAiServiceFields .EMBEDDING_MAX_BATCH_SIZE ;
4144import static org .elasticsearch .xpack .inference .services .googlevertexai .GoogleVertexAiServiceFields .LOCATION ;
45+ import static org .elasticsearch .xpack .inference .services .googlevertexai .GoogleVertexAiServiceFields .MAX_BATCH_SIZE ;
4246import static org .elasticsearch .xpack .inference .services .googlevertexai .GoogleVertexAiServiceFields .PROJECT_ID ;
4347
4448public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObject
@@ -53,6 +57,10 @@ public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObj
5357 // See online prediction requests per minute: https://cloud.google.com/vertex-ai/docs/quotas.
5458 private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings (30_000 );
5559
60+ protected static final TransportVersion GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE = TransportVersion .fromName (
61+ "google_vertex_ai_configurable_max_batch_size"
62+ );
63+
5664 public static GoogleVertexAiEmbeddingsServiceSettings fromMap (Map <String , Object > map , ConfigurationParseContext context ) {
5765 ValidationException validationException = new ValidationException ();
5866
@@ -67,6 +75,13 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
6775 );
6876 SimilarityMeasure similarityMeasure = extractSimilarity (map , ModelConfigurations .SERVICE_SETTINGS , validationException );
6977 Integer dims = extractOptionalPositiveInteger (map , DIMENSIONS , ModelConfigurations .SERVICE_SETTINGS , validationException );
78+ Integer maxBatchSize = extractOptionalPositiveIntegerLessThanOrEqualToMax (
79+ map ,
80+ MAX_BATCH_SIZE ,
81+ EMBEDDING_MAX_BATCH_SIZE ,
82+ ModelConfigurations .SERVICE_SETTINGS ,
83+ validationException
84+ );
7085 RateLimitSettings rateLimitSettings = RateLimitSettings .of (
7186 map ,
7287 DEFAULT_RATE_LIMIT_SETTINGS ,
@@ -106,11 +121,32 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
106121 dimensionsSetByUser ,
107122 maxInputTokens ,
108123 dims ,
124+ maxBatchSize ,
109125 similarityMeasure ,
110126 rateLimitSettings
111127 );
112128 }
113129
130+ @ Override
131+ public ServiceSettings updateServiceSettings (Map <String , Object > serviceSettings ) {
132+ var validationException = new ValidationException ();
133+ serviceSettings = new HashMap <>(serviceSettings );
134+
135+ Integer maxBatchSize = extractOptionalPositiveIntegerLessThanOrEqualToMax (
136+ serviceSettings ,
137+ MAX_BATCH_SIZE ,
138+ EMBEDDING_MAX_BATCH_SIZE ,
139+ ModelConfigurations .SERVICE_SETTINGS ,
140+ validationException
141+ );
142+
143+ if (validationException .validationErrors ().isEmpty () == false ) {
144+ throw validationException ;
145+ }
146+
147+ return new GoogleVertexAiEmbeddingsServiceSettings (this , maxBatchSize );
148+ }
149+
114150 private final String location ;
115151
116152 private final String projectId ;
@@ -119,6 +155,8 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
119155
120156 private final Integer dims ;
121157
158+ private final Integer maxBatchSize ;
159+
122160 private final SimilarityMeasure similarity ;
123161 private final Integer maxInputTokens ;
124162
@@ -133,6 +171,7 @@ public GoogleVertexAiEmbeddingsServiceSettings(
133171 Boolean dimensionsSetByUser ,
134172 @ Nullable Integer maxInputTokens ,
135173 @ Nullable Integer dims ,
174+ @ Nullable Integer maxBatchSize ,
136175 @ Nullable SimilarityMeasure similarity ,
137176 @ Nullable RateLimitSettings rateLimitSettings
138177 ) {
@@ -142,17 +181,35 @@ public GoogleVertexAiEmbeddingsServiceSettings(
142181 this .dimensionsSetByUser = dimensionsSetByUser ;
143182 this .maxInputTokens = maxInputTokens ;
144183 this .dims = dims ;
184+ this .maxBatchSize = maxBatchSize ;
145185 this .similarity = Objects .requireNonNullElse (similarity , SimilarityMeasure .DOT_PRODUCT );
146186 this .rateLimitSettings = Objects .requireNonNullElse (rateLimitSettings , DEFAULT_RATE_LIMIT_SETTINGS );
147187 }
148188
189+ public GoogleVertexAiEmbeddingsServiceSettings (GoogleVertexAiEmbeddingsServiceSettings original , @ Nullable Integer maxBatchSize ) {
190+ this .location = original .location ;
191+ this .projectId = original .projectId ;
192+ this .modelId = original .modelId ;
193+ this .dimensionsSetByUser = original .dimensionsSetByUser ;
194+ this .maxInputTokens = original .maxInputTokens ;
195+ this .dims = original .dims ;
196+ this .maxBatchSize = maxBatchSize != null ? maxBatchSize : original .maxBatchSize ;
197+ this .similarity = original .similarity ;
198+ this .rateLimitSettings = original .rateLimitSettings ;
199+ }
200+
149201 public GoogleVertexAiEmbeddingsServiceSettings (StreamInput in ) throws IOException {
150202 this .location = in .readString ();
151203 this .projectId = in .readString ();
152204 this .modelId = in .readString ();
153205 this .dimensionsSetByUser = in .readBoolean ();
154206 this .maxInputTokens = in .readOptionalVInt ();
155207 this .dims = in .readOptionalVInt ();
208+ if (in .getTransportVersion ().supports (GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE )) {
209+ this .maxBatchSize = in .readOptionalVInt ();
210+ } else {
211+ this .maxBatchSize = null ;
212+ }
156213 this .similarity = in .readOptionalEnum (SimilarityMeasure .class );
157214 this .rateLimitSettings = new RateLimitSettings (in );
158215 }
@@ -189,6 +246,10 @@ public Integer dimensions() {
189246 return dims ;
190247 }
191248
249+ public Integer maxBatchSize () {
250+ return maxBatchSize ;
251+ }
252+
192253 @ Override
193254 public SimilarityMeasure similarity () {
194255 return similarity ;
@@ -228,6 +289,9 @@ public void writeTo(StreamOutput out) throws IOException {
228289 out .writeBoolean (dimensionsSetByUser );
229290 out .writeOptionalVInt (maxInputTokens );
230291 out .writeOptionalVInt (dims );
292+ if (out .getTransportVersion ().supports (GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE )) {
293+ out .writeOptionalVInt (maxBatchSize );
294+ }
231295 out .writeOptionalEnum (similarity );
232296 rateLimitSettings .writeTo (out );
233297 }
@@ -246,6 +310,10 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
246310 builder .field (DIMENSIONS , dims );
247311 }
248312
313+ if (maxBatchSize != null ) {
314+ builder .field (MAX_BATCH_SIZE , maxBatchSize );
315+ }
316+
249317 if (similarity != null ) {
250318 builder .field (SIMILARITY , similarity );
251319 }
@@ -264,6 +332,7 @@ public boolean equals(Object object) {
264332 && Objects .equals (projectId , that .projectId )
265333 && Objects .equals (modelId , that .modelId )
266334 && Objects .equals (dims , that .dims )
335+ && Objects .equals (maxBatchSize , that .maxBatchSize )
267336 && similarity == that .similarity
268337 && Objects .equals (maxInputTokens , that .maxInputTokens )
269338 && Objects .equals (rateLimitSettings , that .rateLimitSettings )
@@ -272,6 +341,16 @@ public boolean equals(Object object) {
272341
273342 @ Override
274343 public int hashCode () {
275- return Objects .hash (location , projectId , modelId , dims , similarity , maxInputTokens , rateLimitSettings , dimensionsSetByUser );
344+ return Objects .hash (
345+ location ,
346+ projectId ,
347+ modelId ,
348+ dims ,
349+ maxBatchSize ,
350+ similarity ,
351+ maxInputTokens ,
352+ rateLimitSettings ,
353+ dimensionsSetByUser
354+ );
276355 }
277356}
0 commit comments