Skip to content

Commit ad268e0

Browse files
Add configurable max_batch_size for GoogleVertexAI embedding service settings (elastic#138047)
* Add configurable max_batch_size for GoogleVertexAI embedding service settings * Update docs/changelog/138047.yaml --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent a642a2f commit ad268e0

File tree

15 files changed

+370
-66
lines changed

15 files changed

+370
-66
lines changed

docs/changelog/138047.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 138047
2+
summary: Add configurable `max_batch_size` for `GoogleVertexAI` embedding service
3+
settings
4+
area: Machine Learning
5+
type: bug
6+
issues: []

server/src/main/java/org/elasticsearch/inference/ServiceSettings.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1515
import org.elasticsearch.xcontent.ToXContentObject;
1616

17+
import java.util.Map;
18+
1719
public interface ServiceSettings extends ToXContentObject, VersionedNamedWriteable, FilteredXContent {
1820

1921
/**
@@ -61,4 +63,8 @@ default DenseVectorFieldMapper.ElementType elementType() {
6163
*/
6264
@Nullable
6365
String modelId();
66+
67+
default ServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
68+
return this;
69+
}
6470
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9241000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
esql_exponential_histogram_supported_version,9240000
1+
google_vertex_ai_configurable_max_batch_size,9241000

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,7 @@ private Model combineExistingModelWithNewSettings(
223223
newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
224224
}
225225
if (settingsToUpdate.serviceSettings() != null) {
226-
// In cluster services can have their deployment settings updated, so this is a special case
227-
if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
228-
newServiceSettings = elasticServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
229-
}
226+
newServiceSettings = newServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
230227
}
231228
if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) {
232229
newTaskSettings = existingTaskSettings.updatedTaskSettings(settingsToUpdate.taskSettings());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,22 @@ public static Integer extractOptionalPositiveInteger(
748748
return extractOptionalInteger(map, settingName, scope, validationException, true);
749749
}
750750

751+
public static Integer extractOptionalPositiveIntegerLessThanOrEqualToMax(
752+
Map<String, Object> map,
753+
String settingName,
754+
int maxValue,
755+
String scope,
756+
ValidationException validationException
757+
) {
758+
Integer optionalField = extractOptionalPositiveInteger(map, settingName, scope, validationException);
759+
760+
if (optionalField != null && optionalField > maxValue) {
761+
validationException.addValidationError(mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, optionalField, maxValue));
762+
}
763+
764+
return optionalField;
765+
}
766+
751767
public static Integer extractOptionalInteger(
752768
Map<String, Object> map,
753769
String settingName,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ public TransportVersion getMinimalSupportedVersion() {
227227
return TransportVersion.minimumCompatible();
228228
}
229229

230-
public ElasticsearchInternalServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
230+
@Override
231+
public ServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
231232
var validationException = new ValidationException();
232233
var mutableServiceSettings = new HashMap<>(serviceSettings);
233234

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,13 @@ protected void doChunkedInfer(
280280
ActionListener<List<ChunkedInference>> listener
281281
) {
282282
GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model;
283+
GoogleVertexAiEmbeddingsServiceSettings serviceSettings = (GoogleVertexAiEmbeddingsServiceSettings) googleVertexAiModel
284+
.getServiceSettings();
283285
var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());
284286

285287
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
286288
inputs,
287-
EMBEDDING_MAX_BATCH_SIZE,
289+
serviceSettings.maxBatchSize() == null ? EMBEDDING_MAX_BATCH_SIZE : serviceSettings.maxBatchSize(),
288290
googleVertexAiModel.getConfigurations().getChunkingSettings()
289291
).batchRequestsWithListeners(listener);
290292

