diff --git a/CHANGELOG.md b/CHANGELOG.md index f9a0033b..d98ac883 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Bug Fixes - Change REST status codes for RBAC and provisioning ([#1083](https://github.com/opensearch-project/flow-framework/pull/1083)) - Fix Config parser does not handle tenant_id field ([#1096](https://github.com/opensearch-project/flow-framework/pull/1096)) +- Complete action listener on failed synchronous workflow provisioning ([#1098](https://github.com/opensearch-project/opensearch-remote-metadata-sdk/pull/1098)) ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 93e0a5f6..cc784425 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -352,11 +352,11 @@ public void onResponse(WorkflowResponse workflowResponse) { @Override public void onFailure(Exception e) { - WorkflowTimeoutUtility.handleFailure(workflowId, e, isResponseSent, listener); + WorkflowTimeoutUtility.handleFailure(workflowId, e, listener); } }, true); } catch (Exception ex) { - WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener); + WorkflowTimeoutUtility.handleFailure(workflowId, ex, listener); } }, client.threadPool().executor(PROVISION_WORKFLOW_THREAD_POOL)); @@ -469,9 +469,25 @@ private void executeWorkflow( ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.FAILED); - }, exceptionState -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exceptionState); }) + if (isSyncExecution) { + listener.onFailure(new FlowFrameworkException(errorMessage, status)); + } else { + TenantAwareHelper.releaseProvision(tenantId); + } + }, exceptionState -> { + logger.error("Failed to update workflow state for workflow {}", workflowId, exceptionState); + if (isSyncExecution) { + listener.onFailure( + new FlowFrameworkException( + errorMessage + ". Failed to update workflow state after execution failure.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } else { + TenantAwareHelper.releaseProvision(tenantId); + } + }) ); } } - } diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index d5752e25..e3710083 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -78,7 +78,6 @@ public class ReprovisionWorkflowTransportAction extends HandledTransportAction { logger.info("updated workflow {} state to {}", workflowId, State.FAILED); - }, exceptionState -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exceptionState); }) + if (isSyncExecution) { + listener.onFailure(new FlowFrameworkException(errorMessage, status)); + } else { + TenantAwareHelper.releaseProvision(template.getTenantId()); + } + }, exceptionState -> { + logger.error("Failed to update workflow state for workflow {}", workflowId, exceptionState); + if (isSyncExecution) { + listener.onFailure( + new FlowFrameworkException( + errorMessage + ". Failed to update workflow state after execution failure.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } else { + TenantAwareHelper.releaseProvision(template.getTenantId()); + } + }) ); } } diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java index a724a450..fc23bcaa 100644 --- a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -14,6 +14,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.transport.GetWorkflowStateAction; import org.opensearch.flowframework.transport.GetWorkflowStateRequest; import org.opensearch.flowframework.transport.WorkflowResponse; @@ -164,25 +165,13 @@ public static void handleResponse( * * @param workflowId The unique identifier of the workflow. * @param e The exception that occurred during workflow execution. - * @param isResponseSent An atomic boolean to ensure the response is sent only once. * @param listener The listener to notify of the workflow failure. */ - public static void handleFailure( - String workflowId, - Exception e, - AtomicBoolean isResponseSent, - ActionListener listener - ) { - // Check if the failure has already been reported, and report it only if it hasn't been reported yet. - if (isResponseSent.compareAndSet(false, true)) { - FlowFrameworkException exception = new FlowFrameworkException( - "Failed to execute workflow " + workflowId, - ExceptionsHelper.status(e) - ); - listener.onFailure(exception); - } else { - logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId); - } + public static void handleFailure(String workflowId, Exception e, ActionListener listener) { + FlowFrameworkException exception = e instanceof FlowFrameworkException + ? (FlowFrameworkException) e + : new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e)); + listener.onFailure(exception); } /** @@ -207,8 +196,16 @@ public static void fetchWorkflowStateAfterTimeout( new GetWorkflowStateRequest(workflowId, false, tenantId), ActionListener.wrap( response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), - exception -> listener.onFailure( - new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) + // we don't want to fail the listener as provisioning is still ongoing + exception -> listener.onResponse( + new WorkflowResponse( + workflowId, + WorkflowState.builder() + .workflowId(workflowId) + .error("Workflow timed out, failed to fetch current state") + .state("UNKNOWN") + .build() + ) ) ) ); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index dddce34e..916c429f 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -13,10 +13,12 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -27,6 +29,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.Template; @@ -34,6 +37,8 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.index.get.GetResult; import org.opensearch.plugins.PluginsService; @@ -257,4 +262,58 @@ public void testFailedToRetrieveTemplateFromGlobalContext() { assertEquals("Failed to get template 1", exceptionCaptor.getValue().getMessage()); } + public void testProvisionWorkflowExecutionException() { + + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, Collections.emptyMap(), TimeValue.timeValueSeconds(5)); + + // Bypass client.get and stub success case + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + this.template.toXContent(builder, null); + BytesReference templateBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, workflowId, 1, 1, 1, true, templateBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + + // Bypass isWorkflowNotStarted and force true response + doAnswer(invocation -> { + Consumer> progressConsumer = invocation.getArgument(2); + progressConsumer.accept(Optional.of(ProvisioningProgress.NOT_STARTED)); + return null; + }).when(flowFrameworkIndicesHandler).getProvisioningProgress(any(), any(), any(), any()); + + // Bypass updateFlowFrameworkSystemIndexDoc and stub on response + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(3); + actionListener.onResponse(mock(UpdateResponse.class)); + return null; + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), nullable(String.class), anyMap(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any(), anyBoolean()); + + // Create a failed future for the workflow execution + PlainActionFuture failedFuture = PlainActionFuture.newFuture(); + failedFuture.onFailure(new RuntimeException("Simulated failure during workflow execution")); + ProcessNode failedProcessNode = mock(ProcessNode.class); + when(failedProcessNode.execute()).thenReturn(failedFuture); + when(workflowProcessSorter.sortProcessNodes(any(), any(), any(), any())).thenReturn(Collections.singletonList(failedProcessNode)); + + provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(FlowFrameworkException.class); + verify(listener, times(1)).onFailure(responseCaptor.capture()); + assertTrue(responseCaptor.getValue().getMessage().startsWith("Simulated failure during workflow execution")); + } } diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java index d07f48be..2409a4cd 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.transport; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -27,7 +28,9 @@ import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStep; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.plugins.PluginsService; import org.opensearch.remote.metadata.client.SdkClient; @@ -335,4 +338,65 @@ public void testFailedWorkflowStateRetrieval() throws Exception { assertEquals("Failed to get workflow state for workflow 1", exceptionCaptor.getValue().getMessage()); } + public void testReprovisionWorkflowExecutionException() throws Exception { + String workflowId = "1"; + + Template mockTemplate = mock(Template.class); + Workflow mockWorkflow = mock(Workflow.class); + Map mockWorkflows = new HashMap<>(); + mockWorkflows.put(PROVISION_WORKFLOW, mockWorkflow); + + // Stub validations + when(mockTemplate.workflows()).thenReturn(mockWorkflows); + when(workflowProcessSorter.sortProcessNodes(any(), any(), any(), any())).thenReturn(List.of()); + doNothing().when(workflowProcessSorter).validate(any(), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(mockTemplate); + + // Stub state and resources created + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + WorkflowState state = mock(WorkflowState.class); + ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); + when(state.getState()).thenReturn(State.COMPLETED.toString()); + when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); + when(state.getError()).thenReturn(null); + listener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(), any(GetWorkflowStateRequest.class), any()); + + // Create a failed future for the workflow execution + PlainActionFuture failedFuture = PlainActionFuture.newFuture(); + failedFuture.onFailure(new RuntimeException("Simulated failure during workflow execution")); + ProcessNode failedProcessNode = mock(ProcessNode.class); + when(failedProcessNode.execute()).thenReturn(failedFuture); + WorkflowStep mockStep = mock(WorkflowStep.class); + when(mockStep.getName()).thenReturn("FakeStep"); + when(failedProcessNode.workflowStep()).thenReturn(mockStep); + + // Stub reprovision sequence creation with the failed process node + when(workflowProcessSorter.createReprovisionSequence(any(), any(), any(), any(), any())).thenReturn(List.of(failedProcessNode)); + + // Bypass updateFlowFrameworkSystemIndexDoc and stub on response + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(3); + actionListener.onResponse(mock(UpdateResponse.class)); + return null; + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), nullable(String.class), anyMap(), any()); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest( + workflowId, + mockTemplate, + mockTemplate, + TimeValue.timeValueSeconds(5) + ); + + reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().startsWith("Simulated failure during workflow execution")); + } + }