Skip to content

Commit 583fc92

Browse files
talsperreclaude
andcommitted
Fix start/end step regressions: events, SFN, cards, argo, graph, lint
Core fixes: - events.py: use end_task.parent.id instead of hardcoded run_obj["end"] - step_functions.py: use StartAt from definition JSON (both get_existing_deployment and get_execution) instead of hardcoded ["States"]["start"] - argo_workflows.py: _matching_conditional_join uses self.graph.end_step instead of hardcoded "end" fallback - graph.py: None guard in output_steps() raises clear ValueError - client/core.py: narrow _graph_endpoints exception caching to (KeyError, MetaflowNotFound); transient errors return uncached fallback - runtime.py: guard metadata registration with is_cloned check + try/except Cards: - card_cli.py: wrap graph_dict in payload with start_step/end_step metadata - basic.py: accept new payload format, pass through start/end - dag.svelte, step-wrapper.svelte, types.ts: dynamic start/end props Lint: - Improved check_basic_steps to distinguish "no start step" from "multiple @step(start=True) annotations" - New check_annotation_name_conflict: warns when @step(start=True) coexists with a step named "start" (and likewise for end) Tests: - Negative-path tests for malformed annotation patterns (8 cases) - Card rendering with custom endpoints - Trigger.from_runs() with custom terminal step - Step Functions StartAt lookup Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 021c5a8 commit 583fc92

22 files changed

Lines changed: 958 additions & 412 deletions

File tree

metaflow/client/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2192,8 +2192,13 @@ def _graph_endpoints(self):
21922192
params_meta = self["_parameters"].task.metadata_dict
21932193
start = params_meta.get("start_step", "start")
21942194
end = params_meta.get("end_step", "end")
2195-
except Exception:
2195+
except (KeyError, MetaflowNotFound):
2196+
# Expected for old runs without _parameters or metadata.
21962197
pass
2198+
except Exception:
2199+
# Transient error (network, metadata service) -- do NOT cache
2200+
# the fallback so a subsequent access can retry.
2201+
return (start, end)
21972202
self._cached_endpoints = (start, end)
21982203
return self._cached_endpoints
21992204

metaflow/events.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,24 @@ def __init__(self, _meta=None):
5757
@classmethod
5858
def from_runs(cls, run_objs: List["metaflow.Run"]):
5959
run_objs.sort(key=lambda x: x.finished_at, reverse=True)
60-
trigger = Trigger(
61-
[
60+
valid_runs = []
61+
meta = []
62+
for run_obj in run_objs:
63+
end_task = run_obj.end_task
64+
if end_task is None:
65+
continue
66+
valid_runs.append(run_obj)
67+
meta.append(
6268
{
6369
"type": "run",
6470
"timestamp": run_obj.finished_at,
65-
"name": "metaflow.%s.%s" % (run_obj.parent.id, run_obj["end"].id),
66-
"id": run_obj.end_task.pathspec,
71+
"name": "metaflow.%s.%s" % (run_obj.parent.id, end_task.parent.id),
72+
"id": end_task.pathspec,
6773
}
68-
for run_obj in run_objs
69-
]
70-
)
71-
trigger._runs = run_objs
74+
)
75+
76+
trigger = Trigger(meta)
77+
trigger._runs = valid_runs
7278
return trigger
7379

7480
@property

metaflow/flowspec.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,21 +164,11 @@ def _merge_value(self, inherited_value, self_value):
164164

165165

166166
class FlowSpecMeta(type):
167-
_registry = {} # {class_name: class} for all FlowSpec/StepSpec subclasses
168-
169-
# Names of base classes that should NOT be registered or have their
170-
# graph/attrs initialized. Subclasses of FlowSpecMeta (like StepSpecMeta)
171-
# can extend this set.
172-
_base_class_names = frozenset({"FlowSpec"})
173-
174167
def __init__(cls, name, bases, attrs):
175168
super().__init__(name, bases, attrs)
176-
# Check against the metaclass's own _base_class_names — this
177-
# allows StepSpecMeta to add "StepSpec" to the set.
178-
if name in type(cls)._base_class_names:
169+
if name == "FlowSpec":
179170
return
180171

181-
type(cls)._registry[name] = cls
182172
cls._init_attrs()
183173

184174
def _init_attrs(cls):

metaflow/graph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,18 @@ def populate_block(start_name, end_name):
568568
break
569569
return resulting_list
570570

571+
if self.start_step is None or self.end_step is None:
572+
missing = []
573+
if self.start_step is None:
574+
missing.append("start")
575+
if self.end_step is None:
576+
missing.append("end")
577+
raise ValueError(
578+
"Cannot compute graph structure: no %s step identified. "
579+
"Use @step(start=True)/@step(end=True) or name your steps "
580+
"'start'/'end'." % " or ".join(missing)
581+
)
582+
571583
if self.start_step == self.end_step:
572584
# Single-step flow
573585
graph_structure = []

