Skip to content

Commit 74fec1a

Browse files
authored
Allow cycle graph & add option to select outlets from all kind of steps (#593)
1 parent 463540f commit 74fec1a

File tree

3 files changed

+305
-46
lines changed

3 files changed

+305
-46
lines changed

storey/flow.py

Lines changed: 128 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
recovery_step=None,
4545
termination_result_fn=lambda x, y: x if x is not None else y,
4646
context=None,
47+
max_iteration: Optional[int] = None,
4748
**kwargs,
4849
):
4950
self._outlets = []
@@ -66,6 +67,7 @@ def __init__(
6667
self._full_event = kwargs.get("full_event")
6768
self._input_path = kwargs.get("input_path")
6869
self._result_path = kwargs.get("result_path")
70+
self._max_iteration = max_iteration
6971
self._runnable = False
7072
name = kwargs.get("name", None)
7173
if name:
@@ -74,10 +76,25 @@ def __init__(
7476
self.name = type(self).__name__
7577

7678
self._closeables = []
79+
self._selected_outlets: Optional[list[str]] = None
80+
self._create_name_to_outlet = True
7781

7882
def _init(self):
7983
self._termination_received = 0
8084
self._termination_result = None
85+
self._name_to_outlet = {}
86+
if self._method_is_overridden("select_outlets", Flow) and self._create_name_to_outlet:
87+
self._init_name_to_outlet()
88+
89+
def _init_name_to_outlet(self):
90+
for outlet in self._outlets:
91+
if outlet.name in self._name_to_outlet:
92+
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in step '{self.name}'")
93+
self._name_to_outlet[outlet.name] = outlet
94+
95+
def _method_is_overridden(self, method_name: str, parent_cls):
96+
"""Return True if the subclass overrides the given method."""
97+
return getattr(self.__class__, method_name) is not getattr(parent_cls, method_name)
8198

8299
def to_dict(self, fields=None, exclude=None):
83100
"""convert the step object to a python dictionary"""
@@ -187,16 +204,27 @@ def _get_recovery_step(self, exception):
187204
else:
188205
return self._recovery_step
189206

190-
def run(self):
207+
def run(self, visited=None):
191208
if not self._legal_first_step and not self._runnable:
192209
raise ValueError("Flow must start with a source")
210+
211+
# Initialize visited set once at the top (only for the root call)
212+
if visited is None:
213+
visited = set()
214+
215+
# Detect cycles: if we've already visited this step, don't run it again
216+
if self in visited:
217+
return []
193218
self._init()
219+
visited.add(self)
220+
194221
outlets = []
195222
outlets.extend(self._outlets)
196223
outlets.extend(self._get_recovery_steps())
197224
for outlet in outlets:
198225
outlet._runnable = True
199-
self._closeables.extend(outlet.run())
226+
outlet_closeables = outlet.run(visited)
227+
self._closeables.extend(outlet_closeables)
200228
return self._closeables
201229

202230
def _get_recovery_steps(self):
@@ -215,6 +243,7 @@ async def _do(self, event):
215243

216244
async def _do_and_recover(self, event):
217245
try:
246+
self.check_and_update_iteration_number(event)
218247
return await self._do(event)
219248
except BaseException as ex:
220249
if getattr(ex, "_raised_by_storey_step", None) is not None:
@@ -256,7 +285,17 @@ def _should_terminate(self):
256285
return self._termination_received == len(self._inlets)
257286

258287
async def _do_downstream(self, event, outlets=None):
259-
outlets = self._outlets if outlets is None else outlets
288+
if not outlets:
289+
if event is _termination_obj:
290+
outlets = self._outlets
291+
else:
292+
if self._selected_outlets:
293+
outlet_names = self._selected_outlets
294+
self._selected_outlets = None
295+
else:
296+
outlet_names = self.select_outlets(event.body)
297+
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else self._outlets
298+
260299
if not outlets:
261300
return
262301
if event is _termination_obj:
@@ -320,21 +359,76 @@ def _user_fn_output_to_event(self, event, fn_result):
320359
mapped_event.body = fn_result
321360
return mapped_event
322361

323-
def _check_step_in_flow(self, type_to_check):
362+
def _check_step_in_flow(self, type_to_check, visited=None):
363+
# initialize the visited set once at the top
364+
if visited is None:
365+
visited = set()
366+
367+
# detect cycles
368+
if self in visited:
369+
return False
370+
visited.add(self)
371+
372+
# check this node
324373
if isinstance(self, type_to_check):
325374
return True
375+
376+
# check outlets
326377
for outlet in self._outlets:
327-
if outlet._check_step_in_flow(type_to_check):
378+
if outlet._check_step_in_flow(type_to_check, visited):
328379
return True
380+
381+
# check recovery step
329382
if isinstance(self._recovery_step, Flow):
330-
if self._recovery_step._check_step_in_flow(type_to_check):
383+
if self._recovery_step._check_step_in_flow(type_to_check, visited):
331384
return True
385+
332386
elif isinstance(self._recovery_step, dict):
333387
for step in self._recovery_step.values():
334-
if step._check_step_in_flow(type_to_check):
388+
if step._check_step_in_flow(type_to_check, visited):
335389
return True
336390
return False
337391

392+
def check_and_update_iteration_number(self, event) -> Optional[Callable]:
393+
if hasattr(event, "_cyclic_counter") and self._max_iteration is not None:
394+
counter = self.get_iteration_counter(event)
395+
if counter >= self._max_iteration:
396+
raise RuntimeError(f"Max iterations exceeded in step '{self.name}' for event {event.id}")
397+
event._cyclic_counter[self.name] = counter + 1
398+
else:
399+
event._cyclic_counter = {self.name: 1}
400+
401+
def get_iteration_counter(self, event):
402+
return getattr(event, "_cyclic_counter", {}).get(self.name, 0)
403+
404+
def select_outlets(self, event) -> Optional[Collection[str]]:
405+
"""
406+
Override this method to route events based on a custom logic. The default implementation will route all
407+
events to all outlets.
408+
"""
409+
return None
410+
411+
def _check_outlets_by_names(self, outlet_names: Collection[str]) -> list["Flow"]:
412+
outlets = []
413+
414+
# Check for duplicates
415+
if len(set(outlet_names)) != len(outlet_names):
416+
raise ValueError(
417+
f"Invalid outlet selection for '{self.name}': duplicate outlet names were provided "
418+
f"({', '.join(outlet_names)})."
419+
)
420+
421+
# Validate each outlet name
422+
for outlet_name in outlet_names:
423+
if outlet_name not in self._name_to_outlet:
424+
raise ValueError(
425+
f"Invalid outlet '{outlet_name}' for '{self.name}'. "
426+
f"Allowed outlets are: {', '.join(self._name_to_outlet)}."
427+
)
428+
outlets.append(self._name_to_outlet[outlet_name])
429+
430+
return outlets
431+
338432

339433
class WithUUID:
340434
def __init__(self):
@@ -358,20 +452,8 @@ class Choice(Flow):
358452

359453
def _init(self):
360454
super()._init()
361-
self._name_to_outlet = {}
362-
for outlet in self._outlets:
363-
if outlet.name in self._name_to_outlet:
364-
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step")
365-
self._name_to_outlet[outlet.name] = outlet
366455
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
367-
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]
368-
369-
def select_outlets(self, event) -> Collection[str]:
370-
"""
371-
Override this method to route events based on a customer logic. The default implementation will route all
372-
events to all outlets.
373-
"""
374-
return self._name_to_outlet.keys()
456+
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] if self._name_to_outlet else False
375457

