Skip to content

Commit 853620c

Browse files
committed
Modify the graph to allow for single step flows and custom names
Instead of enforcing a step named "start" and another named "end" and a minimum of two steps, Metaflow now allows for single step flows with any names. The "start" and "end" properties are derived from the structure of the graph. We still require a single entry point and a single exit point but they can be one and the same and do not have to be named something specific.
1 parent e4e0525 commit 853620c

21 files changed

Lines changed: 961 additions & 80 deletions

metaflow/client/core.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,8 +2121,9 @@ def parent_steps(self) -> Iterator["Step"]:
21212121
Parent step
21222122
"""
21232123
graph_info = self.task["_graph_info"].data
2124+
start_step = graph_info.get("start_step", "start")
21242125

2125-
if self.id != "start":
2126+
if self.id != start_step:
21262127
flow, run, _ = self.path_components
21272128
for node_name, attributes in graph_info["steps"].items():
21282129
if self.id in attributes["next"]:
@@ -2139,8 +2140,9 @@ def child_steps(self) -> Iterator["Step"]:
21392140
Child step
21402141
"""
21412142
graph_info = self.task["_graph_info"].data
2143+
end_step = graph_info.get("end_step", "end")
21422144

2143-
if self.id != "end":
2145+
if self.id != end_step:
21442146
flow, run, _ = self.path_components
21452147
for next_step in graph_info["steps"][self.id]["next"]:
21462148
yield Step(f"{flow}/{run}/{next_step}", _namespace_check=False)
@@ -2153,7 +2155,7 @@ class Run(MetaflowObject):
21532155
Attributes
21542156
----------
21552157
data : MetaflowData
2156-
a shortcut to run['end'].task.data, i.e. data produced by this run.
2158+
A shortcut to the terminal step's task data produced by this run.
21572159
successful : bool
21582160
True if the run completed successfully.
21592161
finished : bool
@@ -2165,7 +2167,7 @@ class Run(MetaflowObject):
21652167
trigger : MetaflowTrigger
21662168
Information about event(s) that triggered this run (if present). See `MetaflowTrigger`.
21672169
end_task : Task
2168-
`Task` for the end step (if it is present already).
2170+
`Task` for the terminal step (if it is present already).
21692171
"""
21702172

21712173
_NAME = "run"
@@ -2176,6 +2178,25 @@ def _iter_filter(self, x):
21762178
# exclude _parameters step
21772179
return x.id[0] != "_"
21782180

2181+
@property
2182+
def _graph_endpoints(self):
2183+
"""
2184+
Returns (start_step_name, end_step_name) from _parameters metadata.
2185+
2186+
Falls back to ("start", "end") for runs that predate structural
2187+
inference (backward compatibility).
2188+
"""
2189+
if not hasattr(self, "_cached_endpoints"):
2190+
start, end = "start", "end"
2191+
try:
2192+
params_meta = self["_parameters"].task.metadata_dict
2193+
start = params_meta.get("start_step", "start")
2194+
end = params_meta.get("end_step", "end")
2195+
except Exception:
2196+
pass
2197+
self._cached_endpoints = (start, end)
2198+
return self._cached_endpoints
2199+
21792200
def steps(self, *tags: str) -> Iterator[Step]:
21802201
"""
21812202
[Legacy function - do not use]
@@ -2298,17 +2319,18 @@ def finished_at(self) -> Optional[datetime]:
22982319
@property
22992320
def end_task(self) -> Optional[Task]:
23002321
"""
2301-
Returns the Task corresponding to the 'end' step.
2322+
Returns the Task corresponding to the terminal step.
23022323
2303-
This returns None if the end step does not yet exist.
2324+
This returns None if the terminal step does not yet exist.
23042325
23052326
Returns
23062327
-------
23072328
Task, optional
2308-
The 'end' task
2329+
The terminal step's task
23092330
"""
23102331
try:
2311-
end_step = self["end"]
2332+
_, end_step_name = self._graph_endpoints
2333+
end_step = self[end_step_name]
23122334
except KeyError:
23132335
return None
23142336

