Skip to content

Enable updating adaptive_allocations for ElasticsearchInternalService #127994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127994.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127994
summary: Enable updating `adaptive_allocations` for `ElasticsearchInternalService`
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class InferenceUpdateElasticsearchInternalServiceModelIT extends CustomElandModelIT {
private final List<AdaptiveAllocationsSettings> ADAPTIVE_ALLOCATIONS_SETTINGS = List.of(
new AdaptiveAllocationsSettings(randomBoolean(), null, null),
new AdaptiveAllocationsSettings(null, randomIntBetween(1, 10), null),
new AdaptiveAllocationsSettings(null, null, randomIntBetween(1, 10)),
new AdaptiveAllocationsSettings(randomBoolean(), randomIntBetween(1, 10), randomIntBetween(11, 20))
);

public void testUpdateNumThreads() throws IOException {
testUpdateElasticsearchInternalServiceEndpoint(
Optional.of(randomIntBetween(2, 10)),
Optional.empty(),
Optional.empty(),
Optional.empty()
);
}

public void testUpdateAdaptiveAllocationsSettings() throws IOException {
for (AdaptiveAllocationsSettings settings : ADAPTIVE_ALLOCATIONS_SETTINGS) {
testUpdateElasticsearchInternalServiceEndpoint(
Optional.empty(),
Optional.ofNullable(settings.getEnabled()),
Optional.ofNullable(settings.getMinNumberOfAllocations()),
Optional.ofNullable(settings.getMaxNumberOfAllocations())
);
}
}

public void testUpdateNumAllocationsAndAdaptiveAllocationsSettings() throws IOException {
testUpdateElasticsearchInternalServiceEndpoint(
Optional.of(randomIntBetween(2, 10)),
Optional.of(randomBoolean()),
Optional.of(randomIntBetween(1, 10)),
Optional.of(randomIntBetween(11, 20))
);
}

private void testUpdateElasticsearchInternalServiceEndpoint(
Optional<Integer> updatedNumAllocations,
Optional<Boolean> updatedAdaptiveAllocationsEnabled,
Optional<Integer> updatedMinNumberOfAllocations,
Optional<Integer> updatedMaxNumberOfAllocations
) throws IOException {
var inferenceId = "update-adaptive-allocations-inference";
var originalEndpoint = setupInferenceEndpoint(inferenceId);
verifyEndpointConfig(originalEndpoint, 1, Optional.empty(), Optional.empty(), Optional.empty());

var updateConfig = generateUpdateConfig(
updatedNumAllocations,
updatedAdaptiveAllocationsEnabled,
updatedMinNumberOfAllocations,
updatedMaxNumberOfAllocations
);
var updatedEndpoint = updateEndpoint(inferenceId, updateConfig, TaskType.SPARSE_EMBEDDING);
verifyEndpointConfig(
updatedEndpoint,
updatedNumAllocations.orElse(1),
updatedAdaptiveAllocationsEnabled,
updatedMinNumberOfAllocations,
updatedMaxNumberOfAllocations
);
}

private Map<String, Object> setupInferenceEndpoint(String inferenceId) throws IOException {
String modelId = "custom-text-expansion-model";
createMlNodeTextExpansionModel(modelId, client());

var inferenceConfig = """
{
"service": "elasticsearch",
"service_settings": {
"model_id": "custom-text-expansion-model",
"num_allocations": 1,
"num_threads": 1
}
}
""";

return putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
}

public static String generateUpdateConfig(
Optional<Integer> numAllocations,
Optional<Boolean> adaptiveAllocationsEnabled,
Optional<Integer> minNumberOfAllocations,
Optional<Integer> maxNumberOfAllocations
) {
StringBuilder requestBodyBuilder = new StringBuilder();
requestBodyBuilder.append("{ \"service_settings\": {");

numAllocations.ifPresent(value -> requestBodyBuilder.append("\"num_allocations\": ").append(value).append(","));

if (adaptiveAllocationsEnabled.isPresent() || minNumberOfAllocations.isPresent() || maxNumberOfAllocations.isPresent()) {
requestBodyBuilder.append("\"adaptive_allocations\": {");
adaptiveAllocationsEnabled.ifPresent(value -> requestBodyBuilder.append("\"enabled\": ").append(value).append(","));
minNumberOfAllocations.ifPresent(
value -> requestBodyBuilder.append("\"min_number_of_allocations\": ").append(value).append(",")
);
maxNumberOfAllocations.ifPresent(
value -> requestBodyBuilder.append("\"max_number_of_allocations\": ").append(value).append(",")
);

if (requestBodyBuilder.charAt(requestBodyBuilder.length() - 1) == ',') {
requestBodyBuilder.deleteCharAt(requestBodyBuilder.length() - 1);
}
requestBodyBuilder.append("},");
}

if (requestBodyBuilder.charAt(requestBodyBuilder.length() - 1) == ',') {
requestBodyBuilder.deleteCharAt(requestBodyBuilder.length() - 1);
}

requestBodyBuilder.append("} }");
return requestBodyBuilder.toString();
}

@SuppressWarnings("unchecked")
private void verifyEndpointConfig(
Map<String, Object> endpointConfig,
int expectedNumAllocations,
Optional<Boolean> adaptiveAllocationsEnabled,
Optional<Integer> minNumberOfAllocations,
Optional<Integer> maxNumberOfAllocations
) {
var serviceSettings = (Map<String, Object>) endpointConfig.get("service_settings");

assertEquals(expectedNumAllocations, serviceSettings.get("num_allocations"));
if (adaptiveAllocationsEnabled.isPresent() || minNumberOfAllocations.isPresent() || maxNumberOfAllocations.isPresent()) {
var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
adaptiveAllocationsEnabled.ifPresent(enabled -> assertEquals(enabled, adaptiveAllocations.get("enabled")));
minNumberOfAllocations.ifPresent(min -> assertEquals(min, adaptiveAllocations.get("min_number_of_allocations")));
maxNumberOfAllocations.ifPresent(max -> assertEquals(max, adaptiveAllocations.get("max_number_of_allocations")));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand All @@ -63,6 +64,7 @@

import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.resolveTaskType;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS;

public class TransportUpdateInferenceModelAction extends TransportMasterNodeAction<
Expand Down Expand Up @@ -220,12 +222,17 @@ private Model combineExistingModelWithNewSettings(
if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) {
newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
}
if (settingsToUpdate.serviceSettings() != null && settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)) {
if (settingsToUpdate.serviceSettings() != null
&& (settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)
|| settingsToUpdate.serviceSettings().containsKey(ADAPTIVE_ALLOCATIONS))) {
// In cluster services can only have their num_allocations updated, so this is a special case
if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
newServiceSettings = new ElasticsearchInternalServiceSettings(
elasticServiceSettings,
(Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS)
settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)
? (Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS)
: null,
getAdaptiveAllocationsSettingsFromMap(settingsToUpdate.serviceSettings())
);
}
}
Expand Down Expand Up @@ -259,10 +266,15 @@ private void updateInClusterEndpoint(
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);

Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {
if (serviceSettings != null
&& (serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer || serviceSettings.containsKey(ADAPTIVE_ALLOCATIONS))) {
var numAllocations = (Integer) serviceSettings.get(NUM_ALLOCATIONS);
var adaptiveAllocationsSettings = getAdaptiveAllocationsSettingsFromMap(serviceSettings);
// TODO: Figure out how to deep clonse the adaptive allocations settings as they are already removed at this point.

UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
updateRequest.setNumberOfAllocations(numAllocations);
updateRequest.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);

var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
Expand Down Expand Up @@ -339,6 +351,36 @@ private void checkEndpointExists(String inferenceEntityId, ActionListener<Unpars
}));
}