376458
async def _do(self, event):
377459
if event is _termination_obj:
@@ -384,19 +466,7 @@ async def _do(self, event):
384466
outlet = self._name_to_outlet["dataframe"]
385467
outlets.append(outlet)
386468
else:
387-
if len(set(outlet_names)) != len(outlet_names):
388-
raise ValueError(
389-
"select_outlets() returned duplicate outlets among the defined outlets: "
390-
+ ", ".join(outlet_names)
391-
)
392-
for outlet_name in outlet_names:
393-
if outlet_name not in self._name_to_outlet:
394-
raise ValueError(
395-
f"select_outlets() returned outlet name '{outlet_name}', which is not one of the "
396-
f"defined outlets: " + ", ".join(self._name_to_outlet)
397-
)
398-
outlet = self._name_to_outlet[outlet_name]
399-
outlets.append(outlet)
469+
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else self._outlets
400470
return await self._do_downstream(event, outlets=outlets)
401471

402472

@@ -421,26 +491,36 @@ async def _do(self, event):
421491

422492

423493
class _UnaryFunctionFlow(Flow):
424-
def __init__(self, fn, long_running=None, pass_context=None, **kwargs):
494+
def __init__(
495+
self, fn, long_running=None, pass_context=None, fn_select_outlets: Optional[Callable] = None, **kwargs
496+
):
425497
super().__init__(**kwargs)
426498
if not callable(fn):
427499
raise TypeError(f"Expected a callable, got {type(fn)}")
428-
self._is_async = asyncio.iscoroutinefunction(fn)
429-
if self._is_async and long_running:
500+
if asyncio.iscoroutinefunction(fn) and long_running:
430501
raise ValueError("long_running=True cannot be used in conjunction with a coroutine")
431502
self._long_running = long_running
432503
self._fn = fn
433504
self._pass_context = pass_context
505+
if fn_select_outlets and not callable(fn_select_outlets):
506+
raise TypeError(f"Expected fn_select_outlets to be callable, got {type(fn)}")
507+
self._outlets_selector = fn_select_outlets
508+
self._create_name_to_outlet = self._outlets_selector or self._method_is_overridden(
509+
"select_outlets", _UnaryFunctionFlow
510+
)
434511