metaflow/lint.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,60 @@ def check_reserved_words(graph):
5959
@linter.check
6060
def check_basic_steps(graph):
6161
if graph.start_step is None:
62+
annotated = [
63+
name
64+
for name, node in graph.nodes.items()
65+
if node.is_start_step and not name.startswith("_")
66+
]
67+
if len(annotated) > 1:
68+
raise LintWarn(
69+
"Multiple steps annotated with @step(start=True): %s. "
70+
"Exactly one is allowed." % ", ".join(sorted(annotated))
71+
)
6272
raise LintWarn(
6373
"Your flow must have exactly one start step. Either name a step "
6474
"'start' or use @step(start=True)."
6575
)
6676
if graph.end_step is None:
77+
annotated = [
78+
name
79+
for name, node in graph.nodes.items()
80+
if node.is_end_step and not name.startswith("_")
81+
]
82+
if len(annotated) > 1:
83+
raise LintWarn(
84+
"Multiple steps annotated with @step(end=True): %s. "
85+
"Exactly one is allowed." % ", ".join(sorted(annotated))
86+
)
6787
raise LintWarn(
6888
"Your flow must have exactly one end step. Either name a step "
6989
"'end' or use @step(end=True)."
7090
)
7191

7292

93+
@linter.ensure_fundamentals
94+
@linter.check
95+
def check_annotation_name_conflict(graph):
96+
"""Detect conflict between @step(start/end=True) and legacy step names."""
97+
if (
98+
graph.start_step is not None
99+
and graph.start_step != "start"
100+
and "start" in graph.nodes
101+
):
102+
raise LintWarn(
103+
"Ambiguous start step: step '%s' is annotated with @step(start=True) "
104+
"but a step named 'start' also exists. Remove the 'start' name or "
105+
"the @step(start=True) annotation." % graph.start_step
106+
)
107+
108+
if graph.end_step is not None and graph.end_step != "end" and "end" in graph.nodes:
109+
raise LintWarn(
110+
"Ambiguous end step: step '%s' is annotated with @step(end=True) "
111+
"but a step named 'end' also exists. Remove the 'end' name or "
112+
"the @step(end=True) annotation." % graph.end_step
113+
)
114+
115+
73116
@linter.ensure_static_graph
74117
@linter.check
75118
def check_start_end_degree(graph):

metaflow/plugins/argo/argo_workflows.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,8 +1201,9 @@ def _is_recursive_node(self, node):
12011201
return node.name in self.recursive_nodes
12021202

12031203
def _matching_conditional_join(self, node):
1204-
# If no earlier conditional join step is found during parsing, then 'end' is always one.
1205-
return self.matching_conditional_join_dict.get(node.name, "end")
1204+
# If no earlier conditional join step is found during parsing,
1205+
# fall back to the graph's terminal step.
1206+
return self.matching_conditional_join_dict.get(node.name, self.graph.end_step)
12061207

12071208
# Visit every node and yield the uber DAGTemplate(s).
12081209
def _dag_templates(self):

metaflow/plugins/aws/step_functions/step_functions.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,14 @@ def get_existing_deployment(cls, name):
240240
workflow = StepFunctionsClient().get(name)
241241
if workflow is not None:
242242
try:
243-
start = json.loads(workflow["definition"])["States"]["start"]
243+
definition = json.loads(workflow["definition"])
244+
start_state_name = definition.get("StartAt", "start")
245+
start = definition["States"][start_state_name]
244246
parameters = start["Parameters"]["Parameters"]
245247
return parameters.get("metaflow.owner"), parameters.get(
246248
"metaflow.production_token"
247249
)
248-
except KeyError:
250+
except (KeyError, TypeError, AttributeError):
249251
raise StepFunctionsException(
250252
"An existing non-metaflow "
251253
"workflow with the same name as "
@@ -271,17 +273,21 @@ def get_execution(cls, state_machine_name, name):
271273
)
272274
try:
273275
state_machine_arn = state_machine.get("stateMachineArn")
274-
environment_vars = (
275-
json.loads(state_machine.get("definition"))
276-
.get("States")
277-
.get("start")
278-
.get("Parameters")
279-
.get("ContainerOverrides")
280-
.get("Environment")
281-
)
282-
parameters = {
283-
item.get("Name"): item.get("Value") for item in environment_vars
284-
}
276+
definition = json.loads(state_machine.get("definition"))
277+
start_state_name = definition.get("StartAt", "start")
278+
try:
279+
start = definition["States"][start_state_name]
280+
environment_vars = start["Parameters"]["ContainerOverrides"][
281+
"Environment"
282+
]
283+
parameters = {
284+
item.get("Name"): item.get("Value") for item in environment_vars
285+
}
286+
except (KeyError, TypeError, AttributeError):
287+
raise StepFunctionsException(
288+
"A non-metaflow workflow *%s* already exists in AWS Step Functions."
289+
% state_machine_name
290+
)
285291
executions = client.list_executions(state_machine_arn, states=["RUNNING"])
286292
for execution in executions:
287293
if execution.get("name") == name:
@@ -295,9 +301,11 @@ def get_execution(cls, state_machine_name, name):
295301
except KeyError:
296302
raise StepFunctionsException(
297303
"A non-metaflow workflow *%s* already exists in AWS Step Functions."
298-
% name
304+
% state_machine_name
299305
)
300306
return None
307+
except StepFunctionsException:
308+
raise
301309
except Exception as e:
302310
raise StepFunctionsException(repr(e))
303311

