From 14943d915d3c892bf40e7c4dd9d73e3af3e89166 Mon Sep 17 00:00:00 2001 From: Madhur Tandon Date: Wed, 24 Jan 2024 21:08:29 +0530 Subject: [PATCH 1/5] add ability to terminate execution of a step-fn state machine --- metaflow/plugins/argo/argo_workflows.py | 4 +- .../aws/step_functions/step_functions.py | 49 +++++++++ .../aws/step_functions/step_functions_cli.py | 102 ++++++++++++++++++ .../step_functions/step_functions_client.py | 11 +- 4 files changed, 161 insertions(+), 5 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 1f1da1b6d09..cb3a01e520d 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -227,8 +227,8 @@ def delete(name): return schedule_deleted, sensor_deleted, workflow_deleted - @staticmethod - def terminate(flow_name, name): + @classmethod + def terminate(cls, flow_name, name): client = ArgoClient(namespace=KUBERNETES_NAMESPACE) response = client.terminate_workflow(name) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 7dd6854b725..9b648084062 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -166,6 +166,13 @@ def delete(cls, name): return schedule_deleted, sfn_deleted + @classmethod + def terminate(cls, flow_name, name): + client = StepFunctionsClient() + execution_arn, _, _, _, _, _ = cls.get_execution(flow_name, name) + response = client.terminate_execution(execution_arn) + return response + @classmethod def trigger(cls, name, parameters): try: @@ -234,6 +241,48 @@ def get_existing_deployment(cls, name): ) return None + @classmethod + def get_execution(cls, state_machine_name, name): + client = StepFunctionsClient() + try: + state_machine = client.get(state_machine_name) + except Exception as e: + raise StepFunctionsException(repr(e)) + if state_machine is None: + raise StepFunctionsException( + "The workflow *%s* doesn't exist " "on AWS Step Functions." % name + ) + try: + state_machine_arn = state_machine.get("stateMachineArn") + parameters = ( + json.loads(state_machine.get("definition")) + .get("States") + .get("start") + .get("Parameters") + .get("Parameters") + ) + + executions = client.list_executions(state_machine_arn, states=["RUNNING"]) + for execution in executions: + if execution.get("name") == name: + try: + return ( + execution.get("executionArn"), + parameters.get("metaflow.owner"), + parameters.get("metaflow.production_token"), + parameters.get("metaflow.flow_name"), + parameters.get("metaflow.branch_name", None), + parameters.get("metaflow.project_name", None), + ) + except KeyError: + raise StepFunctionsException( + "A non-metaflow workflow *%s* already exists in AWS Step Functions." + % name + ) + return None + except Exception as e: + raise StepFunctionsException(repr(e)) + def _compile(self): if self.flow._flow_decorators.get("trigger") or self.flow._flow_decorators.get( "trigger_on_finish" diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 63b1b645424..a7b3c4a1d86 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -26,6 +26,10 @@ class IncorrectProductionToken(MetaflowException): headline = "Incorrect production token" +class RunIdMismatch(MetaflowException): + headline = "Run ID mismatch" + + class IncorrectMetadataServiceVersion(MetaflowException): headline = "Incorrect version for metaflow service" @@ -614,6 +618,104 @@ def _token_instructions(flow_name, prev_user): ) +@step_functions.command(help="Terminate flow execution on Step Functions.") +@click.option( + "--authorize", + default=None, + type=str, + help="Authorize the termination with a production token", +) +@click.argument("run-id", required=True, type=str) +@click.pass_obj +def terminate(obj, run_id, authorize=None): + def _token_instructions(flow_name, prev_user): + obj.echo( + "There is an existing version of *%s* on AWS Step Functions which was " + "deployed by the user *%s*." % (flow_name, prev_user) + ) + obj.echo( + "To terminate this flow, you need to use the same production token that they used." + ) + obj.echo( + "Please reach out to them to get the token. Once you have it, call " + "this command:" + ) + obj.echo(" step-functions terminate --authorize MY_TOKEN RUN_ID", fg="green") + obj.echo( + 'See "Organizing Results" at docs.metaflow.org for more information ' + "about production tokens." + ) + + validate_run_id( + obj.state_machine_name, obj.token_prefix, authorize, run_id, _token_instructions + ) + + # Trim prefix from run_id + name = run_id[4:] + obj.echo( + "Terminating run *{run_id}* for {flow_name} ...".format( + run_id=run_id, flow_name=obj.flow.name + ), + bold=True, + ) + + terminated = StepFunctions.terminate(obj.state_machine_name, name) + if terminated: + obj.echo(f"\nRun terminated at {terminated.get('stopDate')}.") + + +def validate_run_id( + state_machine_name, token_prefix, authorize, run_id, instructions_fn=None +): + if not run_id.startswith("sfn-"): + raise RunIdMismatch( + "Run IDs for flows executed through AWS Step Functions begin with 'sfn-'" + ) + + name = run_id[4:] + execution = StepFunctions.get_execution(state_machine_name, name) + if execution is None: + raise MetaflowException( + "Could not find the execution *%s* (in RUNNING state) on AWS Step Functions" + % name + ) + + _, owner, token, flow_name, branch_name, project_name = execution + + if current.flow_name != flow_name: + raise RunIdMismatch( + "The workflow with the run_id *%s* belongs to the flow *%s*, not for the flow *%s*." + % (run_id, flow_name, current.flow_name) + ) + + if project_name is not None: + if current.get("project_name") != project_name: + raise RunIdMismatch( + "The workflow belongs to the project *%s*. " + "Please use the project decorator or --name to target the correct project" + % project_name + ) + + if current.get("branch_name") != branch_name: + raise RunIdMismatch( + "The workflow belongs to the branch *%s*. " + "Please use --branch, --production or --name to target the correct branch" + % branch_name + ) + + if authorize is None: + authorize = load_token(token_prefix) + elif authorize.startswith("production:"): + authorize = authorize[11:] + + if owner != get_username() and authorize != token: + if instructions_fn: + instructions_fn(flow_name=name, prev_user=owner) + raise IncorrectProductionToken("Try again with the correct production token.") + + return True + + def validate_token(name, token_prefix, authorize, instruction_fn=None): """ Validate that the production token matches that of the deployed flow. diff --git a/metaflow/plugins/aws/step_functions/step_functions_client.py b/metaflow/plugins/aws/step_functions/step_functions_client.py index f7418f15427..f3187634c95 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_client.py +++ b/metaflow/plugins/aws/step_functions/step_functions_client.py @@ -81,9 +81,14 @@ def list_executions(self, state_machine_arn, states): for execution in page["executions"] ) - def terminate_execution(self, state_machine_arn, execution_arn): - # TODO - pass + def terminate_execution(self, execution_arn): + try: + response = self._client.stop_execution(executionArn=execution_arn) + return response + except self._client.exceptions.ExecutionDoesNotExist: + raise ValueError(f"The execution ARN {execution_arn} does not exist.") + except Exception as e: + raise e def _default_logging_configuration(self, log_execution_history): if log_execution_history: From ebc530e76aca91488871b62c5e48dc1b58705bee Mon Sep 17 00:00:00 2001 From: Madhur Tandon Date: Thu, 25 Jan 2024 00:05:13 +0530 Subject: [PATCH 2/5] don't use f-strings --- metaflow/plugins/aws/step_functions/step_functions_cli.py | 2 +- metaflow/plugins/aws/step_functions/step_functions_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index a7b3c4a1d86..7e5ea7a6e94 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -661,7 +661,7 @@ def _token_instructions(flow_name, prev_user): terminated = StepFunctions.terminate(obj.state_machine_name, name) if terminated: - obj.echo(f"\nRun terminated at {terminated.get('stopDate')}.") + obj.echo("\nRun terminated at %s." % terminated.get("stopDate")) def validate_run_id( diff --git a/metaflow/plugins/aws/step_functions/step_functions_client.py b/metaflow/plugins/aws/step_functions/step_functions_client.py index f3187634c95..ceec8e4d0ce 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_client.py +++ b/metaflow/plugins/aws/step_functions/step_functions_client.py @@ -86,7 +86,7 @@ def terminate_execution(self, execution_arn): response = self._client.stop_execution(executionArn=execution_arn) return response except self._client.exceptions.ExecutionDoesNotExist: - raise ValueError(f"The execution ARN {execution_arn} does not exist.") + raise ValueError("The execution ARN %s does not exist." % execution_arn) except Exception as e: raise e From 923eec4f369475061494359dfb23a465075d1bbc Mon Sep 17 00:00:00 2001 From: Madhur Tandon Date: Tue, 30 Jan 2024 02:19:46 +0530 Subject: [PATCH 3/5] remove checks for project, branch --- .../aws/step_functions/step_functions.py | 21 +++++++++-------- .../aws/step_functions/step_functions_cli.py | 23 ++++--------------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 9b648084062..f0e9d9e56fa 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -169,7 +169,7 @@ def delete(cls, name): @classmethod def terminate(cls, flow_name, name): client = StepFunctionsClient() - execution_arn, _, _, _, _, _ = cls.get_execution(flow_name, name) + execution_arn, _, _, _ = cls.get_execution(flow_name, name) response = client.terminate_execution(execution_arn) return response @@ -250,29 +250,30 @@ def get_execution(cls, state_machine_name, name): raise StepFunctionsException(repr(e)) if state_machine is None: raise StepFunctionsException( - "The workflow *%s* doesn't exist " "on AWS Step Functions." % name + "The workflow *%s* doesn't exist on AWS Step Functions." % name ) try: state_machine_arn = state_machine.get("stateMachineArn") - parameters = ( + environment_vars = ( json.loads(state_machine.get("definition")) .get("States") .get("start") .get("Parameters") - .get("Parameters") + .get("ContainerOverrides") + .get("Environment") ) - + parameters = { + item.get("Name"): item.get("Value") for item in environment_vars + } executions = client.list_executions(state_machine_arn, states=["RUNNING"]) for execution in executions: if execution.get("name") == name: try: return ( execution.get("executionArn"), - parameters.get("metaflow.owner"), - parameters.get("metaflow.production_token"), - parameters.get("metaflow.flow_name"), - parameters.get("metaflow.branch_name", None), - parameters.get("metaflow.project_name", None), + parameters.get("METAFLOW_OWNER"), + parameters.get("METAFLOW_PRODUCTION_TOKEN"), + parameters.get("SFN_STATE_MACHINE"), ) except KeyError: raise StepFunctionsException( diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 7e5ea7a6e94..6aba4042b18 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -680,29 +680,14 @@ def validate_run_id( % name ) - _, owner, token, flow_name, branch_name, project_name = execution + _, owner, token, sfn_state_machine = execution - if current.flow_name != flow_name: + if state_machine_name != sfn_state_machine: raise RunIdMismatch( - "The workflow with the run_id *%s* belongs to the flow *%s*, not for the flow *%s*." - % (run_id, flow_name, current.flow_name) + "The workflow with the run_id *%s* belongs to the state machine *%s*, not for the state machine *%s*." + % (run_id, sfn_state_machine, state_machine_name) ) - if project_name is not None: - if current.get("project_name") != project_name: - raise RunIdMismatch( - "The workflow belongs to the project *%s*. " - "Please use the project decorator or --name to target the correct project" - % project_name - ) - - if current.get("branch_name") != branch_name: - raise RunIdMismatch( - "The workflow belongs to the branch *%s*. " - "Please use --branch, --production or --name to target the correct branch" - % branch_name - ) - if authorize is None: authorize = load_token(token_prefix) elif authorize.startswith("production:"): From 750c0c34bf49d61667820ea45fa579cdf41b25cd Mon Sep 17 00:00:00 2001 From: Madhur Tandon Date: Tue, 30 Jan 2024 03:00:42 +0530 Subject: [PATCH 4/5] add comment --- metaflow/plugins/aws/step_functions/step_functions.py | 3 ++- metaflow/plugins/aws/step_functions/step_functions_cli.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index f0e9d9e56fa..261dc98fdcd 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -250,7 +250,8 @@ def get_execution(cls, state_machine_name, name): raise StepFunctionsException(repr(e)) if state_machine is None: raise StepFunctionsException( - "The workflow *%s* doesn't exist on AWS Step Functions." % name + "The state machine *%s* doesn't exist on AWS Step Functions." + % state_machine_name ) try: state_machine_arn = state_machine.get("stateMachineArn") diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 6aba4042b18..2bd4faaef7f 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -682,6 +682,8 @@ def validate_run_id( _, owner, token, sfn_state_machine = execution + # this snippet is probably never triggered since we fail early with not + # being able to find the state_machine itself if state_machine_name != sfn_state_machine: raise RunIdMismatch( "The workflow with the run_id *%s* belongs to the state machine *%s*, not for the state machine *%s*." From d7400b16873f4475a12a95f15543ec3c465b5f35 Mon Sep 17 00:00:00 2001 From: Madhur Tandon Date: Wed, 31 Jan 2024 01:53:59 +0530 Subject: [PATCH 5/5] suggested changes --- .../aws/step_functions/step_functions_cli.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 2bd4faaef7f..3ee971dd0f8 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -676,19 +676,11 @@ def validate_run_id( execution = StepFunctions.get_execution(state_machine_name, name) if execution is None: raise MetaflowException( - "Could not find the execution *%s* (in RUNNING state) on AWS Step Functions" - % name + "Could not find the execution *%s* (in RUNNING state) for the state machine *%s* on AWS Step Functions" + % (name, state_machine_name) ) - _, owner, token, sfn_state_machine = execution - - # this snippet is probably never triggered since we fail early with not - # being able to find the state_machine itself - if state_machine_name != sfn_state_machine: - raise RunIdMismatch( - "The workflow with the run_id *%s* belongs to the state machine *%s*, not for the state machine *%s*." - % (run_id, sfn_state_machine, state_machine_name) - ) + _, owner, token, _ = execution if authorize is None: authorize = load_token(token_prefix)