@SuppressWarnings("unchecked")
private AdaptiveAllocationsSettings getAdaptiveAllocationsSettingsFromMap(Map<String, Object> settings) {
if (settings == null || settings.isEmpty() || settings.containsKey(ADAPTIVE_ALLOCATIONS) == false) {
return null;
}

var adaptiveAllocationsSettingsMap = (Map<String, Object>) settings.get(ADAPTIVE_ALLOCATIONS);

// TODO: Test invalid type being passed here. Also test if updating causes any issues with the UI
var adaptiveAllocationsSettingsBuilder = new AdaptiveAllocationsSettings.Builder();
adaptiveAllocationsSettingsBuilder.setEnabled(
(Boolean) adaptiveAllocationsSettingsMap.get(AdaptiveAllocationsSettings.ENABLED.getPreferredName())
);
adaptiveAllocationsSettingsBuilder.setMinNumberOfAllocations(
(Integer) adaptiveAllocationsSettingsMap.get(AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS.getPreferredName())
);
adaptiveAllocationsSettingsBuilder.setMaxNumberOfAllocations(
(Integer) adaptiveAllocationsSettingsMap.get(AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS.getPreferredName())
);

var adaptiveAllocationsSettings = adaptiveAllocationsSettingsBuilder.build();
var validationException = adaptiveAllocationsSettings.validate();

if (validationException != null) {
throw validationException;
}

return adaptiveAllocationsSettings;
}

private static XContentParser getParser(UpdateInferenceModelAction.Request request) throws IOException {
return XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,20 @@ protected ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSetti
* Copy constructor with the ability to set the number of allocations. Used for Update API.
* @param other the existing settings
* @param numAllocations the new number of allocations
* @param adaptiveAllocationsSettings the new adaptive allocations settings
*/
public ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSettings other, int numAllocations) {
this.numAllocations = numAllocations;
public ElasticsearchInternalServiceSettings(
ElasticsearchInternalServiceSettings other,
Integer numAllocations,
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
this.numAllocations = numAllocations == null ? other.numAllocations : numAllocations;
// TODO: Should we block numAllocations<minNumOfAllocations. Also does this get updated by adaptive allocations?
this.numThreads = other.numThreads;
this.modelId = other.modelId;
this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings;
this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings == null
? adaptiveAllocationsSettings
: other.adaptiveAllocationsSettings.merge(adaptiveAllocationsSettings);
this.deploymentId = other.deploymentId;
}

Expand Down
Loading