Skip to content

Commit c389ddc

Browse files
authored
Rewrite Choice step to make it usable from mlrun (#537)
* Rewrite Choice step to make it usable from mlrun [ML-7818](https://iguazio.atlassian.net/browse/ML-7818) * Add missing space * Hack to avoid issue with mlrun preview * Improve docs * Remove accidental kwargs, add type annotation
1 parent 9bdebf4 commit c389ddc

File tree

3 files changed

+82
-62
lines changed

3 files changed

+82
-62
lines changed

storey/flow.py

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,16 @@ def _event_string(event):
252252
def _should_terminate(self):
253253
return self._termination_received == len(self._inlets)
254254

255-
async def _do_downstream(self, event):
256-
if not self._outlets:
255+
async def _do_downstream(self, event, outlets=None):
256+
outlets = self._outlets if outlets is None else outlets
257+
if not outlets:
257258
return
258259
if event is _termination_obj:
259260
# Only propagate the termination object once we received one per inlet
260-
self._outlets[0]._termination_received += 1
261-
if self._outlets[0]._should_terminate():
262-
self._termination_result = await self._outlets[0]._do(_termination_obj)
263-
for outlet in self._outlets[1:] + self._get_recovery_steps():
261+
outlets[0]._termination_received += 1
262+
if outlets[0]._should_terminate():
263+
self._termination_result = await outlets[0]._do(_termination_obj)
264+
for outlet in outlets[1:] + self._get_recovery_steps():
264265
outlet._termination_received += 1
265266
if outlet._should_terminate():
266267
self._termination_result = self._termination_result_fn(
@@ -269,28 +270,28 @@ async def _do_downstream(self, event):
269270
return self._termination_result
270271
# If there is more than one outlet, allow concurrent execution.
271272
tasks = []
272-
if len(self._outlets) > 1:
273+
if len(outlets) > 1:
273274
awaitable_result = event._awaitable_result
274275
event._awaitable_result = None
275276
original_events = getattr(event, "_original_events", None)
276277
# Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop
277278
event._original_events = None
278-
for i in range(1, len(self._outlets)):
279+
for i in range(1, len(outlets)):
279280
event_copy = copy.deepcopy(event)
280281
event_copy._awaitable_result = awaitable_result
281282
event_copy._original_events = original_events
282-
tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy)))
283+
tasks.append(asyncio.get_running_loop().create_task(outlets[i]._do_and_recover(event_copy)))
283284
# Set self-reference back after deepcopy
284285
event._original_events = original_events
285286
event._awaitable_result = awaitable_result
286287
if self.verbose and self.logger:
287288
step_name = self.name
288289
event_string = self._event_string(event)
289-
self.logger.debug(f"{step_name} -> {self._outlets[0].name} | {event_string}")
290-
await self._outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet.
290+
self.logger.debug(f"{step_name} -> {outlets[0].name} | {event_string}")
291+
await outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet.
291292
for i, task in enumerate(tasks, start=1):
292293
if self.verbose and self.logger:
293-
self.logger.debug(f"{step_name} -> {self._outlets[i].name} | {event_string}")
294+
self.logger.debug(f"{step_name} -> {outlets[i].name} | {event_string}")
294295
await task
295296

296297
def _get_event_or_body(self, event):
@@ -347,46 +348,48 @@ def _get_uuid(self):
347348

348349