435-
async def _call(self, element):
512+
async def _call(self, element, fn, pass_kwargs=True):
436513
if self._long_running:
437-
res = await asyncio.get_running_loop().run_in_executor(None, self._fn, element)
514+
res = await asyncio.get_running_loop().run_in_executor(None, fn, element)
438515
else:
439516
kwargs = {}
440517
if self._pass_context:
441518
kwargs = {"context": self.context}
442-
res = self._fn(element, **kwargs)
443-
if self._is_async:
519+
if pass_kwargs:
520+
res = fn(element, **kwargs)
521+
else:
522+
res = fn(element)
523+
if asyncio.iscoroutinefunction(fn):
444524
res = await res
445525
return res
446526

@@ -452,9 +532,15 @@ async def _do(self, event):
452532
return await self._do_downstream(_termination_obj)
453533
else:
454534
element = self._get_event_or_body(event)
455-
fn_result = await self._call(element)
535+
fn_result = await self._call(element, self._fn)
456536
await self._do_internal(event, fn_result)
457537

538+
def select_outlets(self, event_body) -> Optional[Collection[str]]:
539+
if self._outlets_selector:
540+
return self._outlets_selector(event_body)
541+
else:
542+
return super().select_outlets(event_body)
543+
458544

459545
class DropColumns(Flow):
460546
def __init__(self, columns, **kwargs):
@@ -639,6 +725,7 @@ def __init__(self, long_running=None, **kwargs):
639725
raise ValueError("long_running=True cannot be used in conjunction with a coroutine do()")
640726
self._long_running = long_running
641727
self._filter = False
728+
self._create_name_to_outlet = True
642729

643730
def filter(self):
644731
# used in the .do() code to signal filtering

storey/sources.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def _emit(self, event):
386386
if event is not _termination_obj:
387387
self._raise_on_error(self._ex)
388388

389-
def run(self):
389+
def run(self, visited=None):
390390
"""Starts the flow"""
391391
self._closeables = super().run()
392392

@@ -710,7 +710,7 @@ async def _emit(self, event):
710710
if event is not _termination_obj:
711711
self._raise_on_error()
712712

713-
def run(self):
713+
def run(self, visited=None):
714714
"""Starts the flow"""
715715
self._closeables = super().run()
716716
loop_task = asyncio.get_running_loop().create_task(self._run_loop_and_log_unexpected_error())
@@ -753,7 +753,7 @@ def _raise_on_error(self, ex):
753753
raise type(self._ex)("Flow execution terminated") from self._ex
754754
raise self._ex
755755

756-
def run(self):
756+
def run(self, visited=None):
757757
self._closeables = super().run()
758758

759759
self._init()

0 commit comments

Comments
 (0)