Skip to content

Commit 0bfaa08

Browse files
authored
Minor fixes to the cyclic pr (#598)
1 parent 74fec1a commit 0bfaa08

File tree

2 files changed

+14
-20
lines changed

2 files changed

+14
-20
lines changed

storey/flow.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
recovery_step=None,
4545
termination_result_fn=lambda x, y: x if x is not None else y,
4646
context=None,
47-
max_iteration: Optional[int] = None,
47+
max_iterations: Optional[int] = None,
4848
**kwargs,
4949
):
5050
self._outlets = []
@@ -67,7 +67,7 @@ def __init__(
6767
self._full_event = kwargs.get("full_event")
6868
self._input_path = kwargs.get("input_path")
6969
self._result_path = kwargs.get("result_path")
70-
self._max_iteration = max_iteration
70+
self._max_iterations = max_iterations
7171
self._runnable = False
7272
name = kwargs.get("name", None)
7373
if name:
@@ -284,17 +284,11 @@ def _event_string(event):
284284
def _should_terminate(self):
285285
return self._termination_received == len(self._inlets)
286286

287-
async def _do_downstream(self, event, outlets=None):
288-
if not outlets:
289-
if event is _termination_obj:
290-
outlets = self._outlets
291-
else:
292-
if self._selected_outlets:
293-
outlet_names = self._selected_outlets
294-
self._selected_outlets = None
295-
else:
296-
outlet_names = self.select_outlets(event.body)
297-
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else self._outlets
287+
async def _do_downstream(self, event, outlets=None, select_outlets: bool = True):
288+
if not outlets and event is not _termination_obj and select_outlets:
289+
outlet_names = self.select_outlets(event.body)
290+
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else None
291+
outlets = self._outlets if outlets is None else outlets
298292

299293
if not outlets:
300294
return
@@ -390,9 +384,9 @@ def _check_step_in_flow(self, type_to_check, visited=None):
390384
return False
391385

392386
def check_and_update_iteration_number(self, event) -> Optional[Callable]:
393-
if hasattr(event, "_cyclic_counter") and self._max_iteration is not None:
387+
if hasattr(event, "_cyclic_counter") and self._max_iterations is not None:
394388
counter = self.get_iteration_counter(event)
395-
if counter >= self._max_iteration:
389+
if counter >= self._max_iterations:
396390
raise RuntimeError(f"Max iterations exceeded in step '{self.name}' for event {event.id}")
397391
event._cyclic_counter[self.name] = counter + 1
398392
else:
@@ -457,7 +451,7 @@ def _init(self):
457451

458452
async def _do(self, event):
459453
if event is _termination_obj:
460-
return await self._do_downstream(_termination_obj)
454+
return await self._do_downstream(_termination_obj, select_outlets=False)
461455
else:
462456
event_body = event if self._full_event else event.body
463457
outlet_names = self.select_outlets(event_body)
@@ -466,8 +460,8 @@ async def _do(self, event):
466460
outlet = self._name_to_outlet["dataframe"]
467461
outlets.append(outlet)
468462
else:
469-
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else self._outlets
470-
return await self._do_downstream(event, outlets=outlets)
463+
outlets = self._check_outlets_by_names(outlet_names)
464+
return await self._do_downstream(event, outlets=outlets, select_outlets=False)
471465

472466

473467
class Recover(Flow):

tests/test_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5538,10 +5538,10 @@ def select(event):
55385538
def test_cyclic_graphs(iterations, with_recovery):
55395539
source = SyncEmitSource()
55405540
my_loop = MyLoop(
5541-
fn=lambda x: x, iterations=iterations, name="my_loop", end="end", counter="counter", max_iteration=5
5541+
fn=lambda x: x, iterations=iterations, name="my_loop", end="end", counter="counter", max_iterations=5
55425542
)
55435543
start = Map(lambda x: x, name="start")
5544-
counter = Map(lambda x: x + 1, name="counter", max_iteration=5)
5544+
counter = Map(lambda x: x + 1, name="counter", max_iterations=5)
55455545
end = Map(lambda x: x, name="end")
55465546

55475547
source.to(start)

0 commit comments

Comments
 (0)