349350
class Choice(Flow):
350-
"""Redirects each input element into at most one of multiple downstreams.
351-
352-
:param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a
353-
function. The first condition in the list to evaluate as true for an input element causes that element to
354-
be redirected to that downstream step.
355-
:type choice_array: tuple of (Flow, Function (Event=>boolean))
356-
:param default: a default step for events that did not match any condition in choice_array. If not set, elements
357-
that don't match any condition will be discarded.
358-
:type default: Flow
359-
:param name: Name of this step, as it should appear in logs. Defaults to class name (Choice).
360-
:type name: string
361-
:param full_event: Whether user functions should receive and return Event objects (when True),
362-
or only the payload (when False). Defaults to False.
363-
:type full_event: boolean
351+
"""
352+
Redirects each input element into any number of predetermined downstream steps. Override select_outlets()
353+
to route events to any number of downstream steps.
364354
"""
365355

366-
def __init__(self, choice_array, default=None, **kwargs):
367-
Flow.__init__(self, **kwargs)
368-
369-
self._choice_array = choice_array
370-
for outlet, _ in choice_array:
371-
self.to(outlet)
372-
373-
if default:
374-
self.to(default)
375-
self._default = default
356+
def _init(self):
357+
super()._init()
358+
self._name_to_outlet = {}
359+
for outlet in self._outlets:
360+
if outlet.name in self._name_to_outlet:
361+
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step")
362+
self._name_to_outlet[outlet.name] = outlet
363+
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
364+
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]
365+
366+
def select_outlets(self, event) -> List[str]:
367+
"""
368+
Override this method to route events based on a customer logic. The default implementation will route all
369+
events to all outlets.
370+
"""
371+
return list(self._name_to_outlet.keys())
376372

377373
async def _do(self, event):
378-
if not self._outlets or event is _termination_obj:
379-
return await super()._do_downstream(event)
380-
chosen_outlet = None
381-
element = self._get_event_or_body(event)
382-
for outlet, condition in self._choice_array:
383-
if condition(element):
384-
chosen_outlet = outlet
385-
break
386-
if chosen_outlet:
387-
await chosen_outlet._do(event)
388-
elif self._default:
389-
await self._default._do(event)
374+
if event is _termination_obj:
375+
return await self._do_downstream(_termination_obj)
376+
else:
377+
event_body = event if self._full_event else event.body
378+
outlet_names = self.select_outlets(event_body)
379+
outlets = []
380+
if self._passthrough_for_preview:
381+
outlet = self._name_to_outlet["dataframe"]
382+
outlets.append(outlet)
383+
else:
384+
for outlet_name in outlet_names:
385+
if outlet_name not in self._name_to_outlet:
386+
raise ValueError(
387+
f"select_outlets() returned outlet name '{outlet_name}', which is not one of the "
388+
f"defined outlets: " + ", ".join(self._name_to_outlet)
389+
)
390+
outlet = self._name_to_outlet[outlet_name]
391+
outlets.append(outlet)
392+
return await self._do_downstream(event, outlets=outlets)
390393

391394

392395
class Recover(Flow):

storey/sources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ async def _run_loop(self):
313313
await _commit_handled_events(self._outstanding_offsets, committer, commit_all=True)
314314
self._termination_future.set_result(termination_result)
315315
except BaseException as ex:
316+
traceback.print_exc()
316317
if self.logger:
317318
message = "An error was raised"
318319
raised_by = getattr(ex, "_raised_by_storey_step", None)

tests/test_flow.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,26 +1690,42 @@ def boom(_):
16901690

16911691

16921692
def test_choice():
1693-
small_reduce = Reduce(0, lambda acc, x: acc + x)
1693+
class MyChoice(Choice):
1694+
def select_outlets(self, event):
1695+
outlets = ["all_events"]
1696+
if event > 5:
1697+
outlets.append("more_than_five")
1698+
else:
1699+
outlets.append("up_to_five")
1700+
return outlets
16941701

1695-
big_reduce = build_flow([Map(lambda x: x * 100), Reduce(0, lambda acc, x: acc + x)])
1702+
source = SyncEmitSource()
1703+
my_choice = MyChoice(termination_result_fn=lambda x, y: x + y)
1704+
all_events = Map(lambda x: x, name="all_events")
1705+
more_than_five = Map(lambda x: x * 10, name="more_than_five")
1706+
up_to_five = Map(lambda x: x * 100, name="up_to_five")
1707+
sum_up_all_events = Reduce(0, lambda acc, x: acc + x)
1708+
sum_up_more_than_five = Reduce(0, lambda acc, x: acc + x)
1709+
sum_up_up_to_five = Reduce(0, lambda acc, x: acc + x)
1710+
1711+
source.to(my_choice)
1712+
my_choice.to(all_events)
1713+
my_choice.to(more_than_five)
1714+
my_choice.to(up_to_five)
1715+
all_events.to(sum_up_all_events)
1716+
more_than_five.to(sum_up_more_than_five)
1717+
up_to_five.to(sum_up_up_to_five)
16961718

1697-
controller = build_flow(
1698-
[
1699-
SyncEmitSource(),
1700-
Choice(
1701-
[(big_reduce, lambda x: x % 2 == 0)],
1702-
default=small_reduce,
1703-
termination_result_fn=lambda x, y: x + y,
1704-
),
1705-
]
1706-
).run()
1719+
controller = source.run()
17071720

1708-
for i in range(10):
1721+
for i in range(4, 8):
17091722
controller.emit(i)
1723+
17101724
controller.terminate()
17111725
termination_result = controller.await_termination()
1712-
assert termination_result == 2025
1726+
1727+
expected = sum(range(4, 8)) + sum(range(6, 8)) * 10 + sum(range(4, 6)) * 100
1728+
assert termination_result == expected
17131729

17141730

17151731
def test_metadata():

0 commit comments

Comments
 (0)