Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions metaflow/cli_components/run_cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ def wrapper(*args, **kwargs):
return wrapper


def _get_origin_run_tags(flow_name, origin_run_id):
"""
Retrieve user tags from the origin run for tag propagation on resume.

Returns a list of user tags, or an empty list if the origin run
cannot be found or has no tags.
"""
try:
from ..client.core import Run

origin_run = Run("%s/%s" % (flow_name, origin_run_id), _namespace_check=False)
return list(origin_run.user_tags)
except Exception:
# If we can't read the origin run's tags (e.g. metadata service
Comment on lines +192 to +198
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Overly broad exception suppression

The bare except Exception swallows everything — including AttributeError, ImportError, or other programming errors that signal a real bug. Consider catching only the expected failure modes (e.g. MetaflowNotFound, network errors) and letting unexpected exceptions surface, or at minimum logging a warning so callers know tags were silently dropped.

Suggested change
try:
from ..client.core import Run
origin_run = Run("%s/%s" % (flow_name, origin_run_id), _namespace_check=False)
return list(origin_run.user_tags)
except Exception:
# If we can't read the origin run's tags (e.g. metadata service
except Exception as e:
# If we can't read the origin run's tags (e.g. metadata service
# unavailable), we proceed without propagating tags.
import logging
logging.getLogger(__name__).debug(
"Could not retrieve tags for origin run %s/%s: %s",
flow_name,
origin_run_id,
e,
)
return []

# unavailable), we proceed without propagating tags.
return []


@click.option(
"--origin-run-id",
default=None,
Expand Down Expand Up @@ -237,15 +255,22 @@ def resume(
resume_identifier=None,
runner_attribute_file=None,
):
before_run(obj, tags, decospecs)

if origin_run_id is None:
origin_run_id = get_latest_run_id(obj.echo, obj.flow.name)
if origin_run_id is None:
raise CommandException(
"A previous run id was not found. Specify --origin-run-id."
)

# Propagate user tags from the origin run to the resumed run.
# This ensures that tags set during the original run (e.g. via
# current.run.add_tag()) are carried over to the resumed run.
origin_tags = _get_origin_run_tags(obj.flow.name, origin_run_id)
if origin_tags:
tags = tuple(set(tags or ()) | set(origin_tags))

before_run(obj, tags, decospecs)

if step_to_rerun is None:
steps_to_rerun = set()
else:
Expand Down
16 changes: 16 additions & 0 deletions metaflow/cli_components/step_cmd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import os

from metaflow._vendor import click

from .. import namespace
Expand Down Expand Up @@ -143,6 +146,19 @@ def step(
cli_args._set_step_kwargs(step_kwargs)

ctx.obj.metadata.add_sticky_tags(tags=opt_tag)

# Support trigger-time tags passed via environment variable.
# This allows tags to be specified at trigger time (e.g. via CLI or Deployer API)
# rather than only at deploy time.
trigger_tags_env = os.environ.get("METAFLOW_TRIGGER_TAGS")
if trigger_tags_env:
try:
trigger_tags = json.loads(trigger_tags_env)
if isinstance(trigger_tags, list) and trigger_tags:
ctx.obj.metadata.add_sticky_tags(tags=trigger_tags)
except (json.JSONDecodeError, TypeError):
pass

if not input_paths and input_paths_filename:
with open(input_paths_filename, mode="r", encoding="utf-8") as f:
input_paths = f.read().strip(" \n\"'")
Expand Down
36 changes: 24 additions & 12 deletions metaflow/plugins/argo/argo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,27 +279,39 @@ def _patch_workflow(self, name, body):
json.loads(e.body)["message"] if e.body is not None else e.reason
)

def trigger_workflow_template(self, name, usertype, username, parameters={}):
def trigger_workflow_template(
self, name, usertype, username, parameters={}, tags=None
):
client = self._client.get()
# Build the list of workflow parameters from user-provided flow parameters.
workflow_params = [
{"name": k, "value": json.dumps(v)} for k, v in parameters.items()
]
# Pass trigger-time tags as a reserved workflow parameter so that
# running steps can read them via the METAFLOW_TRIGGER_TAGS env var.
if tags:
workflow_params.append(
{"name": "metaflow-trigger-tags", "value": json.dumps(tags)}
)

annotations = {
"metaflow/triggered_by_user": json.dumps(
{"type": usertype, "name": username}
)
}
if tags:
annotations["metaflow/trigger_tags"] = json.dumps(tags)

body = {
"apiVersion": "argoproj.io/v1alpha1",
"kind": "Workflow",
"metadata": {
"generateName": name + "-",
"annotations": {
"metaflow/triggered_by_user": json.dumps(
{"type": usertype, "name": username}
)
},
"annotations": annotations,
},
"spec": {
"workflowTemplateRef": {"name": name},
"arguments": {
"parameters": [
{"name": k, "value": json.dumps(v)}
for k, v in parameters.items()
]
},
"arguments": {"parameters": workflow_params},
},
}
try:
Expand Down
16 changes: 15 additions & 1 deletion metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def parse_incident_io_metadata(metadata: List[str] = None):
return parsed_metadata

@classmethod
def trigger(cls, name, parameters=None):
def trigger(cls, name, parameters=None, tags=None):
if parameters is None:
parameters = {}
try:
Expand Down Expand Up @@ -395,6 +395,7 @@ def trigger(cls, name, parameters=None):
usertype,
username,
parameters,
tags=tags,
)
except Exception as e:
raise ArgoWorkflowsException(str(e))
Expand Down Expand Up @@ -965,6 +966,17 @@ def _compile_workflow_template(self):
.description("auto-set by metaflow. safe to ignore.")
for event in self.triggers
]
+ [
# Reserved parameter for trigger-time tags.
# Tags passed at trigger time (via CLI --tag or Deployer API)
# are set as this parameter, which gets read by each step
# via the METAFLOW_TRIGGER_TAGS env var.
Parameter("metaflow-trigger-tags")
.value("[]")
.description(
"auto-set by metaflow for trigger-time tags. safe to ignore."
),
]
)
)
# Set common pod metadata.
Expand Down Expand Up @@ -2371,6 +2383,8 @@ def _container_templates(self):
"ARGO_WORKFLOW_TEMPLATE": self.name,
"ARGO_WORKFLOW_NAME": "{{workflow.name}}",
"ARGO_WORKFLOW_NAMESPACE": KUBERNETES_NAMESPACE,
# Trigger-time tags passed as a workflow parameter.
"METAFLOW_TRIGGER_TAGS": "{{workflow.parameters.metaflow-trigger-tags}}",
},
**self.metadata.get_runtime_environment("argo-workflows"),
}
Expand Down
17 changes: 15 additions & 2 deletions metaflow/plugins/argo/argo_workflows_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,8 +924,19 @@ def resolve_token(
help="Write the metadata and pathspec of this run to the file specified.\nUsed internally for Metaflow's Deployer API.",
hidden=True,
)
@click.option(
"--tag",
"tags",
multiple=True,
default=None,
help="Annotate the triggered run with the given tag. You can specify "
"this option multiple times to attach multiple tags.",
)
@click.pass_obj
def trigger(obj, run_id_file=None, deployer_attribute_file=None, **kwargs):
def trigger(obj, run_id_file=None, deployer_attribute_file=None, tags=None, **kwargs):
if tags:
validate_tags(tags)

