Skip to content

Add Retry Logic for ElasticsearchInternalService#chunkedInfer #127812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127812.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127812
summary: Add Retry Logic for `ElasticsearchInternalService#chunkedInfer`
area: Machine Learning
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@
public class ModelRegistryIT extends ESSingleNodeTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);

private ClusterService clusterService;
private ModelRegistry modelRegistry;

@Before
public void createComponents() {
clusterService = node().injector().getInstance(ClusterService.class);
modelRegistry = node().injector().getInstance(ModelRegistry.class);
modelRegistry.clearDefaultIds();
}
Expand Down Expand Up @@ -123,12 +125,11 @@ public void testGetModel() throws Exception {
assertThat(modelHolder.get(), not(nullValue()));

assertEquals(model.getConfigurations().getService(), modelHolder.get().service());

var elserService = new ElasticsearchInternalService(
new InferenceServiceExtension.InferenceServiceFactoryContext(
mock(Client.class),
mock(ThreadPool.class),
mock(ClusterService.class),
clusterService,
Settings.EMPTY
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

public class RetrySettings {

static final Setting<TimeValue> RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting(
public static final Setting<TimeValue> RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting(
"xpack.inference.http.retry.initial_delay",
TimeValue.timeValueSeconds(1),
Setting.Property.NodeScope,
Expand All @@ -30,7 +30,7 @@ public class RetrySettings {
Setting.Property.Dynamic
);

static final Setting<TimeValue> RETRY_TIMEOUT_SETTING = Setting.timeSetting(
public static final Setting<TimeValue> RETRY_TIMEOUT_SETTING = Setting.timeSetting(
"xpack.inference.http.retry.timeout",
TimeValue.timeValueSeconds(30),
Setting.Property.NodeScope,
Expand Down Expand Up @@ -106,15 +106,15 @@ public static List<Setting<?>> getSettingsDefinitions() {
);
}

TimeValue getInitialDelay() {
public TimeValue getInitialDelay() {
return initialDelay;
}

TimeValue getMaxDelayBound() {
return maxDelayBound;
}

TimeValue getTimeout() {
public TimeValue getTimeout() {
return timeout;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.RetryableAction;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
Expand All @@ -36,6 +40,7 @@
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand All @@ -57,6 +62,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

Expand All @@ -68,9 +74,12 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntUnaryOperator;

import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
Expand Down Expand Up @@ -121,10 +130,14 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
private static final String OLD_MODEL_ID_FIELD_NAME = "model_version";

private final Settings settings;
private final ThreadPool threadPool;
private final RetrySettings retrySettings;

public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
super(context);
this.settings = context.settings();
this.threadPool = context.threadPool();
this.retrySettings = new RetrySettings(context.settings(), context.clusterService());
}

// for testing
Expand All @@ -134,6 +147,8 @@ public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFa
) {
super(context, platformArch);
this.settings = context.settings();
this.threadPool = context.threadPool();
this.retrySettings = new RetrySettings(context.settings(), context.clusterService());
}

@Override
Expand Down Expand Up @@ -1126,10 +1141,150 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft
if (maybeDeploy) {
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
}
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);

new BatchExecutor(retrySettings.getInitialDelay(), retrySettings.getTimeout(), inferenceRequest, listener, inferenceExecutor)
.run();
}
}

private static final Set<RestStatus> RETRYABLE_STATUS = Set.of(
RestStatus.INTERNAL_SERVER_ERROR,
RestStatus.TOO_MANY_REQUESTS,
RestStatus.REQUEST_TIMEOUT
);

private class BatchExecutor extends RetryableAction<InferModelAction.Response> {
private final RetryState state;

BatchExecutor(
TimeValue initialDelay,
TimeValue timeoutValue,
InferModelAction.Request request,
ActionListener<InferModelAction.Response> listener,
Executor executor
) {
this(initialDelay, timeoutValue, new RetryState(request), listener, executor);
}

private BatchExecutor(
TimeValue initialDelay,
TimeValue timeoutValue,
RetryState state,
ActionListener<InferModelAction.Response> listener,
Executor executor
) {
super(logger, threadPool, initialDelay, timeoutValue, new ActionListener<>() {
@Override
public void onResponse(InferModelAction.Response response) {
listener.onResponse(state.getAccumulatedResponse(null));
}

@Override
public void onFailure(Exception exc) {
if (state.hasPartialResponse()) {
listener.onResponse(state.getAccumulatedResponse(exc instanceof RetryableException ? null : exc));
} else {
listener.onFailure(exc);
}
}
}, executor);
this.state = state;
}

@Override
public void tryAction(ActionListener<InferModelAction.Response> listener) {
client.execute(InferModelAction.INSTANCE, state.getCurrentRequest(), new ActionListener<>() {
@Override
public void onResponse(InferModelAction.Response response) {
if (state.consumeResponse(response)) {
listener.onResponse(response);
} else {
listener.onFailure(new RetryableException());
}
}

@Override
public void onFailure(Exception exc) {
listener.onFailure(exc);
}
});
}

@Override
public boolean shouldRetry(Exception exc) {
return exc instanceof RetryableException
|| RETRYABLE_STATUS.contains(ExceptionsHelper.status(ExceptionsHelper.unwrapCause(exc)));
}
}

private static class RetryState {
private final InferModelAction.Request originalRequest;
private InferModelAction.Request currentRequest;

private IntUnaryOperator currentToOriginalIndex;
private final AtomicArray<InferenceResults> inferenceResults;
private final AtomicBoolean hasPartialResponse;

private RetryState(InferModelAction.Request originalRequest) {
this.originalRequest = originalRequest;
this.currentRequest = originalRequest;
this.currentToOriginalIndex = index -> index;
this.inferenceResults = new AtomicArray<>(originalRequest.getTextInput().size());
this.hasPartialResponse = new AtomicBoolean();
}

boolean hasPartialResponse() {
return hasPartialResponse.get();
}

InferModelAction.Request getCurrentRequest() {
return currentRequest;
}

InferModelAction.Response getAccumulatedResponse(@Nullable Exception exc) {
List<InferenceResults> finalResults = new ArrayList<>();
for (int i = 0; i < inferenceResults.length(); i++) {
var result = inferenceResults.get(i);
if (exc != null && result instanceof ErrorInferenceResults) {
finalResults.add(new ErrorInferenceResults(exc));
} else {
finalResults.add(result);
}
}
return new InferModelAction.Response(finalResults, originalRequest.getId(), originalRequest.isPreviouslyLicensed());
}

private boolean consumeResponse(InferModelAction.Response response) {
hasPartialResponse.set(true);
List<String> retryInputs = new ArrayList<>();
IntIntHashMap newIndexMap = new IntIntHashMap();
for (int i = 0; i < response.getInferenceResults().size(); i++) {
var result = response.getInferenceResults().get(i);
int index = currentToOriginalIndex.applyAsInt(i);
inferenceResults.set(index, result);
if (result instanceof ErrorInferenceResults error
&& RETRYABLE_STATUS.contains(ExceptionsHelper.status(ExceptionsHelper.unwrapCause(error.getException())))) {
newIndexMap.put(retryInputs.size(), index);
retryInputs.add(originalRequest.getTextInput().get(index));
}
}
if (retryInputs.isEmpty()) {
return true;
}
currentRequest = InferModelAction.Request.forTextInput(
originalRequest.getId(),
originalRequest.getUpdate(),
retryInputs,
originalRequest.isPreviouslyLicensed(),
originalRequest.getInferenceTimeout()
);
currentToOriginalIndex = newIndexMap::get;
return false;
}
}

private static class RetryableException extends Exception {}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down
Loading
Loading