@@ -306,6 +308,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
306308
serviceSettings.dimensionsSetByUser(),
307309
serviceSettings.maxInputTokens(),
308310
embeddingSize,
311+
serviceSettings.maxBatchSize(),
309312
serviceSettings.similarity(),
310313
serviceSettings.rateLimitSettings()
311314
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ public class GoogleVertexAiServiceFields {
1515
public static final String URL_SETTING_NAME = "url";
1616
public static final String STREAMING_URL_SETTING_NAME = "streaming_url";
1717
public static final String PROVIDER_SETTING_NAME = "provider";
18+
public static final String MAX_BATCH_SIZE = "max_batch_size";
1819

1920
/**
2021
* According to https://cloud.google.com/vertex-ai/docs/quotas#text-embedding-limits the limit is `250`.
2122
*/
22-
static final int EMBEDDING_MAX_BATCH_SIZE = 250;
23+
public static final int EMBEDDING_MAX_BATCH_SIZE = 250;
2324

2425
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2828

2929
import java.io.IOException;
30+
import java.util.HashMap;
3031
import java.util.Map;
3132
import java.util.Objects;
3233

@@ -36,9 +37,12 @@
3637
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
3738
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
3839
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
40+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveIntegerLessThanOrEqualToMax;
3941
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
4042
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
43+
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
4144
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
45+
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.MAX_BATCH_SIZE;
4246
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
4347

4448
public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObject
@@ -53,6 +57,10 @@ public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObj
5357
// See online prediction requests per minute: https://cloud.google.com/vertex-ai/docs/quotas.
5458
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(30_000);
5559

60+
protected static final TransportVersion GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE = TransportVersion.fromName(
61+
"google_vertex_ai_configurable_max_batch_size"
62+
);
63+
5664
public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
5765
ValidationException validationException = new ValidationException();
5866

@@ -67,6 +75,13 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
6775
);
6876
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
6977
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
78+
Integer maxBatchSize = extractOptionalPositiveIntegerLessThanOrEqualToMax(
79+
map,
80+
MAX_BATCH_SIZE,
81+
EMBEDDING_MAX_BATCH_SIZE,
82+
ModelConfigurations.SERVICE_SETTINGS,
83+
validationException
84+
);
7085
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
7186
map,
7287
DEFAULT_RATE_LIMIT_SETTINGS,
@@ -106,11 +121,32 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
106121
dimensionsSetByUser,
107122
maxInputTokens,
108123
dims,
124+
maxBatchSize,
109125
similarityMeasure,
110126
rateLimitSettings
111127
);
112128
}
113129

130+
@Override
131+
public ServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
132+
var validationException = new ValidationException();
133+
serviceSettings = new HashMap<>(serviceSettings);
134+
135+
Integer maxBatchSize = extractOptionalPositiveIntegerLessThanOrEqualToMax(
136+
serviceSettings,
137+
MAX_BATCH_SIZE,
138+
EMBEDDING_MAX_BATCH_SIZE,
139+
ModelConfigurations.SERVICE_SETTINGS,
140+
validationException
141+
);
142+
143+
if (validationException.validationErrors().isEmpty() == false) {
144+
throw validationException;
145+
}
146+
147+
return new GoogleVertexAiEmbeddingsServiceSettings(this, maxBatchSize);
148+
}
149+
114150
private final String location;
115151

116152
private final String projectId;
@@ -119,6 +155,8 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
119155

120156
private final Integer dims;
121157

158+
private final Integer maxBatchSize;
159+
122160
private final SimilarityMeasure similarity;
123161
private final Integer maxInputTokens;
124162

