diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 82272b70568..3183dcdc46d 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -182,6 +182,24 @@ def wrapper(*args, **kwargs): return wrapper +def _get_origin_run_tags(flow_name, origin_run_id): + """ + Retrieve user tags from the origin run for tag propagation on resume. + + Returns a list of user tags, or an empty list if the origin run + cannot be found or has no tags. + """ + try: + from ..client.core import Run + + origin_run = Run("%s/%s" % (flow_name, origin_run_id), _namespace_check=False) + return list(origin_run.user_tags) + except Exception: + # If we can't read the origin run's tags (e.g. metadata service + # unavailable), we proceed without propagating tags. + return [] + + @click.option( "--origin-run-id", default=None, @@ -237,8 +255,6 @@ def resume( resume_identifier=None, runner_attribute_file=None, ): - before_run(obj, tags, decospecs) - if origin_run_id is None: origin_run_id = get_latest_run_id(obj.echo, obj.flow.name) if origin_run_id is None: @@ -246,6 +262,15 @@ def resume( "A previous run id was not found. Specify --origin-run-id." ) + # Propagate user tags from the origin run to the resumed run. + # This ensures that tags set during the original run (e.g. via + # current.run.add_tag()) are carried over to the resumed run. + origin_tags = _get_origin_run_tags(obj.flow.name, origin_run_id) + if origin_tags: + tags = tuple(set(tags or ()) | set(origin_tags)) + + before_run(obj, tags, decospecs) + if step_to_rerun is None: steps_to_rerun = set() else: diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 24ca9d784a0..479ce028f3c 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -1,3 +1,6 @@ +import json +import os + from metaflow._vendor import click from .. import namespace @@ -143,6 +146,19 @@ def step( cli_args._set_step_kwargs(step_kwargs) ctx.obj.metadata.add_sticky_tags(tags=opt_tag) + + # Support trigger-time tags passed via environment variable. + # This allows tags to be specified at trigger time (e.g. via CLI or Deployer API) + # rather than only at deploy time. + trigger_tags_env = os.environ.get("METAFLOW_TRIGGER_TAGS") + if trigger_tags_env: + try: + trigger_tags = json.loads(trigger_tags_env) + if isinstance(trigger_tags, list) and trigger_tags: + ctx.obj.metadata.add_sticky_tags(tags=trigger_tags) + except (json.JSONDecodeError, TypeError): + pass + if not input_paths and input_paths_filename: with open(input_paths_filename, mode="r", encoding="utf-8") as f: input_paths = f.read().strip(" \n\"'") diff --git a/metaflow/plugins/argo/argo_client.py b/metaflow/plugins/argo/argo_client.py index ea961819b8f..a81c70d85e3 100644 --- a/metaflow/plugins/argo/argo_client.py +++ b/metaflow/plugins/argo/argo_client.py @@ -279,27 +279,39 @@ def _patch_workflow(self, name, body): json.loads(e.body)["message"] if e.body is not None else e.reason ) - def trigger_workflow_template(self, name, usertype, username, parameters={}): + def trigger_workflow_template( + self, name, usertype, username, parameters={}, tags=None + ): client = self._client.get() + # Build the list of workflow parameters from user-provided flow parameters. + workflow_params = [ + {"name": k, "value": json.dumps(v)} for k, v in parameters.items() + ] + # Pass trigger-time tags as a reserved workflow parameter so that + # running steps can read them via the METAFLOW_TRIGGER_TAGS env var. + if tags: + workflow_params.append( + {"name": "metaflow-trigger-tags", "value": json.dumps(tags)} + ) + + annotations = { + "metaflow/triggered_by_user": json.dumps( + {"type": usertype, "name": username} + ) + } + if tags: + annotations["metaflow/trigger_tags"] = json.dumps(tags) + body = { "apiVersion": "argoproj.io/v1alpha1", "kind": "Workflow", "metadata": { "generateName": name + "-", - "annotations": { - "metaflow/triggered_by_user": json.dumps( - {"type": usertype, "name": username} - ) - }, + "annotations": annotations, }, "spec": { "workflowTemplateRef": {"name": name}, - "arguments": { - "parameters": [ - {"name": k, "value": json.dumps(v)} - for k, v in parameters.items() - ] - }, + "arguments": {"parameters": workflow_params}, }, } try: diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index b7b25c8c69d..10c52a9528e 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -359,7 +359,7 @@ def parse_incident_io_metadata(metadata: List[str] = None): return parsed_metadata @classmethod - def trigger(cls, name, parameters=None): + def trigger(cls, name, parameters=None, tags=None): if parameters is None: parameters = {} try: @@ -395,6 +395,7 @@ def trigger(cls, name, parameters=None): usertype, username, parameters, + tags=tags, ) except Exception as e: raise ArgoWorkflowsException(str(e)) @@ -965,6 +966,17 @@ def _compile_workflow_template(self): .description("auto-set by metaflow. safe to ignore.") for event in self.triggers ] + + [ + # Reserved parameter for trigger-time tags. + # Tags passed at trigger time (via CLI --tag or Deployer API) + # are set as this parameter, which gets read by each step + # via the METAFLOW_TRIGGER_TAGS env var. + Parameter("metaflow-trigger-tags") + .value("[]") + .description( + "auto-set by metaflow for trigger-time tags. safe to ignore." + ), + ] ) ) # Set common pod metadata. @@ -2371,6 +2383,8 @@ def _container_templates(self): "ARGO_WORKFLOW_TEMPLATE": self.name, "ARGO_WORKFLOW_NAME": "{{workflow.name}}", "ARGO_WORKFLOW_NAMESPACE": KUBERNETES_NAMESPACE, + # Trigger-time tags passed as a workflow parameter. + "METAFLOW_TRIGGER_TAGS": "{{workflow.parameters.metaflow-trigger-tags}}", }, **self.metadata.get_runtime_environment("argo-workflows"), } diff --git a/metaflow/plugins/argo/argo_workflows_cli.py b/metaflow/plugins/argo/argo_workflows_cli.py index 6abc5e93b57..ef781b4289c 100644 --- a/metaflow/plugins/argo/argo_workflows_cli.py +++ b/metaflow/plugins/argo/argo_workflows_cli.py @@ -924,8 +924,19 @@ def resolve_token( help="Write the metadata and pathspec of this run to the file specified.\nUsed internally for Metaflow's Deployer API.", hidden=True, ) +@click.option( + "--tag", + "tags", + multiple=True, + default=None, + help="Annotate the triggered run with the given tag. You can specify " + "this option multiple times to attach multiple tags.", +) @click.pass_obj -def trigger(obj, run_id_file=None, deployer_attribute_file=None, **kwargs): +def trigger(obj, run_id_file=None, deployer_attribute_file=None, tags=None, **kwargs): + if tags: + validate_tags(tags) + def _convert_value(param): # Swap `-` with `_` in parameter name to match click's behavior val = kwargs.get(param.name.replace("-", "_").lower()) @@ -962,7 +973,9 @@ def _convert_value(param): ) obj.echo("re-deploy your flow in order to get rid of this message.") workflow_name_to_deploy = obj._v1_workflow_name - response = ArgoWorkflows.trigger(workflow_name_to_deploy, params) + response = ArgoWorkflows.trigger( + workflow_name_to_deploy, params, tags=list(tags) if tags else None + ) run_id = "argo-" + response["metadata"]["name"] if run_id_file: diff --git a/metaflow/plugins/argo/argo_workflows_deployer_objects.py b/metaflow/plugins/argo/argo_workflows_deployer_objects.py index 165fc2d92ae..1cdbdd94ea6 100644 --- a/metaflow/plugins/argo/argo_workflows_deployer_objects.py +++ b/metaflow/plugins/argo/argo_workflows_deployer_objects.py @@ -405,7 +405,8 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun: ---------- **kwargs : Any Additional arguments to pass to the trigger command, - `Parameters` in particular. + `Parameters` in particular. Use ``tag=["my_tag"]`` to + attach tags to the triggered run. Returns ------- diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 3c36a97efe9..bcc422cf9ca 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -189,7 +189,7 @@ def terminate(cls, flow_name, name): return response @classmethod - def trigger(cls, name, parameters): + def trigger(cls, name, parameters, tags=None): try: state_machine = StepFunctionsClient().get(name) except Exception as e: @@ -202,7 +202,15 @@ def trigger(cls, name, parameters): ) # Dump parameters into `Parameters` input field. - input = json.dumps({"Parameters": json.dumps(parameters)}) + # Always include TriggerTags (defaulting to empty list) in the + # execution input. The state machine propagates this field through + # every step so that trigger-time tags are applied to all tasks. + input = json.dumps( + { + "Parameters": json.dumps(parameters), + "TriggerTags": json.dumps(tags if tags else []), + } + ) # AWS Step Functions limits input to be 32KiB, but AWS Batch # has its own limitation of 30KiB for job specification length. # Reserving 10KiB for rest of the job specification leaves 20KiB @@ -617,6 +625,11 @@ def _batch(self, node): # start step to all subsequent tasks. attrs["metaflow.run_id.$"] = "$$.Execution.Name" + # Propagate trigger-time tags from execution input to all steps. + # The trigger command always includes TriggerTags in the input. + attrs["metaflow.trigger_tags.$"] = "$.TriggerTags" + env["METAFLOW_TRIGGER_TAGS"] = "$.TriggerTags" + # Initialize parameters for the flow in the `start` step. parameters = self._process_parameters() if parameters: @@ -677,6 +690,11 @@ def _batch(self, node): ) # Inherit the run id from the parent and pass it along to children. attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']" + # Propagate trigger-time tags from the parent. + attrs["metaflow.trigger_tags.$"] = ( + "$.Parameters.['metaflow.trigger_tags']" + ) + env["METAFLOW_TRIGGER_TAGS"] = "$.Parameters.['metaflow.trigger_tags']" else: # Set appropriate environment variables for runtime replacement. if len(node.in_funcs) == 1: @@ -687,6 +705,13 @@ def _batch(self, node): env["METAFLOW_PARENT_TASK_ID"] = "$.JobId" # Inherit the run id from the parent and pass it along to children. attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']" + # Propagate trigger-time tags from the parent. + attrs["metaflow.trigger_tags.$"] = ( + "$.Parameters.['metaflow.trigger_tags']" + ) + env["METAFLOW_TRIGGER_TAGS"] = ( + "$.Parameters.['metaflow.trigger_tags']" + ) else: # Generate the input paths in a quasi-compressed format. # See util.decompress_list for why this is written the way @@ -698,6 +723,13 @@ def _batch(self, node): ) # Inherit the run id from the parent and pass it along to children. attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']" + # Propagate trigger-time tags from the first branch. + attrs["metaflow.trigger_tags.$"] = ( + "$.[0].Parameters.['metaflow.trigger_tags']" + ) + env["METAFLOW_TRIGGER_TAGS"] = ( + "$.[0].Parameters.['metaflow.trigger_tags']" + ) for idx, _ in enumerate(node.in_funcs): env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx env["METAFLOW_PARENT_%s_STEP" % idx] = ( diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 8a839c9aa3b..55cfb853fe2 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -20,7 +20,6 @@ from .production_token import load_token, new_token, store_token from .step_functions import StepFunctions -from metaflow.tagging_util import validate_tags from ..aws_utils import validate_aws_tag VALID_NAME = re.compile(r"[^a-zA-Z0-9_\-\.]") @@ -526,8 +525,19 @@ def resolve_token( help="Write the metadata and pathspec of this run to the file specified.\nUsed internally for Metaflow's Deployer API.", hidden=True, ) +@click.option( + "--tag", + "tags", + multiple=True, + default=None, + help="Annotate the triggered run with the given tag. You can specify " + "this option multiple times to attach multiple tags.", +) @click.pass_obj -def trigger(obj, run_id_file=None, deployer_attribute_file=None, **kwargs): +def trigger(obj, run_id_file=None, deployer_attribute_file=None, tags=None, **kwargs): + if tags: + validate_tags(tags) + def _convert_value(param): # Swap `-` with `_` in parameter name to match click's behavior val = kwargs.get(param.name.replace("-", "_").lower()) @@ -543,7 +553,9 @@ def _convert_value(param): if kwargs.get(param.name.replace("-", "_").lower()) is not None } - response = StepFunctions.trigger(obj.state_machine_name, params) + response = StepFunctions.trigger( + obj.state_machine_name, params, tags=list(tags) if tags else None + ) id = response["executionArn"].split(":")[-1] run_id = "sfn-" + id diff --git a/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py index 161eb91a638..7e9ab4baf91 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +++ b/metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py @@ -215,7 +215,8 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun: ---------- **kwargs : Any Additional arguments to pass to the trigger command, - `Parameters` in particular + `Parameters` in particular. Use ``tag=["my_tag"]`` to + attach tags to the triggered run. Returns ------- diff --git a/test/unit/test_tag_improvements.py b/test/unit/test_tag_improvements.py new file mode 100644 index 00000000000..29590a33910 --- /dev/null +++ b/test/unit/test_tag_improvements.py @@ -0,0 +1,193 @@ +""" +Tests for tag improvements: +1. Trigger-time tags via METAFLOW_TRIGGER_TAGS env var (Issue #1243) +2. Resume tag propagation from origin run (Issue #1406) +""" + +import json + + +class TestTriggerTimeTags: + """Tests for METAFLOW_TRIGGER_TAGS env var support in step_cmd.""" + + def test_trigger_tags_env_parsed(self): + """Verify that METAFLOW_TRIGGER_TAGS env var is parsed as JSON list.""" + tags = ["tag1", "tag2"] + env_val = json.dumps(tags) + parsed = json.loads(env_val) + assert parsed == tags + + def test_trigger_tags_empty_list_ignored(self): + """Empty list should not add any sticky tags.""" + tags = [] + env_val = json.dumps(tags) + parsed = json.loads(env_val) + assert isinstance(parsed, list) and not parsed + + def test_trigger_tags_invalid_json_handled(self): + """Invalid JSON should not raise, just be ignored.""" + env_val = "not valid json{{" + try: + json.loads(env_val) + parsed = True + except (json.JSONDecodeError, TypeError): + parsed = False + assert not parsed + + def test_trigger_tags_non_list_ignored(self): + """Non-list JSON (e.g. a string) should be ignored.""" + env_val = json.dumps("just a string") + parsed = json.loads(env_val) + assert not isinstance(parsed, list) + + +class TestArgoTriggerTags: + """Tests for Argo Workflows trigger-time tag support.""" + + def test_argo_client_trigger_with_tags(self): + """Verify that tags are included in workflow parameters and annotations.""" + from metaflow.plugins.argo.argo_client import ArgoClient + + # We can't easily test the full client, but we can verify the + # trigger_workflow_template signature accepts tags. + import inspect + + sig = inspect.signature(ArgoClient.trigger_workflow_template) + assert "tags" in sig.parameters + + def test_argo_trigger_tags_parameter_in_workflow(self): + """Verify metaflow-trigger-tags is a recognized parameter name.""" + # This tests that the parameter name constant is used consistently. + param_name = "metaflow-trigger-tags" + env_var = "METAFLOW_TRIGGER_TAGS" + + # The workflow template uses {{workflow.parameters.metaflow-trigger-tags}} + template_ref = "{{workflow.parameters.%s}}" % param_name + assert param_name in template_ref + assert env_var == "METAFLOW_TRIGGER_TAGS" + + +class TestSFNTriggerTags: + """Tests for Step Functions trigger-time tag support.""" + + def test_sfn_trigger_includes_trigger_tags(self): + """Verify trigger method signature accepts tags.""" + from metaflow.plugins.aws.step_functions.step_functions import StepFunctions + + import inspect + + sig = inspect.signature(StepFunctions.trigger) + assert "tags" in sig.parameters + + def test_sfn_trigger_input_format(self): + """Verify the execution input format includes TriggerTags.""" + parameters = {"alpha": "1"} + tags = ["tag1", "tag2"] + + # This mirrors the logic in StepFunctions.trigger() + input_data = json.dumps( + { + "Parameters": json.dumps(parameters), + "TriggerTags": json.dumps(tags), + } + ) + parsed = json.loads(input_data) + assert "TriggerTags" in parsed + assert json.loads(parsed["TriggerTags"]) == ["tag1", "tag2"] + + def test_sfn_trigger_input_no_tags(self): + """Verify TriggerTags defaults to empty list when no tags provided.""" + parameters = {"alpha": "1"} + tags = None + + input_data = json.dumps( + { + "Parameters": json.dumps(parameters), + "TriggerTags": json.dumps(tags if tags else []), + } + ) + parsed = json.loads(input_data) + assert json.loads(parsed["TriggerTags"]) == [] + + +class TestResumeTags: + """Tests for resume tag propagation (Issue #1406).""" + + def test_get_origin_run_tags_function_exists(self): + """Verify the helper function is importable.""" + from metaflow.cli_components.run_cmds import _get_origin_run_tags + + assert callable(_get_origin_run_tags) + + def test_get_origin_run_tags_handles_missing_run(self): + """If the origin run can't be found, return empty list.""" + from metaflow.cli_components.run_cmds import _get_origin_run_tags + + # A non-existent flow/run should return empty list, not raise. + result = _get_origin_run_tags("NonExistentFlow", "nonexistent_run_id") + assert result == [] + + def test_resume_tags_merge_logic(self): + """Verify that origin tags are merged with CLI tags correctly.""" + # Simulates the merge logic in resume() + cli_tags = ("cli_tag1", "cli_tag2") + origin_tags = ["origin_tag1", "cli_tag1"] # cli_tag1 overlaps + + merged = tuple(set(cli_tags) | set(origin_tags)) + assert "cli_tag1" in merged + assert "cli_tag2" in merged + assert "origin_tag1" in merged + assert len(merged) == 3 # deduped + + def test_resume_tags_none_cli_tags(self): + """If no CLI tags provided, origin tags should still be applied.""" + cli_tags = None + origin_tags = ["origin_tag1"] + + merged = tuple(set(cli_tags or ()) | set(origin_tags)) + assert merged == ("origin_tag1",) + + def test_resume_tags_no_origin_tags(self): + """If origin run has no tags, CLI tags should be unaffected.""" + cli_tags = ("cli_tag1",) + origin_tags = [] + + # The code checks `if origin_tags:` first + if origin_tags: + merged = tuple(set(cli_tags or ()) | set(origin_tags)) + else: + merged = cli_tags + + assert merged == ("cli_tag1",) + + +class TestCLITagOption: + """Tests for --tag CLI option on trigger commands.""" + + def test_argo_trigger_has_tag_option(self): + """Verify the Argo trigger CLI command has a --tag option.""" + from metaflow.plugins.argo.argo_workflows_cli import trigger + + param_names = [p.name for p in trigger.params] + assert "tags" in param_names + + def test_sfn_trigger_has_tag_option(self): + """Verify the SFN trigger CLI command has a --tag option.""" + from metaflow.plugins.aws.step_functions.step_functions_cli import trigger + + param_names = [p.name for p in trigger.params] + assert "tags" in param_names + + def test_argo_trigger_tag_is_multiple(self): + """Verify the --tag option accepts multiple values.""" + from metaflow.plugins.argo.argo_workflows_cli import trigger + + tag_param = [p for p in trigger.params if p.name == "tags"][0] + assert tag_param.multiple is True + + def test_sfn_trigger_tag_is_multiple(self): + """Verify the --tag option accepts multiple values.""" + from metaflow.plugins.aws.step_functions.step_functions_cli import trigger + + tag_param = [p for p in trigger.params if p.name == "tags"][0] + assert tag_param.multiple is True