def _convert_value(param):
# Swap `-` with `_` in parameter name to match click's behavior
val = kwargs.get(param.name.replace("-", "_").lower())
Expand Down Expand Up @@ -962,7 +973,9 @@ def _convert_value(param):
)
obj.echo("re-deploy your flow in order to get rid of this message.")
workflow_name_to_deploy = obj._v1_workflow_name
response = ArgoWorkflows.trigger(workflow_name_to_deploy, params)
response = ArgoWorkflows.trigger(
workflow_name_to_deploy, params, tags=list(tags) if tags else None
)
run_id = "argo-" + response["metadata"]["name"]

if run_id_file:
Expand Down
3 changes: 2 additions & 1 deletion metaflow/plugins/argo/argo_workflows_deployer_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def trigger(self, **kwargs) -> ArgoWorkflowsTriggeredRun:
----------
**kwargs : Any
Additional arguments to pass to the trigger command,
`Parameters` in particular.
`Parameters` in particular. Use ``tag=["my_tag"]`` to
attach tags to the triggered run.

Returns
-------
Expand Down
36 changes: 34 additions & 2 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def terminate(cls, flow_name, name):
return response

@classmethod
def trigger(cls, name, parameters):
def trigger(cls, name, parameters, tags=None):
try:
state_machine = StepFunctionsClient().get(name)
except Exception as e:
Expand All @@ -202,7 +202,15 @@ def trigger(cls, name, parameters):
)

# Dump parameters into `Parameters` input field.
input = json.dumps({"Parameters": json.dumps(parameters)})
# Always include TriggerTags (defaulting to empty list) in the
# execution input. The state machine propagates this field through
# every step so that trigger-time tags are applied to all tasks.
input = json.dumps(
{
"Parameters": json.dumps(parameters),
"TriggerTags": json.dumps(tags if tags else []),
Comment on lines 204 to +211
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 TriggerTags double-JSON-encodes the list

tags if tags else [] is a list, and json.dumps(...) of that list gives a string like '["t1","t2"]'. That string is then embedded in the outer json.dumps(...), so what Step Functions receives for $.TriggerTags is a JSON-encoded string, not a list. When a step reads METAFLOW_TRIGGER_TAGS it gets '["t1","t2"]' and must call json.loads to recover the list — which the step-cmd code does, so it works. However this is an unusual convention worth a comment, and it must stay consistent with the Argo side where the parameter value is also json.dumps(tags). A brief comment here would clarify the intentional double-encoding.

}
)
# AWS Step Functions limits input to be 32KiB, but AWS Batch
# has its own limitation of 30KiB for job specification length.
# Reserving 10KiB for rest of the job specification leaves 20KiB
Expand Down Expand Up @@ -617,6 +625,11 @@ def _batch(self, node):
# start step to all subsequent tasks.
attrs["metaflow.run_id.$"] = "$$.Execution.Name"