@@ -133,6 +171,7 @@ public GoogleVertexAiEmbeddingsServiceSettings(
133171
Boolean dimensionsSetByUser,
134172
@Nullable Integer maxInputTokens,
135173
@Nullable Integer dims,
174+
@Nullable Integer maxBatchSize,
136175
@Nullable SimilarityMeasure similarity,
137176
@Nullable RateLimitSettings rateLimitSettings
138177
) {
@@ -142,17 +181,35 @@ public GoogleVertexAiEmbeddingsServiceSettings(
142181
this.dimensionsSetByUser = dimensionsSetByUser;
143182
this.maxInputTokens = maxInputTokens;
144183
this.dims = dims;
184+
this.maxBatchSize = maxBatchSize;
145185
this.similarity = Objects.requireNonNullElse(similarity, SimilarityMeasure.DOT_PRODUCT);
146186
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
147187
}
148188

189+
public GoogleVertexAiEmbeddingsServiceSettings(GoogleVertexAiEmbeddingsServiceSettings original, @Nullable Integer maxBatchSize) {
190+
this.location = original.location;
191+
this.projectId = original.projectId;
192+
this.modelId = original.modelId;
193+
this.dimensionsSetByUser = original.dimensionsSetByUser;
194+
this.maxInputTokens = original.maxInputTokens;
195+
this.dims = original.dims;
196+
this.maxBatchSize = maxBatchSize != null ? maxBatchSize : original.maxBatchSize;
197+
this.similarity = original.similarity;
198+
this.rateLimitSettings = original.rateLimitSettings;
199+
}
200+
149201
public GoogleVertexAiEmbeddingsServiceSettings(StreamInput in) throws IOException {
150202
this.location = in.readString();
151203
this.projectId = in.readString();
152204
this.modelId = in.readString();
153205
this.dimensionsSetByUser = in.readBoolean();
154206
this.maxInputTokens = in.readOptionalVInt();
155207
this.dims = in.readOptionalVInt();
208+
if (in.getTransportVersion().supports(GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE)) {
209+
this.maxBatchSize = in.readOptionalVInt();
210+
} else {
211+
this.maxBatchSize = null;
212+
}
156213
this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
157214
this.rateLimitSettings = new RateLimitSettings(in);
158215
}
@@ -189,6 +246,10 @@ public Integer dimensions() {
189246
return dims;
190247
}
191248

249+
public Integer maxBatchSize() {
250+
return maxBatchSize;
251+
}
252+
192253
@Override
193254
public SimilarityMeasure similarity() {
194255
return similarity;
@@ -228,6 +289,9 @@ public void writeTo(StreamOutput out) throws IOException {
228289
out.writeBoolean(dimensionsSetByUser);
229290
out.writeOptionalVInt(maxInputTokens);
230291
out.writeOptionalVInt(dims);
292+
if (out.getTransportVersion().supports(GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE)) {
293+
out.writeOptionalVInt(maxBatchSize);
294+
}
231295
out.writeOptionalEnum(similarity);
232296
rateLimitSettings.writeTo(out);
233297
}
@@ -246,6 +310,10 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
246310
builder.field(DIMENSIONS, dims);
247311
}
248312

313+
if (maxBatchSize != null) {
314+
builder.field(MAX_BATCH_SIZE, maxBatchSize);
315+
}
316+
249317
if (similarity != null) {
250318
builder.field(SIMILARITY, similarity);
251319
}
@@ -264,6 +332,7 @@ public boolean equals(Object object) {
264332
&& Objects.equals(projectId, that.projectId)
265333
&& Objects.equals(modelId, that.modelId)
266334
&& Objects.equals(dims, that.dims)
335+
&& Objects.equals(maxBatchSize, that.maxBatchSize)
267336
&& similarity == that.similarity
268337
&& Objects.equals(maxInputTokens, that.maxInputTokens)
269338
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
@@ -272,6 +341,16 @@ public boolean equals(Object object) {
272341

273342
@Override
274343
public int hashCode() {
275-
return Objects.hash(location, projectId, modelId, dims, similarity, maxInputTokens, rateLimitSettings, dimensionsSetByUser);
344+
return Objects.hash(
345+
location,
346+
projectId,
347+
modelId,
348+
dims,
349+
maxBatchSize,
350+
similarity,
351+
maxInputTokens,
352+
rateLimitSettings,
353+
dimensionsSetByUser
354+
);
276355
}
277356
}

0 commit comments

Comments
 (0)