Skip to content

Commit a0ee9df

Browse files
committed
set_next
1 parent 2a5ff18 commit a0ee9df

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

storey/flow.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
self.name = type(self).__name__
7878

7979
self._closeables = []
80+
self._selected_outlet: Optional[list[str]] = None
8081

8182
def _init(self):
8283
self._termination_received = 0
@@ -277,10 +278,14 @@ async def _do_downstream(self, event, outlets=None):
277278
if outlets:
278279
outlets = outlets
279280
elif event is not _termination_obj:
280-
if asyncio.iscoroutinefunction(self.select_outlets):
281-
outlet_names = await self.select_outlets(event.body)
281+
if self._selected_outlet:
282+
outlet_names = self._selected_outlet
283+
self._selected_outlet = None
282284
else:
283-
outlet_names = self.select_outlets(event.body)
285+
if asyncio.iscoroutinefunction(self.select_outlets):
286+
outlet_names = await self.select_outlets(event.body)
287+
else:
288+
outlet_names = self.select_outlets(event.body)
284289
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else self._outlets
285290
else:
286291
outlets = self._outlets
@@ -403,6 +408,15 @@ def select_outlets(self, event) -> typing.Optional[Collection[str]]:
403408
"""
404409
return None
405410

411+
def set_next_outlets(self, outlet_names: Union[str, list[str]]):
412+
"""
413+
Set the next outlets to which the event will be sent. This method can be used in conjunction with
414+
select_outlets() to dynamically determine the outlets for the next event.
415+
416+
:param outlet_names: A collection of outlet names to which the next event should be sent.
417+
"""
418+
self._selected_outlet = outlet_names if isinstance(outlet_names, list) else [outlet_names]
419+
406420
def _check_outlets_by_names(self, outlet_names: Collection[str]) -> list["Flow"]:
407421
outlets = []
408422

@@ -448,7 +462,7 @@ class Choice(Flow):
448462
def _init(self):
449463
super()._init()
450464
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
451-
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] if self._name_to_outlet else False
465+
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]
452466

453467
async def _do(self, event):
454468
if event is _termination_obj:

tests/test_flow.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5480,7 +5480,8 @@ def select_outlets(self, event):
54805480
return outlets
54815481

54825482

5483-
def test_regular_step_with_choice():
5483+
@pytest.mark.parametrize("fn_select_outlets", [True, False])
5484+
def test_regular_step_with_choice(fn_select_outlets):
54845485
class MyStep(Map):
54855486
def select_outlets(self, event):
54865487
outlets = ["all_events"]
@@ -5490,7 +5491,60 @@ def select_outlets(self, event):
54905491
outlets.append("up_to_five")
54915492
return outlets
54925493

5494+
def select(event):
5495+
outlets = ["all_events"]
5496+
if event > 5:
5497+
outlets.append("more_than_five")
5498+
else:
5499+
outlets.append("up_to_five")
5500+
return outlets
5501+
5502+
source = SyncEmitSource()
5503+
if fn_select_outlets:
5504+
my_step = Map(fn=lambda x: x, termination_result_fn=lambda x, y: x + y, fn_select_outlets=select)
5505+
else:
5506+
my_step = MyStep(fn=lambda x: x, termination_result_fn=lambda x, y: x + y)
5507+
5508+
all_events = Map(lambda x: x, name="all_events")
5509+
more_than_five = Map(lambda x: x * 10, name="more_than_five")
5510+
up_to_five = Map(lambda x: x * 100, name="up_to_five")
5511+
sum_up_all_events = Reduce(0, lambda acc, x: acc + x)
5512+
sum_up_more_than_five = Reduce(0, lambda acc, x: acc + x)
5513+
sum_up_up_to_five = Reduce(0, lambda acc, x: acc + x)
5514+
5515+
source.to(my_step)
5516+
my_step.to(all_events)
5517+
my_step.to(more_than_five)
5518+
my_step.to(up_to_five)
5519+
all_events.to(sum_up_all_events)
5520+
more_than_five.to(sum_up_more_than_five)
5521+
up_to_five.to(sum_up_up_to_five)
5522+
5523+
controller = source.run()
5524+
5525+
for i in range(4, 8):
5526+
controller.emit(i)
5527+
5528+
controller.terminate()
5529+
termination_result = controller.await_termination()
5530+
5531+
expected = sum(range(4, 8)) + sum(range(6, 8)) * 10 + sum(range(4, 6)) * 100
5532+
assert termination_result == expected
5533+
5534+
5535+
def test_regular_step_with_set_next():
5536+
class MyStep(MapClass):
5537+
def do(self, event):
5538+
outlets = ["all_events"]
5539+
if event > 5:
5540+
outlets.append("more_than_five")
5541+
else:
5542+
outlets.append("up_to_five")
5543+
self.set_next_outlets(outlets)
5544+
return event
5545+
54935546
source = SyncEmitSource()
5547+
54945548
my_step = MyStep(fn=lambda x: x, termination_result_fn=lambda x, y: x + y)
54955549
all_events = Map(lambda x: x, name="all_events")
54965550
more_than_five = Map(lambda x: x * 10, name="more_than_five")

0 commit comments

Comments
 (0)