Skip to content

Commit 2d2b74a

Browse files
committed
Address comments
1 parent 5251c5e commit 2d2b74a

14 files changed

Lines changed: 284 additions & 291 deletions

metaflow/decorators.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,10 @@ def step(
957957

958958

959959
def step(
960-
f: Union[Callable[[FlowSpecDerived], None], Callable[[FlowSpecDerived, Any], None]],
960+
f=None,
961+
*,
962+
start=False,
963+
end=False,
961964
):
962965
"""
963966
Marks a method in a FlowSpec as a Metaflow Step. Note that this
@@ -982,20 +985,36 @@ def foo(self):
982985
983986
Parameters
984987
----------
985-
f : Union[Callable[[FlowSpecDerived], None], Callable[[FlowSpecDerived, Any], None]]
986-
Function to make into a Metaflow Step
988+
f : callable, optional
989+
Function to make into a Metaflow Step. When using keyword arguments
990+
(e.g. ``@step(start=True)``), this is ``None`` and a decorator
991+
function is returned instead.
992+
start : bool, default False
993+
Mark this step as the start (entry) step of the flow.
994+
end : bool, default False
995+
Mark this step as the end (terminal) step of the flow.
987996
988997
Returns
989998
-------
990-
Union[Callable[[FlowSpecDerived, StepFlag], None], Callable[[FlowSpecDerived, Any, StepFlag], None]]
991-
Function that is a Metaflow Step
999+
callable
1000+
The decorated function, or a decorator if keyword arguments were used.
9921001
"""
993-
f.is_step = True
994-
f.decorators = []
995-
f.config_decorators = []
996-
f.wrappers = []
997-
f.name = f.__name__
998-
return f
1002+
1003+
def _apply(func):
1004+
func.is_step = True
1005+
func.decorators = []
1006+
func.config_decorators = []
1007+
func.wrappers = []
1008+
func.name = func.__name__
1009+
func.is_start_step = start
1010+
func.is_end_step = end
1011+
return func
1012+
1013+
if f is not None:
1014+
# Called as @step (no parens)
1015+
return _apply(f)
1016+
# Called as @step(start=True) etc.
1017+
return _apply
9991018

10001019

10011020
def _import_plugin_decorators(globals_dict):

metaflow/graph.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,16 @@ def deindent_docstring(doc):
4848

4949
class DAGNode(object):
5050
def __init__(
51-
self, func_ast, decos, wrappers, config_decorators, doc, source_file, lineno
51+
self,
52+
func_ast,
53+
decos,
54+
wrappers,
55+
config_decorators,
56+
doc,
57+
source_file,
58+
lineno,
59+
is_start_step=False,
60+
is_end_step=False,
5261
):
5362
self.name = func_ast.name
5463
self.source_file = source_file
@@ -60,6 +69,9 @@ def __init__(
6069
self.config_decorators = config_decorators
6170
self.doc = deindent_docstring(doc)
6271
self.parallel_step = any(getattr(deco, "IS_PARALLEL", False) for deco in decos)
72+
# Explicit start/end annotations from @step(start=True) / @step(end=True)
73+
self.is_start_step = is_start_step
74+
self.is_end_step = is_end_step
6375

6476
# these attributes are populated by _parse
6577
self.tail_next_lineno = 0
@@ -264,41 +276,45 @@ def __init__(self, flow):
264276

265277
def _identify_start_end(self):
266278
"""
267-
Determine the start and end steps from graph structure.
279+
Determine the start and end steps.
268280
269-
Start step: the unique node with zero in-degree (no other node
270-
transitions to it). End step: the unique node with zero out-degree
271-
(it has no self.next() transitions). Internal steps (names starting
272-
with '_') are excluded.
281+
Uses explicit ``@step(start=True)`` / ``@step(end=True)`` annotations
282+
if present. Falls back to looking for steps named ``"start"`` /
283+
``"end"`` for backward compatibility.
273284
274-
Sets self.start_step and self.end_step to the step name strings,
275-
or None if the graph is malformed (validated later by lint).
276-
Also assigns the "start" and "end" node types based on structure.
285+
Sets ``self.start_step`` and ``self.end_step`` to step name strings,
286+
or ``None`` if the graph is malformed (validated later by lint).
287+
Also assigns the ``"start"`` and ``"end"`` node types.
277288
"""
278-
# Compute in-degree from out_funcs (already set by _parse)
279-
in_degree = {name: 0 for name in self.nodes}
280-
for node in self.nodes.values():
281-
for target in node.out_funcs:
282-
if target in in_degree:
283-
in_degree[target] += 1
284-
285-
# Start = zero in-degree (exclude internal steps starting with _)
286-
candidates_start = [
289+
# 1. Look for explicit annotations
290+
annotated_start = [
287291
name
288-
for name, deg in in_degree.items()
289-
if deg == 0 and not name.startswith("_")
292+
for name, node in self.nodes.items()
293+
if node.is_start_step and not name.startswith("_")
290294
]
291-
# End = zero out-degree (exclude internal steps starting with _)
292-
candidates_end = [
295+
annotated_end = [
293296
name
294-
for name in self.nodes
295-
if not self.nodes[name].out_funcs and not name.startswith("_")
297+
for name, node in self.nodes.items()
298+
if node.is_end_step and not name.startswith("_")
296299
]
297300

298-
self.start_step = candidates_start[0] if len(candidates_start) == 1 else None
299-
self.end_step = candidates_end[0] if len(candidates_end) == 1 else None
301+
# 2. Determine start step (annotation first, then name fallback)
302+
if len(annotated_start) == 1:
303+
self.start_step = annotated_start[0]
304+
elif len(annotated_start) == 0:
305+
self.start_step = "start" if "start" in self.nodes else None
306+
else:
307+
self.start_step = None # Multiple annotated — lint will catch
308+
309+
# 3. Determine end step (annotation first, then name fallback)
310+
if len(annotated_end) == 1:
311+
self.end_step = annotated_end[0]
312+
elif len(annotated_end) == 0:
313+
self.end_step = "end" if "end" in self.nodes else None
314+
else:
315+
self.end_step = None # Multiple annotated — lint will catch
300316

301-
# Assign types based on structure.
317+
# 4. Assign types based on identified start/end.
302318
# Only upgrade "linear" → "start" for the entry point; do NOT override
303319
# "split", "foreach", etc. since those types are needed for
304320
# split/join balance checking.
@@ -332,6 +348,8 @@ def _create_nodes(self, flow):
332348
func.__doc__,
333349
source_file,
334350
lineno,
351+
is_start_step=getattr(func, "is_start_step", False),
352+
is_end_step=getattr(func, "is_end_step", False),
335353
)
336354
nodes[element] = node
337355
return nodes

metaflow/lint.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,41 @@ def check_reserved_words(graph):
6060
def check_basic_steps(graph):
6161
if graph.start_step is None:
6262
raise LintWarn(
63-
"Your flow must have exactly one step with no incoming transitions "
64-
"(the entry point). Found zero or multiple candidates."
63+
"Your flow must have exactly one start step. Either name a step "
64+
"'start' or use @step(start=True)."
6565
)
6666
if graph.end_step is None:
6767
raise LintWarn(
68-
"Your flow must have exactly one step with no outgoing transitions "
69-
"(the terminal step). Found zero or multiple candidates."
68+
"Your flow must have exactly one end step. Either name a step "
69+
"'end' or use @step(end=True)."
70+
)
71+
72+
73+
@linter.ensure_static_graph
74+
@linter.check
75+
def check_start_end_degree(graph):
76+
"""Validate that the start step has no incoming and the end step has no outgoing."""
77+
if graph.start_step is None or graph.end_step is None:
78+
return
79+
80+
start_node = graph[graph.start_step]
81+
if start_node.in_funcs:
82+
raise LintWarn(
83+
"The start step *%s* has incoming transitions from %s. "
84+
"A start step must have no incoming transitions."
85+
% (graph.start_step, ", ".join(start_node.in_funcs)),
86+
start_node.func_lineno,
87+
start_node.source_file,
88+
)
89+
90+
end_node = graph[graph.end_step]
91+
if end_node.out_funcs:
92+
raise LintWarn(
93+
"The end step *%s* has outgoing transitions. "
94+
"An end step must have no outgoing transitions (no self.next())."
95+
% graph.end_step,
96+
end_node.func_lineno,
97+
end_node.source_file,
7098
)
7199

72100

metaflow/parameters.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -520,13 +520,7 @@ def wrapper(cmd):
520520

521521

522522
class InitParameter(Parameter):
523-
"""
524-
Parameter for StepSpec init() phase.
525-
526-
Behaves identically to Parameter in CLI mode. In direct invocation,
527-
InitParameter values are passed to the constructor (alongside Config)
528-
rather than to __call__.
529-
"""
523+
"""Parameter for StepSpec init() phase."""
530524

531525
IS_INIT_PARAMETER = True
532526

0 commit comments

Comments
 (0)