diff --git a/docs/changelog/127812.yaml b/docs/changelog/127812.yaml new file mode 100644 index 0000000000000..45272134659da --- /dev/null +++ b/docs/changelog/127812.yaml @@ -0,0 +1,5 @@ +pr: 127812 +summary: Add Retry Logic for `ElasticsearchInternalService#chunkedInfer` +area: Machine Learning +type: feature +issues: [] diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index e56782bd00ef5..0cdc8fb640f93 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -77,10 +77,12 @@ public class ModelRegistryIT extends ESSingleNodeTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ClusterService clusterService; private ModelRegistry modelRegistry; @Before public void createComponents() { + clusterService = node().injector().getInstance(ClusterService.class); modelRegistry = node().injector().getInstance(ModelRegistry.class); modelRegistry.clearDefaultIds(); } @@ -123,12 +125,11 @@ public void testGetModel() throws Exception { assertThat(modelHolder.get(), not(nullValue())); assertEquals(model.getConfigurations().getService(), modelHolder.get().service()); - var elserService = new ElasticsearchInternalService( new InferenceServiceExtension.InferenceServiceFactoryContext( mock(Client.class), mock(ThreadPool.class), - mock(ClusterService.class), + clusterService, Settings.EMPTY ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java index 35e50e557cc83..da04b05a5a332 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetrySettings.java @@ -16,7 +16,7 @@ public class RetrySettings { - static final Setting RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting( + public static final Setting RETRY_INITIAL_DELAY_SETTING = Setting.timeSetting( "xpack.inference.http.retry.initial_delay", TimeValue.timeValueSeconds(1), Setting.Property.NodeScope, @@ -30,7 +30,7 @@ public class RetrySettings { Setting.Property.Dynamic ); - static final Setting RETRY_TIMEOUT_SETTING = Setting.timeSetting( + public static final Setting RETRY_TIMEOUT_SETTING = Setting.timeSetting( "xpack.inference.http.retry.timeout", TimeValue.timeValueSeconds(30), Setting.Property.NodeScope, @@ -106,7 +106,7 @@ public static List> getSettingsDefinitions() { ); } - TimeValue getInitialDelay() { + public TimeValue getInitialDelay() { return initialDelay; } @@ -114,7 +114,7 @@ TimeValue getMaxDelayBound() { return maxDelayBound; } - TimeValue getTimeout() { + public TimeValue getTimeout() { return timeout; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 7435430f8af43..fa744d0ab18b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -9,14 +9,18 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.internal.hppc.IntIntHashMap; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RetryableAction; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; @@ -36,6 +40,7 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -57,6 +62,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -68,9 +74,12 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.IntUnaryOperator; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; @@ -121,10 +130,14 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi private static final String OLD_MODEL_ID_FIELD_NAME = "model_version"; private final Settings settings; + private final ThreadPool threadPool; + private final RetrySettings retrySettings; public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { super(context); this.settings = context.settings(); + this.threadPool = context.threadPool(); + this.retrySettings = new RetrySettings(context.settings(), context.clusterService()); } // for testing @@ -134,6 +147,8 @@ public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFa ) { super(context, platformArch); this.settings = context.settings(); + this.threadPool = context.threadPool(); + this.retrySettings = new RetrySettings(context.settings(), context.clusterService()); } @Override @@ -1126,10 +1141,150 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft if (maybeDeploy) { listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l)); } - client.execute(InferModelAction.INSTANCE, inferenceRequest, listener); + + new BatchExecutor(retrySettings.getInitialDelay(), retrySettings.getTimeout(), inferenceRequest, listener, inferenceExecutor) + .run(); + } + } + + private static final Set RETRYABLE_STATUS = Set.of( + RestStatus.INTERNAL_SERVER_ERROR, + RestStatus.TOO_MANY_REQUESTS, + RestStatus.REQUEST_TIMEOUT + ); + + private class BatchExecutor extends RetryableAction { + private final RetryState state; + + BatchExecutor( + TimeValue initialDelay, + TimeValue timeoutValue, + InferModelAction.Request request, + ActionListener listener, + Executor executor + ) { + this(initialDelay, timeoutValue, new RetryState(request), listener, executor); + } + + private BatchExecutor( + TimeValue initialDelay, + TimeValue timeoutValue, + RetryState state, + ActionListener listener, + Executor executor + ) { + super(logger, threadPool, initialDelay, timeoutValue, new ActionListener<>() { + @Override + public void onResponse(InferModelAction.Response response) { + listener.onResponse(state.getAccumulatedResponse(null)); + } + + @Override + public void onFailure(Exception exc) { + if (state.hasPartialResponse()) { + listener.onResponse(state.getAccumulatedResponse(exc instanceof RetryableException ? null : exc)); + } else { + listener.onFailure(exc); + } + } + }, executor); + this.state = state; + } + + @Override + public void tryAction(ActionListener listener) { + client.execute(InferModelAction.INSTANCE, state.getCurrentRequest(), new ActionListener<>() { + @Override + public void onResponse(InferModelAction.Response response) { + if (state.consumeResponse(response)) { + listener.onResponse(response); + } else { + listener.onFailure(new RetryableException()); + } + } + + @Override + public void onFailure(Exception exc) { + listener.onFailure(exc); + } + }); + } + + @Override + public boolean shouldRetry(Exception exc) { + return exc instanceof RetryableException + || RETRYABLE_STATUS.contains(ExceptionsHelper.status(ExceptionsHelper.unwrapCause(exc))); } } + private static class RetryState { + private final InferModelAction.Request originalRequest; + private InferModelAction.Request currentRequest; + + private IntUnaryOperator currentToOriginalIndex; + private final AtomicArray inferenceResults; + private final AtomicBoolean hasPartialResponse; + + private RetryState(InferModelAction.Request originalRequest) { + this.originalRequest = originalRequest; + this.currentRequest = originalRequest; + this.currentToOriginalIndex = index -> index; + this.inferenceResults = new AtomicArray<>(originalRequest.getTextInput().size()); + this.hasPartialResponse = new AtomicBoolean(); + } + + boolean hasPartialResponse() { + return hasPartialResponse.get(); + } + + InferModelAction.Request getCurrentRequest() { + return currentRequest; + } + + InferModelAction.Response getAccumulatedResponse(@Nullable Exception exc) { + List finalResults = new ArrayList<>(); + for (int i = 0; i < inferenceResults.length(); i++) { + var result = inferenceResults.get(i); + if (exc != null && result instanceof ErrorInferenceResults) { + finalResults.add(new ErrorInferenceResults(exc)); + } else { + finalResults.add(result); + } + } + return new InferModelAction.Response(finalResults, originalRequest.getId(), originalRequest.isPreviouslyLicensed()); + } + + private boolean consumeResponse(InferModelAction.Response response) { + hasPartialResponse.set(true); + List retryInputs = new ArrayList<>(); + IntIntHashMap newIndexMap = new IntIntHashMap(); + for (int i = 0; i < response.getInferenceResults().size(); i++) { + var result = response.getInferenceResults().get(i); + int index = currentToOriginalIndex.applyAsInt(i); + inferenceResults.set(index, result); + if (result instanceof ErrorInferenceResults error + && RETRYABLE_STATUS.contains(ExceptionsHelper.status(ExceptionsHelper.unwrapCause(error.getException())))) { + newIndexMap.put(retryInputs.size(), index); + retryInputs.add(originalRequest.getTextInput().get(index)); + } + } + if (retryInputs.isEmpty()) { + return true; + } + currentRequest = InferModelAction.Request.forTextInput( + originalRequest.getId(), + originalRequest.getUpdate(), + retryInputs, + originalRequest.isPreviouslyLicensed(), + originalRequest.getInferenceTimeout() + ); + currentToOriginalIndex = newIndexMap::get; + return false; + } + } + + private static class RetryableException extends Exception {} + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 27709c2067a26..870e0800fda1f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; @@ -76,6 +77,7 @@ import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; +import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.hamcrest.Matchers; import org.junit.After; @@ -89,6 +91,7 @@ import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -1036,28 +1039,19 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru assertTrue("Listener not called", gotResults.get()); } - public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException { + public void testChunkInfer_SparseWithNullChunkingSettings() throws Exception { testChunkInfer_Sparse(null); } - public void testChunkInfer_SparseWithChunkingSettingsSet() throws InterruptedException { + public void testChunkInfer_SparseWithChunkingSettingsSet() throws Exception { testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings()); } @SuppressWarnings("unchecked") - private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException { + private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Exception { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); - var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - - Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadPool); - doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(response); - return null; - }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); var model = new CustomElandModel( "foo", @@ -1066,46 +1060,30 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int new ElasticsearchInternalServiceSettings(1, 1, "model-id", null, null), chunkingSettings ); - var service = createService(client); - var gotResults = new AtomicBoolean(); - - var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); - var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); - assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); - assertEquals( - ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - ((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens() - ); - assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); - assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); - var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); - assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); - assertEquals( - ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - ((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens() - ); - assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); - gotResults.set(true); - }, ESTestCase::fail); - - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener<>(resultsListener, latch); - - service.chunkedInfer( + var chunkedResponse = chunkedInferWithFailures( model, - null, List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), - Map.of(), - InputType.SEARCH, - InferenceAction.Request.DEFAULT_TIMEOUT, - latchedListener + mlTrainedModelResults ); - latch.await(); - assertTrue("Listener not called", gotResults.get()); + assertThat(chunkedResponse, hasSize(2)); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbedding.class)); + var result1 = (ChunkedInferenceEmbedding) chunkedResponse.get(0); + assertThat(result1.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + assertEquals( + ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), + ((SparseEmbeddingResults.Embedding) result1.chunks().get(0).embedding()).tokens() + ); + assertEquals(new ChunkedInference.TextOffset(0, 1), result1.chunks().get(0).offset()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbedding.class)); + var result2 = (ChunkedInferenceEmbedding) chunkedResponse.get(1); + assertThat(result2.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); + assertEquals( + ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), + ((SparseEmbeddingResults.Embedding) result2.chunks().get(0).embedding()).tokens() + ); + assertEquals(new ChunkedInference.TextOffset(0, 2), result2.chunks().get(0).offset()); } public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException { @@ -1802,9 +1780,7 @@ private void testUpdateModelsWithDynamicFields(Map> modelsBy } public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException { - var cs = mock(ClusterService.class); - var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); - when(cs.getClusterSettings()).thenReturn(cSettings); + var cs = createClusterService(); var context = new InferenceServiceExtension.InferenceServiceFactoryContext( mock(), threadPool, @@ -1844,9 +1820,7 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); when(client.threadPool()).thenReturn(threadPool); - var cs = mock(ClusterService.class); - var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); - when(cs.getClusterSettings()).thenReturn(cSettings); + var cs = createClusterService(); var context = new InferenceServiceExtension.InferenceServiceFactoryContext( client, threadPool, @@ -1905,11 +1879,64 @@ public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { } } + @SuppressWarnings("unchecked") + private List chunkedInferWithFailures( + Model model, + List inputs, + List finalResults + ) throws Exception { + var response = new InferModelAction.Response(finalResults, "foo", true); + + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + AtomicInteger retryCount = new AtomicInteger(); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + int count = retryCount.incrementAndGet(); + if (count < 2) { + listener.onFailure(new Exception("boom")); + } else if (count < 3) { + listener.onResponse(replaceWithFailures(response)); + } else { + listener.onResponse(response); + } + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + try (var service = createService(client)) { + var actualResponse = new AtomicReference>(); + var resultsListener = ActionListener.>wrap( + chunkedResponse -> { actualResponse.set(chunkedResponse); }, + ESTestCase::fail + ); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + + service.chunkedInfer(model, null, inputs, Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener); + + latch.await(); + assertNotNull("Listener not called", actualResponse.get()); + return actualResponse.get(); + } + } + + private InferModelAction.Response replaceWithFailures(InferModelAction.Response current) { + List newResults = new ArrayList<>(); + for (var result : current.getInferenceResults()) { + if (result instanceof ErrorInferenceResults) { + newResults.add(result); + } else { + newResults.add(new ErrorInferenceResults(new Exception("boom"))); + } + } + return new InferModelAction.Response(newResults, "foo", true); + } + private ElasticsearchInternalService createService(Client client) { - var cs = mock(ClusterService.class); - var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); - when(cs.getClusterSettings()).thenReturn(cSettings); - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY); + var cs = createClusterService(); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, cs.getSettings()); return new ElasticsearchInternalService(context); } @@ -1917,9 +1944,24 @@ private ElasticsearchInternalService createService(Client client, BaseElasticsea var context = new InferenceServiceExtension.InferenceServiceFactoryContext( client, threadPool, - mock(ClusterService.class), + createClusterService(), Settings.EMPTY ); return new ElasticsearchInternalService(context, l -> l.onResponse(modelVariant)); } + + private static ClusterService createClusterService() { + Set> addSettings = new HashSet<>(RetrySettings.getSettingsDefinitions()); + addSettings.add(MachineLearningField.MAX_LAZY_ML_NODES); + var cs = mock(ClusterService.class); + + var settings = Settings.builder() + .put(RetrySettings.RETRY_INITIAL_DELAY_SETTING.getKey(), "1ms") + .put(RetrySettings.RETRY_TIMEOUT_SETTING.getKey(), "100ms") + .build(); + when(cs.getSettings()).thenReturn(settings); + var cSettings = new ClusterSettings(settings, addSettings); + when(cs.getClusterSettings()).thenReturn(cSettings); + return cs; + } }