@@ -48,7 +48,17 @@ def deindent_docstring(doc):
4848
4949class DAGNode (object ):
5050 def __init__ (
51- self , func_ast , decos , wrappers , config_decorators , doc , source_file , lineno
51+ self ,
52+ func_ast ,
53+ decos ,
54+ wrappers ,
55+ config_decorators ,
56+ doc ,
57+ source_file ,
58+ lineno ,
59+ is_start_step = False ,
60+ is_end_step = False ,
61+ node_info = None ,
5262 ):
5363 self .name = func_ast .name
5464 self .source_file = source_file
@@ -60,6 +70,12 @@ def __init__(
6070 self .config_decorators = config_decorators
6171 self .doc = deindent_docstring (doc )
6272 self .parallel_step = any (getattr (deco , "IS_PARALLEL" , False ) for deco in decos )
73+ # Explicit start/end annotations from @step(start=True) / @step(end=True)
74+ self .is_start_step = is_start_step
75+ self .is_end_step = is_end_step
76+ # Generic metadata dict for extensions to attach extra info to this node.
77+ # Serialized to _graph_info via to_pod; live references accessible via flow._graph.
78+ self .node_info = node_info or {}
6379
6480 # these attributes are populated by _parse
6581 self .tail_next_lineno = 0
@@ -140,10 +156,9 @@ def _parse(self, func_ast, lineno):
140156 self .num_args = len (func_ast .args .args )
141157 tail = func_ast .body [- 1 ]
142158
143- # end doesn't need a transition
144- if self .name == "end" :
145- # TYPE: end
146- self .type = "end"
159+ # Note: type assignment for start/end steps is handled by
160+ # FlowGraph._identify_start_end() based on graph structure,
161+ # not by name.
147162
148163 # ensure that the tail an expression
149164 if not isinstance (tail , ast .Expr ):
@@ -212,10 +227,10 @@ def _parse(self, func_ast, lineno):
212227 self .type = "split"
213228 self .invalid_tail_next = False
214229 elif len (self .out_funcs ) == 1 :
215- # TYPE: linear
216- if self . name == "start" :
217- self . type = "start"
218- elif self .num_args > 1 :
230+ # TYPE: linear (or join)
231+ # Note: "start" type is assigned later by
232+ # FlowGraph._identify_start_end() based on structure.
233+ if self .num_args > 1 :
219234 self .type = "join"
220235 else :
221236 self .type = "linear"
@@ -259,9 +274,65 @@ def __init__(self, flow):
259274 self .doc = deindent_docstring (flow .__doc__ )
260275 # nodes sorted in topological order.
261276 self .sorted_nodes = []
277+ self ._identify_start_end ()
262278 self ._traverse_graph ()
263279 self ._postprocess ()
264280
281+ def _identify_start_end (self ):
282+ """
283+ Determine the start and end steps.
284+
285+ Uses explicit ``@step(start=True)`` / ``@step(end=True)`` annotations
286+ if present. Falls back to looking for steps named ``"start"`` /
287+ ``"end"`` for backward compatibility.
288+
289+ Sets ``self.start_step`` and ``self.end_step`` to step name strings,
290+ or ``None`` if the graph is malformed (validated later by lint).
291+ Also assigns the ``"start"`` and ``"end"`` node types.
292+ """
293+ # 1. Look for explicit annotations
294+ annotated_start = [
295+ name
296+ for name , node in self .nodes .items ()
297+ if node .is_start_step and not name .startswith ("_" )
298+ ]
299+ annotated_end = [
300+ name
301+ for name , node in self .nodes .items ()
302+ if node .is_end_step and not name .startswith ("_" )
303+ ]
304+
305+ # 2. Determine start step (annotation first, then name fallback)
306+ if len (annotated_start ) == 1 :
307+ self .start_step = annotated_start [0 ]
308+ elif len (annotated_start ) == 0 :
309+ self .start_step = "start" if "start" in self .nodes else None
310+ else :
311+ self .start_step = None # Multiple annotated — lint will catch
312+
313+ # 3. Determine end step (annotation first, then name fallback)
314+ if len (annotated_end ) == 1 :
315+ self .end_step = annotated_end [0 ]
316+ elif len (annotated_end ) == 0 :
317+ self .end_step = "end" if "end" in self .nodes else None
318+ else :
319+ self .end_step = None # Multiple annotated — lint will catch
320+
321+ # 4. Assign types based on identified start/end.
322+ # Only upgrade "linear" → "start" for the entry point; do NOT override
323+ # "split", "foreach", etc. since those types are needed for
324+ # split/join balance checking.
325+ if self .start_step and self .start_step == self .end_step :
326+ # Single-step flow: terminal node that is also the entry point
327+ self .nodes [self .start_step ].type = "end"
328+ else :
329+ if self .start_step :
330+ node = self .nodes [self .start_step ]
331+ if node .type in (None , "linear" ):
332+ node .type = "start"
333+ if self .end_step :
334+ self .nodes [self .end_step ].type = "end"
335+
265336 def _create_nodes (self , flow ):
266337 nodes = {}
267338 for element in dir (flow ):
@@ -281,6 +352,9 @@ def _create_nodes(self, flow):
281352 func .__doc__ ,
282353 source_file ,
283354 lineno ,
355+ is_start_step = getattr (func , "is_start_step" , False ),
356+ is_end_step = getattr (func , "is_end_step" , False ),
357+ node_info = getattr (func , "node_info" , None ),
284358 )
285359 nodes [element ] = node
286360 return nodes
@@ -338,8 +412,8 @@ def traverse(node, seen, split_parents, split_branches):
338412 split_branches + ([n ] if add_split_branch else []),
339413 )
340414
341- if "start" in self :
342- traverse (self ["start" ], [], [], [])
415+ if self . start_step and self . start_step in self :
416+ traverse (self [self . start_step ], [], [], [])
343417
344418 # fix the order of in_funcs
345419 for node in self .nodes .values ():
@@ -445,6 +519,7 @@ def node_to_dict(name, node):
445519 for deco in chain (node .wrappers , node .config_decorators )
446520 ],
447521 "next" : node .out_funcs ,
522+ "node_info" : to_pod (node .node_info ),
448523 }
449524 if d ["type" ] == "split-foreach" :
450525 d ["foreach_artifact" ] = node .foreach_param
@@ -493,9 +568,15 @@ def populate_block(start_name, end_name):
493568 break
494569 return resulting_list
495570
496- graph_structure = populate_block ("start" , "end" )
571+ if self .start_step == self .end_step :
572+ # Single-step flow
573+ graph_structure = []
574+ else :
575+ graph_structure = populate_block (self .start_step , self .end_step )
497576
498- steps_info ["end" ] = node_to_dict ("end" , self .nodes ["end" ])
499- graph_structure .append ("end" )
577+ steps_info [self .end_step ] = node_to_dict (
578+ self .end_step , self .nodes [self .end_step ]
579+ )
580+ graph_structure .append (self .end_step )
500581
501582 return steps_info , graph_structure
0 commit comments