# Propagate trigger-time tags from execution input to all steps.
# The trigger command always includes TriggerTags in the input.
attrs["metaflow.trigger_tags.$"] = "$.TriggerTags"
env["METAFLOW_TRIGGER_TAGS"] = "$.TriggerTags"
Comment on lines +628 to +631
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 EventBridge-scheduled executions will fail after redeployment

After redeploying a state machine with these changes, any execution that doesn't include TriggerTags in the input will throw a States.Runtime error when Step Functions tries to resolve $.TriggerTags. This includes all EventBridge-scheduled runs: event_bridge_client.py (line 57) passes {"Parameters": json.dumps({})} as the execution input — no TriggerTags key — so the JsonPath lookup fails immediately.

The Argo path avoids this by setting .value("[]") as a default on the workflow template parameter. SFN needs the same defensive default, which means updating EventBridgeClient._set() to include "TriggerTags": json.dumps([]) in its Input:

"Input": json.dumps({"Parameters": json.dumps({}), "TriggerTags": json.dumps([])}),


# Initialize parameters for the flow in the `start` step.
parameters = self._process_parameters()
if parameters:
Expand Down Expand Up @@ -677,6 +690,11 @@ def _batch(self, node):
)
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
# Propagate trigger-time tags from the parent.
attrs["metaflow.trigger_tags.$"] = (
"$.Parameters.['metaflow.trigger_tags']"
)
env["METAFLOW_TRIGGER_TAGS"] = "$.Parameters.['metaflow.trigger_tags']"
else:
# Set appropriate environment variables for runtime replacement.
if len(node.in_funcs) == 1:
Expand All @@ -687,6 +705,13 @@ def _batch(self, node):
env["METAFLOW_PARENT_TASK_ID"] = "$.JobId"
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
# Propagate trigger-time tags from the parent.
attrs["metaflow.trigger_tags.$"] = (
"$.Parameters.['metaflow.trigger_tags']"
)
env["METAFLOW_TRIGGER_TAGS"] = (
"$.Parameters.['metaflow.trigger_tags']"
)
else:
# Generate the input paths in a quasi-compressed format.
# See util.decompress_list for why this is written the way
Expand All @@ -698,6 +723,13 @@ def _batch(self, node):
)
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']"
# Propagate trigger-time tags from the first branch.
attrs["metaflow.trigger_tags.$"] = (
"$.[0].Parameters.['metaflow.trigger_tags']"
)
env["METAFLOW_TRIGGER_TAGS"] = (
"$.[0].Parameters.['metaflow.trigger_tags']"
)
for idx, _ in enumerate(node.in_funcs):
env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx
env["METAFLOW_PARENT_%s_STEP" % idx] = (
Expand Down
18 changes: 15 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from .production_token import load_token, new_token, store_token
from .step_functions import StepFunctions
from metaflow.tagging_util import validate_tags
from ..aws_utils import validate_aws_tag

VALID_NAME = re.compile(r"[^a-zA-Z0-9_\-\.]")
Expand Down Expand Up @@ -526,8 +525,19 @@ def resolve_token(
help="Write the metadata and pathspec of this run to the file specified.\nUsed internally for Metaflow's Deployer API.",
hidden=True,
)
@click.option(
"--tag",
"tags",
multiple=True,
default=None,
help="Annotate the triggered run with the given tag. You can specify "
"this option multiple times to attach multiple tags.",
)
@click.pass_obj
def trigger(obj, run_id_file=None, deployer_attribute_file=None, **kwargs):
def trigger(obj, run_id_file=None, deployer_attribute_file=None, tags=None, **kwargs):
if tags:
validate_tags(tags)

def _convert_value(param):
# Swap `-` with `_` in parameter name to match click's behavior
val = kwargs.get(param.name.replace("-", "_").lower())
Expand All @@ -543,7 +553,9 @@ def _convert_value(param):
if kwargs.get(param.name.replace("-", "_").lower()) is not None
}

response = StepFunctions.trigger(obj.state_machine_name, params)
response = StepFunctions.trigger(
obj.state_machine_name, params, tags=list(tags) if tags else None
)

id = response["executionArn"].split(":")[-1]
run_id = "sfn-" + id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def trigger(self, **kwargs) -> StepFunctionsTriggeredRun:
----------
**kwargs : Any
Additional arguments to pass to the trigger command,
`Parameters` in particular
`Parameters` in particular. Use ``tag=["my_tag"]`` to
attach tags to the triggered run.

Returns
-------
Expand Down
Loading
Loading