@@ -44,6 +44,7 @@ def __init__(
4444 recovery_step = None ,
4545 termination_result_fn = lambda x , y : x if x is not None else y ,
4646 context = None ,
47+ max_iteration : Optional [int ] = None ,
4748 ** kwargs ,
4849 ):
4950 self ._outlets = []
@@ -66,6 +67,7 @@ def __init__(
6667 self ._full_event = kwargs .get ("full_event" )
6768 self ._input_path = kwargs .get ("input_path" )
6869 self ._result_path = kwargs .get ("result_path" )
70+ self ._max_iteration = max_iteration
6971 self ._runnable = False
7072 name = kwargs .get ("name" , None )
7173 if name :
@@ -74,10 +76,25 @@ def __init__(
7476 self .name = type (self ).__name__
7577
7678 self ._closeables = []
79+ self ._selected_outlets : Optional [list [str ]] = None
80+ self ._create_name_to_outlet = True
7781
7882 def _init (self ):
7983 self ._termination_received = 0
8084 self ._termination_result = None
85+ self ._name_to_outlet = {}
86+ if self ._method_is_overridden ("select_outlets" , Flow ) and self ._create_name_to_outlet :
87+ self ._init_name_to_outlet ()
88+
89+ def _init_name_to_outlet (self ):
90+ for outlet in self ._outlets :
91+ if outlet .name in self ._name_to_outlet :
92+ raise ValueError (f"Ambiguous outlet name '{ outlet .name } ' in step '{ self .name } '" )
93+ self ._name_to_outlet [outlet .name ] = outlet
94+
95+ def _method_is_overridden (self , method_name : str , parent_cls ):
96+ """Return True if the subclass overrides the given method."""
97+ return getattr (self .__class__ , method_name ) is not getattr (parent_cls , method_name )
8198
8299 def to_dict (self , fields = None , exclude = None ):
83100 """convert the step object to a python dictionary"""
@@ -187,16 +204,27 @@ def _get_recovery_step(self, exception):
187204 else :
188205 return self ._recovery_step
189206
190- def run (self ):
207+ def run (self , visited = None ):
191208 if not self ._legal_first_step and not self ._runnable :
192209 raise ValueError ("Flow must start with a source" )
210+
211+ # Initialize visited set once at the top (only for the root call)
212+ if visited is None :
213+ visited = set ()
214+
215+ # Detect cycles: if we've already visited this step, don't run it again
216+ if self in visited :
217+ return []
193218 self ._init ()
219+ visited .add (self )
220+
194221 outlets = []
195222 outlets .extend (self ._outlets )
196223 outlets .extend (self ._get_recovery_steps ())
197224 for outlet in outlets :
198225 outlet ._runnable = True
199- self ._closeables .extend (outlet .run ())
226+ outlet_closeables = outlet .run (visited )
227+ self ._closeables .extend (outlet_closeables )
200228 return self ._closeables
201229
202230 def _get_recovery_steps (self ):
@@ -215,6 +243,7 @@ async def _do(self, event):
215243
216244 async def _do_and_recover (self , event ):
217245 try :
246+ self .check_and_update_iteration_number (event )
218247 return await self ._do (event )
219248 except BaseException as ex :
220249 if getattr (ex , "_raised_by_storey_step" , None ) is not None :
@@ -256,7 +285,17 @@ def _should_terminate(self):
256285 return self ._termination_received == len (self ._inlets )
257286
258287 async def _do_downstream (self , event , outlets = None ):
259- outlets = self ._outlets if outlets is None else outlets
288+ if not outlets :
289+ if event is _termination_obj :
290+ outlets = self ._outlets
291+ else :
292+ if self ._selected_outlets :
293+ outlet_names = self ._selected_outlets
294+ self ._selected_outlets = None
295+ else :
296+ outlet_names = self .select_outlets (event .body )
297+ outlets = self ._check_outlets_by_names (outlet_names ) if outlet_names else self ._outlets
298+
260299 if not outlets :
261300 return
262301 if event is _termination_obj :
@@ -320,21 +359,76 @@ def _user_fn_output_to_event(self, event, fn_result):
320359 mapped_event .body = fn_result
321360 return mapped_event
322361
323- def _check_step_in_flow (self , type_to_check ):
362+ def _check_step_in_flow (self , type_to_check , visited = None ):
363+ # initialize the visited set once at the top
364+ if visited is None :
365+ visited = set ()
366+
367+ # detect cycles
368+ if self in visited :
369+ return False
370+ visited .add (self )
371+
372+ # check this node
324373 if isinstance (self , type_to_check ):
325374 return True
375+
376+ # check outlets
326377 for outlet in self ._outlets :
327- if outlet ._check_step_in_flow (type_to_check ):
378+ if outlet ._check_step_in_flow (type_to_check , visited ):
328379 return True
380+
381+ # check recovery step
329382 if isinstance (self ._recovery_step , Flow ):
330- if self ._recovery_step ._check_step_in_flow (type_to_check ):
383+ if self ._recovery_step ._check_step_in_flow (type_to_check , visited ):
331384 return True
385+
332386 elif isinstance (self ._recovery_step , dict ):
333387 for step in self ._recovery_step .values ():
334- if step ._check_step_in_flow (type_to_check ):
388+ if step ._check_step_in_flow (type_to_check , visited ):
335389 return True
336390 return False
337391
392+ def check_and_update_iteration_number (self , event ) -> Optional [Callable ]:
393+ if hasattr (event , "_cyclic_counter" ) and self ._max_iteration is not None :
394+ counter = self .get_iteration_counter (event )
395+ if counter >= self ._max_iteration :
396+ raise RuntimeError (f"Max iterations exceeded in step '{ self .name } ' for event { event .id } " )
397+ event ._cyclic_counter [self .name ] = counter + 1
398+ else :
399+ event ._cyclic_counter = {self .name : 1 }
400+
401+ def get_iteration_counter (self , event ):
402+ return getattr (event , "_cyclic_counter" , {}).get (self .name , 0 )
403+
404+ def select_outlets (self , event ) -> Optional [Collection [str ]]:
405+ """
406+ Override this method to route events based on a custom logic. The default implementation will route all
407+ events to all outlets.
408+ """
409+ return None
410+
411+ def _check_outlets_by_names (self , outlet_names : Collection [str ]) -> list ["Flow" ]:
412+ outlets = []
413+
414+ # Check for duplicates
415+ if len (set (outlet_names )) != len (outlet_names ):
416+ raise ValueError (
417+ f"Invalid outlet selection for '{ self .name } ': duplicate outlet names were provided "
418+ f"({ ', ' .join (outlet_names )} )."
419+ )
420+
421+ # Validate each outlet name
422+ for outlet_name in outlet_names :
423+ if outlet_name not in self ._name_to_outlet :
424+ raise ValueError (
425+ f"Invalid outlet '{ outlet_name } ' for '{ self .name } '. "
426+ f"Allowed outlets are: { ', ' .join (self ._name_to_outlet )} ."
427+ )
428+ outlets .append (self ._name_to_outlet [outlet_name ])
429+
430+ return outlets
431+
338432
339433class WithUUID :
340434 def __init__ (self ):
@@ -358,20 +452,8 @@ class Choice(Flow):
358452
359453 def _init (self ):
360454 super ()._init ()
361- self ._name_to_outlet = {}
362- for outlet in self ._outlets :
363- if outlet .name in self ._name_to_outlet :
364- raise ValueError (f"Ambiguous outlet name '{ outlet .name } ' in Choice step" )
365- self ._name_to_outlet [outlet .name ] = outlet
366455 # TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
367- self ._passthrough_for_preview = list (self ._name_to_outlet ) == ["dataframe" ]
368-
369- def select_outlets (self , event ) -> Collection [str ]:
370- """
371- Override this method to route events based on a customer logic. The default implementation will route all
372- events to all outlets.
373- """
374- return self ._name_to_outlet .keys ()
456+ self ._passthrough_for_preview = list (self ._name_to_outlet ) == ["dataframe" ] if self ._name_to_outlet else False
375457
376458 async def _do (self , event ):
377459 if event is _termination_obj :
@@ -384,19 +466,7 @@ async def _do(self, event):
384466 outlet = self ._name_to_outlet ["dataframe" ]
385467 outlets .append (outlet )
386468 else :
387- if len (set (outlet_names )) != len (outlet_names ):
388- raise ValueError (
389- "select_outlets() returned duplicate outlets among the defined outlets: "
390- + ", " .join (outlet_names )
391- )
392- for outlet_name in outlet_names :
393- if outlet_name not in self ._name_to_outlet :
394- raise ValueError (
395- f"select_outlets() returned outlet name '{ outlet_name } ', which is not one of the "
396- f"defined outlets: " + ", " .join (self ._name_to_outlet )
397- )
398- outlet = self ._name_to_outlet [outlet_name ]
399- outlets .append (outlet )
469+ outlets = self ._check_outlets_by_names (outlet_names ) if outlet_names else self ._outlets
400470 return await self ._do_downstream (event , outlets = outlets )
401471
402472
@@ -421,26 +491,36 @@ async def _do(self, event):
421491
422492
423493class _UnaryFunctionFlow (Flow ):
424- def __init__ (self , fn , long_running = None , pass_context = None , ** kwargs ):
494+ def __init__ (
495+ self , fn , long_running = None , pass_context = None , fn_select_outlets : Optional [Callable ] = None , ** kwargs
496+ ):
425497 super ().__init__ (** kwargs )
426498 if not callable (fn ):
427499 raise TypeError (f"Expected a callable, got { type (fn )} " )
428- self ._is_async = asyncio .iscoroutinefunction (fn )
429- if self ._is_async and long_running :
500+ if asyncio .iscoroutinefunction (fn ) and long_running :
430501 raise ValueError ("long_running=True cannot be used in conjunction with a coroutine" )
431502 self ._long_running = long_running
432503 self ._fn = fn
433504 self ._pass_context = pass_context
505+ if fn_select_outlets and not callable (fn_select_outlets ):
506+ raise TypeError (f"Expected fn_select_outlets to be callable, got { type (fn )} " )
507+ self ._outlets_selector = fn_select_outlets
508+ self ._create_name_to_outlet = self ._outlets_selector or self ._method_is_overridden (
509+ "select_outlets" , _UnaryFunctionFlow
510+ )
434511
435- async def _call (self , element ):
512+ async def _call (self , element , fn , pass_kwargs = True ):
436513 if self ._long_running :
437- res = await asyncio .get_running_loop ().run_in_executor (None , self . _fn , element )
514+ res = await asyncio .get_running_loop ().run_in_executor (None , fn , element )
438515 else :
439516 kwargs = {}
440517 if self ._pass_context :
441518 kwargs = {"context" : self .context }
442- res = self ._fn (element , ** kwargs )
443- if self ._is_async :
519+ if pass_kwargs :
520+ res = fn (element , ** kwargs )
521+ else :
522+ res = fn (element )
523+ if asyncio .iscoroutinefunction (fn ):
444524 res = await res
445525 return res
446526
@@ -452,9 +532,15 @@ async def _do(self, event):
452532 return await self ._do_downstream (_termination_obj )
453533 else :
454534 element = self ._get_event_or_body (event )
455- fn_result = await self ._call (element )
535+ fn_result = await self ._call (element , self . _fn )
456536 await self ._do_internal (event , fn_result )
457537
538+ def select_outlets (self , event_body ) -> Optional [Collection [str ]]:
539+ if self ._outlets_selector :
540+ return self ._outlets_selector (event_body )
541+ else :
542+ return super ().select_outlets (event_body )
543+
458544
459545class DropColumns (Flow ):
460546 def __init__ (self , columns , ** kwargs ):
@@ -639,6 +725,7 @@ def __init__(self, long_running=None, **kwargs):
639725 raise ValueError ("long_running=True cannot be used in conjunction with a coroutine do()" )
640726 self ._long_running = long_running
641727 self ._filter = False
728+ self ._create_name_to_outlet = True
642729
643730 def filter (self ):
644731 # used in the .do() code to signal filtering
0 commit comments