@@ -48,7 +48,16 @@ def deindent_docstring(doc):
4848
4949class DAGNode (object ):
5050 def __init__ (
51- self , func_ast , decos , wrappers , config_decorators , doc , source_file , lineno
51+ self ,
52+ func_ast ,
53+ decos ,
54+ wrappers ,
55+ config_decorators ,
56+ doc ,
57+ source_file ,
58+ lineno ,
59+ is_start_step = False ,
60+ is_end_step = False ,
5261 ):
5362 self .name = func_ast .name
5463 self .source_file = source_file
@@ -60,6 +69,9 @@ def __init__(
6069 self .config_decorators = config_decorators
6170 self .doc = deindent_docstring (doc )
6271 self .parallel_step = any (getattr (deco , "IS_PARALLEL" , False ) for deco in decos )
72+ # Explicit start/end annotations from @step(start=True) / @step(end=True)
73+ self .is_start_step = is_start_step
74+ self .is_end_step = is_end_step
6375
6476 # these attributes are populated by _parse
6577 self .tail_next_lineno = 0
@@ -264,41 +276,45 @@ def __init__(self, flow):
264276
265277 def _identify_start_end (self ):
266278 """
267- Determine the start and end steps from graph structure .
279+ Determine the start and end steps.
268280
269- Start step: the unique node with zero in-degree (no other node
270- transitions to it). End step: the unique node with zero out-degree
271- (it has no self.next() transitions). Internal steps (names starting
272- with '_') are excluded.
281+ Uses explicit ``@step(start=True)`` / ``@step(end=True)`` annotations
282+ if present. Falls back to looking for steps named ``"start"`` /
283+ ``"end"`` for backward compatibility.
273284
274- Sets self.start_step and self.end_step to the step name strings,
275- or None if the graph is malformed (validated later by lint).
276- Also assigns the "start" and "end" node types based on structure .
285+ Sets `` self.start_step`` and `` self.end_step`` to step name strings,
286+ or `` None`` if the graph is malformed (validated later by lint).
287+ Also assigns the `` "start"`` and `` "end"`` node types.
277288 """
278- # Compute in-degree from out_funcs (already set by _parse)
279- in_degree = {name : 0 for name in self .nodes }
280- for node in self .nodes .values ():
281- for target in node .out_funcs :
282- if target in in_degree :
283- in_degree [target ] += 1
284-
285- # Start = zero in-degree (exclude internal steps starting with _)
286- candidates_start = [
289+ # 1. Look for explicit annotations
290+ annotated_start = [
287291 name
288- for name , deg in in_degree .items ()
289- if deg == 0 and not name .startswith ("_" )
292+ for name , node in self . nodes .items ()
293+ if node . is_start_step and not name .startswith ("_" )
290294 ]
291- # End = zero out-degree (exclude internal steps starting with _)
292- candidates_end = [
295+ annotated_end = [
293296 name
294- for name in self .nodes
295- if not self . nodes [ name ]. out_funcs and not name .startswith ("_" )
297+ for name , node in self .nodes . items ()
298+ if node . is_end_step and not name .startswith ("_" )
296299 ]
297300
298- self .start_step = candidates_start [0 ] if len (candidates_start ) == 1 else None
299- self .end_step = candidates_end [0 ] if len (candidates_end ) == 1 else None
301+ # 2. Determine start step (annotation first, then name fallback)
302+ if len (annotated_start ) == 1 :
303+ self .start_step = annotated_start [0 ]
304+ elif len (annotated_start ) == 0 :
305+ self .start_step = "start" if "start" in self .nodes else None
306+ else :
307+ self .start_step = None # Multiple annotated — lint will catch
308+
309+ # 3. Determine end step (annotation first, then name fallback)
310+ if len (annotated_end ) == 1 :
311+ self .end_step = annotated_end [0 ]
312+ elif len (annotated_end ) == 0 :
313+ self .end_step = "end" if "end" in self .nodes else None
314+ else :
315+ self .end_step = None # Multiple annotated — lint will catch
300316
301- # Assign types based on structure .
317+ # 4. Assign types based on identified start/end .
302318 # Only upgrade "linear" → "start" for the entry point; do NOT override
303319 # "split", "foreach", etc. since those types are needed for
304320 # split/join balance checking.
@@ -332,6 +348,8 @@ def _create_nodes(self, flow):
332348 func .__doc__ ,
333349 source_file ,
334350 lineno ,
351+ is_start_step = getattr (func , "is_start_step" , False ),
352+ is_end_step = getattr (func , "is_end_step" , False ),
335353 )
336354 nodes [element ] = node
337355 return nodes
0 commit comments