@@ -252,15 +252,16 @@ def _event_string(event):
252252 def _should_terminate (self ):
253253 return self ._termination_received == len (self ._inlets )
254254
255- async def _do_downstream (self , event ):
256- if not self ._outlets :
255+ async def _do_downstream (self , event , outlets = None ):
256+ outlets = self ._outlets if outlets is None else outlets
257+ if not outlets :
257258 return
258259 if event is _termination_obj :
259260 # Only propagate the termination object once we received one per inlet
260- self . _outlets [0 ]._termination_received += 1
261- if self . _outlets [0 ]._should_terminate ():
262- self ._termination_result = await self . _outlets [0 ]._do (_termination_obj )
263- for outlet in self . _outlets [1 :] + self ._get_recovery_steps ():
261+ outlets [0 ]._termination_received += 1
262+ if outlets [0 ]._should_terminate ():
263+ self ._termination_result = await outlets [0 ]._do (_termination_obj )
264+ for outlet in outlets [1 :] + self ._get_recovery_steps ():
264265 outlet ._termination_received += 1
265266 if outlet ._should_terminate ():
266267 self ._termination_result = self ._termination_result_fn (
@@ -269,28 +270,28 @@ async def _do_downstream(self, event):
269270 return self ._termination_result
270271 # If there is more than one outlet, allow concurrent execution.
271272 tasks = []
272- if len (self . _outlets ) > 1 :
273+ if len (outlets ) > 1 :
273274 awaitable_result = event ._awaitable_result
274275 event ._awaitable_result = None
275276 original_events = getattr (event , "_original_events" , None )
276277 # Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop
277278 event ._original_events = None
278- for i in range (1 , len (self . _outlets )):
279+ for i in range (1 , len (outlets )):
279280 event_copy = copy .deepcopy (event )
280281 event_copy ._awaitable_result = awaitable_result
281282 event_copy ._original_events = original_events
282- tasks .append (asyncio .get_running_loop ().create_task (self . _outlets [i ]._do_and_recover (event_copy )))
283+ tasks .append (asyncio .get_running_loop ().create_task (outlets [i ]._do_and_recover (event_copy )))
283284 # Set self-reference back after deepcopy
284285 event ._original_events = original_events
285286 event ._awaitable_result = awaitable_result
286287 if self .verbose and self .logger :
287288 step_name = self .name
288289 event_string = self ._event_string (event )
289- self .logger .debug (f"{ step_name } -> { self . _outlets [0 ].name } | { event_string } " )
290- await self . _outlets [0 ]._do_and_recover (event ) # Optimization - avoids creating a task for the first outlet.
290+ self .logger .debug (f"{ step_name } -> { outlets [0 ].name } | { event_string } " )
291+ await outlets [0 ]._do_and_recover (event ) # Optimization - avoids creating a task for the first outlet.
291292 for i , task in enumerate (tasks , start = 1 ):
292293 if self .verbose and self .logger :
293- self .logger .debug (f"{ step_name } -> { self . _outlets [i ].name } | { event_string } " )
294+ self .logger .debug (f"{ step_name } -> { outlets [i ].name } | { event_string } " )
294295 await task
295296
296297 def _get_event_or_body (self , event ):
@@ -347,46 +348,48 @@ def _get_uuid(self):
347348
348349
349350class Choice (Flow ):
350- """Redirects each input element into at most one of multiple downstreams.
351-
352- :param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a
353- function. The first condition in the list to evaluate as true for an input element causes that element to
354- be redirected to that downstream step.
355- :type choice_array: tuple of (Flow, Function (Event=>boolean))
356- :param default: a default step for events that did not match any condition in choice_array. If not set, elements
357- that don't match any condition will be discarded.
358- :type default: Flow
359- :param name: Name of this step, as it should appear in logs. Defaults to class name (Choice).
360- :type name: string
361- :param full_event: Whether user functions should receive and return Event objects (when True),
362- or only the payload (when False). Defaults to False.
363- :type full_event: boolean
351+ """
352+ Redirects each input element into any number of predetermined downstream steps. Override select_outlets()
353+ to route events to any number of downstream steps.
364354 """
365355
366- def __init__ (self , choice_array , default = None , ** kwargs ):
367- Flow .__init__ (self , ** kwargs )
368-
369- self ._choice_array = choice_array
370- for outlet , _ in choice_array :
371- self .to (outlet )
372-
373- if default :
374- self .to (default )
375- self ._default = default
356+ def _init (self ):
357+ super ()._init ()
358+ self ._name_to_outlet = {}
359+ for outlet in self ._outlets :
360+ if outlet .name in self ._name_to_outlet :
361+ raise ValueError (f"Ambiguous outlet name '{ outlet .name } ' in Choice step" )
362+ self ._name_to_outlet [outlet .name ] = outlet
363+ # TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
364+ self ._passthrough_for_preview = list (self ._name_to_outlet ) == ["dataframe" ]
365+
366+ def select_outlets (self , event ) -> List [str ]:
367+ """
368+ Override this method to route events based on a customer logic. The default implementation will route all
369+ events to all outlets.
370+ """
371+ return list (self ._name_to_outlet .keys ())
376372
377373 async def _do (self , event ):
378- if not self ._outlets or event is _termination_obj :
379- return await super ()._do_downstream (event )
380- chosen_outlet = None
381- element = self ._get_event_or_body (event )
382- for outlet , condition in self ._choice_array :
383- if condition (element ):
384- chosen_outlet = outlet
385- break
386- if chosen_outlet :
387- await chosen_outlet ._do (event )
388- elif self ._default :
389- await self ._default ._do (event )
374+ if event is _termination_obj :
375+ return await self ._do_downstream (_termination_obj )
376+ else :
377+ event_body = event if self ._full_event else event .body
378+ outlet_names = self .select_outlets (event_body )
379+ outlets = []
380+ if self ._passthrough_for_preview :
381+ outlet = self ._name_to_outlet ["dataframe" ]
382+ outlets .append (outlet )
383+ else :
384+ for outlet_name in outlet_names :
385+ if outlet_name not in self ._name_to_outlet :
386+ raise ValueError (
387+ f"select_outlets() returned outlet name '{ outlet_name } ', which is not one of the "
388+ f"defined outlets: " + ", " .join (self ._name_to_outlet )
389+ )
390+ outlet = self ._name_to_outlet [outlet_name ]
391+ outlets .append (outlet )
392+ return await self ._do_downstream (event , outlets = outlets )
390393
391394
392395class Recover (Flow ):
0 commit comments