@@ -2481,8 +2503,9 @@ def trigger(self) -> Optional[Trigger]:
24812503
Trigger, optional
24822504
Container of triggering events
24832505
"""
2484-
if "start" in self and self["start"].task:
2485-
meta = self["start"].task.metadata_dict.get("execution-triggers")
2506+
start_step, _ = self._graph_endpoints
2507+
if start_step in self and self[start_step].task:
2508+
meta = self[start_step].task.metadata_dict.get("execution-triggers")
24862509
if meta:
24872510
return Trigger(json.loads(meta))
24882511
return None

metaflow/decorators.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,19 @@ def _base_step_decorator(decotype, *args, **kwargs):
547547
func = args[0]
548548
if isinstance(func, (StepMutator, UserStepDecoratorBase)):
549549
func = func._my_step
550+
551+
# Step decorator applied to a class with a synthesized step method.
552+
# Used by extensions that create a single synthetic @step on a
553+
# FlowSpec subclass. _step_spec_step_name is set by the
554+
# extension's metaclass.
555+
if isinstance(func, type) and hasattr(func, "_step_spec_step_name"):
556+
step_func = getattr(func, func._step_spec_step_name)
557+
if hasattr(step_func, "is_step"):
558+
step_func.decorators.append(
559+
decotype(attributes=kwargs, statically_defined=True)
560+
)
561+
return func
562+
550563
if not hasattr(func, "is_step"):
551564
raise BadStepDecoratorException(decotype.name, func)
552565

@@ -947,7 +960,11 @@ def step(
947960

948961

949962
def step(
950-
f: Union[Callable[[FlowSpecDerived], None], Callable[[FlowSpecDerived, Any], None]],
963+
f=None,
964+
*,
965+
start=False,
966+
end=False,
967+
node_info=None,
951968
):
952969
"""
953970
Marks a method in a FlowSpec as a Metaflow Step. Note that this
@@ -972,20 +989,41 @@ def foo(self):
972989
973990
Parameters
974991
----------
975-
f : Union[Callable[[FlowSpecDerived], None], Callable[[FlowSpecDerived, Any], None]]
976-
Function to make into a Metaflow Step
992+
f : callable, optional
993+
Function to make into a Metaflow Step. When using keyword arguments
994+
(e.g. ``@step(start=True)``), this is ``None`` and a decorator
995+
function is returned instead.
996+
start : bool, default False
997+
Mark this step as the start (entry) step of the flow.
998+
end : bool, default False
999+
Mark this step as the end (terminal) step of the flow.
1000+
node_info : dict, optional
1001+
Extra metadata to attach to this step's DAGNode. Extensions can use
1002+
this to store arbitrary information accessible via ``flow._graph``
1003+
(live references) and ``_graph_info`` (serialized via ``to_pod``).
9771004
9781005
Returns
9791006
-------
980-
Union[Callable[[FlowSpecDerived, StepFlag], None], Callable[[FlowSpecDerived, Any, StepFlag], None]]
981-
Function that is a Metaflow Step
1007+
callable
1008+
The decorated function, or a decorator if keyword arguments were used.
9821009
"""
983-
f.is_step = True
984-
f.decorators = []
985-
f.config_decorators = []
986-
f.wrappers = []
987-
f.name = f.__name__
988-
return f
1010+
1011+
def _apply(func):
1012+
func.is_step = True
1013+
func.decorators = []
1014+
func.config_decorators = []
1015+
func.wrappers = []
1016+
func.name = func.__name__
1017+
func.is_start_step = start
1018+
func.is_end_step = end
1019+
func.node_info = node_info or {}
1020+
return func
1021+
1022+
if f is not None:
1023+
# Called as @step (no parens)
1024+
return _apply(f)
1025+
# Called as @step(start=True) etc.
1026+
return _apply
9891027

9901028

9911029
def _import_plugin_decorators(globals_dict):

metaflow/flowspec.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,21 @@ def _merge_value(self, inherited_value, self_value):
164164

165165

166166
class FlowSpecMeta(type):
167+
_registry = {} # {class_name: class} for all FlowSpec/StepSpec subclasses
168+
169+
# Names of base classes that should NOT be registered or have their
170+
# graph/attrs initialized. Subclasses of FlowSpecMeta (like StepSpecMeta)
171+
# can extend this set.
172+
_base_class_names = frozenset({"FlowSpec"})
173+
167174
def __init__(cls, name, bases, attrs):
168175
super().__init__(name, bases, attrs)
169-
if name == "FlowSpec":
176+
# Check against the metaclass's own _base_class_names — this
177+
# allows StepSpecMeta to add "StepSpec" to the set.
178+
if name in type(cls)._base_class_names:
170179
return
171180

181+
type(cls)._registry[name] = cls
172182
cls._init_attrs()
173183

174184
def _init_attrs(cls):
@@ -516,6 +526,8 @@ def _set_constants(self, graph, kwargs, config_options):
516526

517527
graph_info = {
518528
"file": os.path.basename(os.path.abspath(sys.argv[0])),
529+
"start_step": graph.start_step,
530+
"end_step": graph.end_step,
519531
"parameters": parameters_info,
520532
"constants": constants_info,
521533
"steps": steps_info,

metaflow/graph.py

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,17 @@ def deindent_docstring(doc):
4848

4949
class DAGNode(object):
5050
def __init__(
51-
self, func_ast, decos, wrappers, config_decorators, doc, source_file, lineno
51+
self,
52+
func_ast,
53+
decos,
54+
wrappers,
55+
config_decorators,
56+
doc,
57+
source_file,
58+
lineno,
59+
is_start_step=False,
60+
is_end_step=False,
61+
node_info=None,
5262
):
5363
self.name = func_ast.name
5464
self.source_file = source_file
@@ -60,6 +70,12 @@ def __init__(
6070
self.config_decorators = config_decorators
6171
self.doc = deindent_docstring(doc)
6272
self.parallel_step = any(getattr(deco, "IS_PARALLEL", False) for deco in decos)
73+
# Explicit start/end annotations from @step(start=True) / @step(end=True)
74+
self.is_start_step = is_start_step
75+
self.is_end_step = is_end_step
76+
# Generic metadata dict for extensions to attach extra info to this node.
77+
# Serialized to _graph_info via to_pod; live references accessible via flow._graph.
78+
self.node_info = node_info or {}
6379

6480
# these attributes are populated by _parse
6581
self.tail_next_lineno = 0
@@ -140,10 +156,9 @@ def _parse(self, func_ast, lineno):
140156
self.num_args = len(func_ast.args.args)
141157
tail = func_ast.body[-1]
142158

143-
# end doesn't need a transition
144-
if self.name == "end":
145-
# TYPE: end
146-
self.type = "end"
159+
# Note: type assignment for start/end steps is handled by
160+
# FlowGraph._identify_start_end() based on graph structure,
161+
# not by name.
147162

148163
# ensure that the tail an expression
149164
if not isinstance(tail, ast.Expr):
@@ -212,10 +227,10 @@ def _parse(self, func_ast, lineno):
212227
self.type = "split"
213228
self.invalid_tail_next = False
214229
elif len(self.out_funcs) == 1:
215-
# TYPE: linear
216-
if self.name == "start":
217-
self.type = "start"
218-
elif self.num_args > 1:
230+
# TYPE: linear (or join)
231+
# Note: "start" type is assigned later by
232+
# FlowGraph._identify_start_end() based on structure.
233+
if self.num_args > 1:
219234
self.type = "join"
220235
else:
221236
self.type = "linear"
@@ -259,9 +274,65 @@ def __init__(self, flow):
259274
self.doc = deindent_docstring(flow.__doc__)
260275
# nodes sorted in topological order.
261276
self.sorted_nodes = []
277+
self._identify_start_end()
262278
self._traverse_graph()
263279
self._postprocess()
264280

281+
def _identify_start_end(self):
282+
"""
283+
Determine the start and end steps.
284+
285+
Uses explicit ``@step(start=True)`` / ``@step(end=True)`` annotations
286+
if present. Falls back to looking for steps named ``"start"`` /
287+
``"end"`` for backward compatibility.
288+
289+
Sets ``self.start_step`` and ``self.end_step`` to step name strings,
290+
or ``None`` if the graph is malformed (validated later by lint).
291+
Also assigns the ``"start"`` and ``"end"`` node types.
292+
"""
293+
# 1. Look for explicit annotations
294+
annotated_start = [
295+
name
296+
for name, node in self.nodes.items()
297+
if node.is_start_step and not name.startswith("_")
298+
]
299+
annotated_end = [
300+
name
301+
for name, node in self.nodes.items()
302+
if node.is_end_step and not name.startswith("_")
303+
]
304+
305+
# 2. Determine start step (annotation first, then name fallback)
306+
if len(annotated_start) == 1:
307+
self.start_step = annotated_start[0]
308+
elif len(annotated_start) == 0:
309+
self.start_step = "start" if "start" in self.nodes else None
310+
else:
311+
self.start_step = None # Multiple annotated — lint will catch
312+
313+
# 3. Determine end step (annotation first, then name fallback)
314+
if len(annotated_end) == 1:
315+
self.end_step = annotated_end[0]
316+
elif len(annotated_end) == 0:
317+
self.end_step = "end" if "end" in self.nodes else None
318+
else:
319+
self.end_step = None # Multiple annotated — lint will catch
320+
321+
# 4. Assign types based on identified start/end.
322+
# Only upgrade "linear" → "start" for the entry point; do NOT override
323+
# "split", "foreach", etc. since those types are needed for
324+
# split/join balance checking.
325+
if self.start_step and self.start_step == self.end_step:
326+
# Single-step flow: terminal node that is also the entry point
327+
self.nodes[self.start_step].type = "end"
328+
else:
329+
if self.start_step:
330+
node = self.nodes[self.start_step]
331+
if node.type in (None, "linear"):
332+
node.type = "start"
333+
if self.end_step:
334+
self.nodes[self.end_step].type = "end"
335+
265336
def _create_nodes(self, flow):
266337
nodes = {}
267338
for element in dir(flow):
@@ -281,6 +352,9 @@ def _create_nodes(self, flow):
281352
func.__doc__,
282353
source_file,
283354
lineno,
355+
is_start_step=getattr(func, "is_start_step", False),
356+
is_end_step=getattr(func, "is_end_step", False),
357+
node_info=getattr(func, "node_info", None),
284358
)
285359
nodes[element] = node
286360
return nodes
@@ -338,8 +412,8 @@ def traverse(node, seen, split_parents, split_branches):
338412
split_branches + ([n] if add_split_branch else []),
339413
)
340414

341-
if "start" in self:
342-
traverse(self["start"], [], [], [])
415+
if self.start_step and self.start_step in self:
416+
traverse(self[self.start_step], [], [], [])
343417

344418
# fix the order of in_funcs
345419
for node in self.nodes.values():
@@ -445,6 +519,7 @@ def node_to_dict(name, node):
445519
for deco in chain(node.wrappers, node.config_decorators)
446520
],
447521
"next": node.out_funcs,
522+
"node_info": to_pod(node.node_info),
448523
}
449524
if d["type"] == "split-foreach":
450525
d["foreach_artifact"] = node.foreach_param
@@ -493,9 +568,15 @@ def populate_block(start_name, end_name):
493568
break
494569
return resulting_list
495570

496-
graph_structure = populate_block("start", "end")
571+
if self.start_step == self.end_step:
572+
# Single-step flow
573+
graph_structure = []
574+
else:
575+
graph_structure = populate_block(self.start_step, self.end_step)
497576

498-
steps_info["end"] = node_to_dict("end", self.nodes["end"])
499-
graph_structure.append("end")
577+
steps_info[self.end_step] = node_to_dict(
578+
self.end_step, self.nodes[self.end_step]
579+
)
580+
graph_structure.append(self.end_step)
500581

501582
return steps_info, graph_structure

0 commit comments

Comments
 (0)