@@ -795,6 +803,10 @@ def _batch(self, node):
795803
metaflow_version["production_token"] = self.production_token
796804
env["METAFLOW_VERSION"] = json.dumps(metaflow_version)
797805

806+
multiflow_name = os.environ.get("METAFLOW_MULTIFLOW_NAME")
807+
if multiflow_name:
808+
env["METAFLOW_MULTIFLOW_NAME"] = multiflow_name
809+
798810
# map config values
799811
cfg_env = {param["name"]: param["kv_name"] for param in self.config_parameters}
800812
if cfg_env:
@@ -926,6 +938,9 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
926938
entrypoint = [R.entrypoint()]
927939
else:
928940
entrypoint = [executable, script_name]
941+
multiflow_name = os.environ.get("METAFLOW_MULTIFLOW_NAME")
942+
if multiflow_name:
943+
entrypoint.append(multiflow_name)
929944

930945
# Use AWS Batch job identifier as the globally unique task identifier.
931946
task_id = "${AWS_BATCH_JOB_ID}"

metaflow/plugins/cards/card_cli.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,11 @@ def create(
649649
full_pathspec = "/".join([flowname, pathspec])
650650

651651
graph_dict, _ = ctx.obj.graph.output_steps()
652+
graph_payload = {
653+
"steps": graph_dict,
654+
"start_step": ctx.obj.graph.start_step,
655+
"end_step": ctx.obj.graph.end_step,
656+
}
652657

653658
if card_uuid is None:
654659
card_uuid = str(uuid.uuid4()).replace("-", "")
@@ -702,12 +707,12 @@ def create(
702707
mf_card = filtered_card(
703708
options=options,
704709
components=component_arr,
705-
graph=graph_dict,
710+
graph=graph_payload,
706711
flow=ctx.obj.flow,
707712
)
708713
else:
709714
mf_card = filtered_card(
710-
components=component_arr, graph=graph_dict, flow=ctx.obj.flow
715+
components=component_arr, graph=graph_payload, flow=ctx.obj.flow
711716
)
712717
except TypeError as e:
713718
if render_error_card:

metaflow/plugins/cards/card_modules/basic.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,22 @@
1212
CSS_PATH = os.path.join(ABS_DIR_PATH, "bundle.css")
1313

1414

15-
def transform_flow_graph(step_info):
15+
def transform_flow_graph(graph):
16+
if (
17+
isinstance(graph, dict)
18+
and "steps" in graph
19+
and "start_step" in graph
20+
and "end_step" in graph
21+
and isinstance(graph["steps"], dict)
22+
):
23+
step_info = graph["steps"]
24+
start_step = graph.get("start_step")
25+
end_step = graph.get("end_step")
26+
else:
27+
step_info = graph
28+
start_step = "start" if "start" in graph else None
29+
end_step = "end" if "end" in graph else None
30+
1631
def node_to_type(node_type):
1732
if node_type in ["linear", "start", "end", "join"]:
1833
return node_type
@@ -47,7 +62,11 @@ def node_to_type(node_type):
4762

4863
graph_dict[stepname] = node_info
4964

50-
return graph_dict
65+
return {
66+
"steps": graph_dict,
67+
"start_step": start_step,
68+
"end_step": end_step,
69+
}
5170

5271

5372
def read_file(path):

metaflow/plugins/cards/ui/src/components/dag/dag.svelte

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
1212
export let componentData: DagComponent;
1313
14-
const { data: steps } = componentData;
14+
const { steps, start_step: startStep, end_step: endStep } = componentData.data;
1515
let el: HTMLElement;
1616
let dagStructure: DagStructure = {};
1717
@@ -42,8 +42,14 @@
4242
style="position: relative; line-height: 1"
4343
data-component="dag"
4444
>
45-
{#if steps?.start}
46-
<StepWrapper {steps} stepName="start" bind:dagStructure />
45+
{#if startStep && steps?.[startStep]}
46+
<StepWrapper
47+
{steps}
48+
stepName={startStep}
49+
{startStep}
50+
{endStep}
51+
bind:dagStructure
52+
/>
4753
{:else}
4854
<p>No start step</p>
4955
{/if}

0 commit comments

Comments
 (0)