Skip to content
182 changes: 141 additions & 41 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
recovery_step=None,
termination_result_fn=lambda x, y: x if x is not None else y,
context=None,
max_iteration: Optional[int] = None,
**kwargs,
):
self._outlets = []
Expand All @@ -66,6 +67,7 @@ def __init__(
self._full_event = kwargs.get("full_event")
self._input_path = kwargs.get("input_path")
self._result_path = kwargs.get("result_path")
self._max_iteration = max_iteration
self._runnable = False
name = kwargs.get("name", None)
if name:
Expand All @@ -74,10 +76,25 @@ def __init__(
self.name = type(self).__name__

self._closeables = []
self._selected_outlets: Optional[list[str]] = None
self._create_name_to_outlet = True

def _init(self):
self._termination_received = 0
self._termination_result = None
self._name_to_outlet = {}
if self._method_is_overridden("select_outlets", Flow) and self._create_name_to_outlet:
self._init_name_to_outlet()

def _init_name_to_outlet(self):
for outlet in self._outlets:
if outlet.name in self._name_to_outlet:
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in step '{self.name}'")
self._name_to_outlet[outlet.name] = outlet

def _method_is_overridden(self, method_name: str, parent_cls):
"""Return True if the subclass overrides the given method."""
return getattr(self.__class__, method_name) is not getattr(parent_cls, method_name)

def to_dict(self, fields=None, exclude=None):
"""convert the step object to a python dictionary"""
Expand Down Expand Up @@ -187,16 +204,27 @@ def _get_recovery_step(self, exception):
else:
return self._recovery_step

def run(self):
def run(self, visited=None):
if not self._legal_first_step and not self._runnable:
raise ValueError("Flow must start with a source")

# Initialize visited set once at the top (only for the root call)
if visited is None:
visited = set()

# Detect cycles: if we've already visited this step, don't run it again
if self in visited:
return []
self._init()
visited.add(self)

outlets = []
outlets.extend(self._outlets)
outlets.extend(self._get_recovery_steps())
for outlet in outlets:
outlet._runnable = True
self._closeables.extend(outlet.run())
outlet_closeables = outlet.run(visited)
self._closeables.extend(outlet_closeables)
return self._closeables

def _get_recovery_steps(self):
Expand All @@ -215,6 +243,7 @@ async def _do(self, event):

async def _do_and_recover(self, event):
try:
self.check_and_update_iteration_number(event)
return await self._do(event)
except BaseException as ex:
if getattr(ex, "_raised_by_storey_step", None) is not None:
Expand Down Expand Up @@ -256,7 +285,20 @@ def _should_terminate(self):
return self._termination_received == len(self._inlets)

async def _do_downstream(self, event, outlets=None):
outlets = self._outlets if outlets is None else outlets
if not outlets:
if event is _termination_obj:
outlets = self._outlets
else:
if self._selected_outlets:
outlet_names = self._selected_outlets
self._selected_outlets = None
else:
if asyncio.iscoroutinefunction(self.select_outlets):
outlet_names = await self.select_outlets(event.body)
else:
outlet_names = self.select_outlets(event.body)
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else self._outlets

if not outlets:
return
if event is _termination_obj:
Expand Down Expand Up @@ -320,21 +362,79 @@ def _user_fn_output_to_event(self, event, fn_result):
mapped_event.body = fn_result
return mapped_event

def _check_step_in_flow(self, type_to_check):
def _check_step_in_flow(self, type_to_check, visited=None):
# initialize the visited set once at the top
if visited is None:
visited = set()

# detect cycles
if self in visited:
return False
visited.add(self)

# check this node
if isinstance(self, type_to_check):
return True

# check outlets
for outlet in self._outlets:
if outlet._check_step_in_flow(type_to_check):
if outlet._check_step_in_flow(type_to_check, visited):
return True

# check recovery step
if isinstance(self._recovery_step, Flow):
if self._recovery_step._check_step_in_flow(type_to_check):
if self._recovery_step._check_step_in_flow(type_to_check, visited):
return True

elif isinstance(self._recovery_step, dict):
for step in self._recovery_step.values():
if step._check_step_in_flow(type_to_check):
if step._check_step_in_flow(type_to_check, visited):
return True
return False

def check_and_update_iteration_number(self, event) -> Optional[Callable]:
if hasattr(event, "_cyclic_counter") and self._max_iteration is not None:
counter = self.get_iteration_counter(event)
if counter >= self._max_iteration:
raise RuntimeError(f"Max iterations exceeded in step '{self.name}' for event {event.id}")
event._cyclic_counter[self.name] = counter + 1
else:
event._cyclic_counter = {self.name: 1}

def get_iteration_counter(self, event):
if not hasattr(event, "_cyclic_counter"):
return 0
else:
return event._cyclic_counter.get(self.name, 0)

def select_outlets(self, event) -> Optional[Collection[str]]:
"""
Override this method to route events based on a custom logic. The default implementation will route all
events to all outlets.
"""
return None

def _check_outlets_by_names(self, outlet_names: Collection[str]) -> list["Flow"]:
outlets = []

# Check for duplicates
if len(set(outlet_names)) != len(outlet_names):
raise ValueError(
f"Invalid outlet selection for '{self.name}': duplicate outlet names were provided "
f"({', '.join(outlet_names)})."
)

# Validate each outlet name
for outlet_name in outlet_names:
if outlet_name not in self._name_to_outlet:
raise ValueError(
f"Invalid outlet '{outlet_name}' for '{self.name}'. "
f"Allowed outlets are: {', '.join(self._name_to_outlet)}."
)
outlets.append(self._name_to_outlet[outlet_name])

return outlets


class WithUUID:
def __init__(self):
Expand All @@ -358,20 +458,8 @@ class Choice(Flow):

def _init(self):
super()._init()
self._name_to_outlet = {}
for outlet in self._outlets:
if outlet.name in self._name_to_outlet:
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step")
self._name_to_outlet[outlet.name] = outlet
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]

def select_outlets(self, event) -> Collection[str]:
"""
Override this method to route events based on a customer logic. The default implementation will route all
events to all outlets.
"""
return self._name_to_outlet.keys()
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] if self._name_to_outlet else False

async def _do(self, event):
if event is _termination_obj:
Expand All @@ -384,19 +472,7 @@ async def _do(self, event):
outlet = self._name_to_outlet["dataframe"]
outlets.append(outlet)
else:
if len(set(outlet_names)) != len(outlet_names):
raise ValueError(
"select_outlets() returned duplicate outlets among the defined outlets: "
+ ", ".join(outlet_names)
)
for outlet_name in outlet_names:
if outlet_name not in self._name_to_outlet:
raise ValueError(
f"select_outlets() returned outlet name '{outlet_name}', which is not one of the "
f"defined outlets: " + ", ".join(self._name_to_outlet)
)
outlet = self._name_to_outlet[outlet_name]
outlets.append(outlet)
outlets = self._check_outlets_by_names(outlet_names) if outlet_names else self._outlets
return await self._do_downstream(event, outlets=outlets)


Expand All @@ -421,26 +497,33 @@ async def _do(self, event):


class _UnaryFunctionFlow(Flow):
def __init__(self, fn, long_running=None, pass_context=None, **kwargs):
def __init__(
self, fn, long_running=None, pass_context=None, fn_select_outlets: Optional[Callable] = None, **kwargs
):
super().__init__(**kwargs)
if not callable(fn):
raise TypeError(f"Expected a callable, got {type(fn)}")
self._is_async = asyncio.iscoroutinefunction(fn)
if self._is_async and long_running:
if asyncio.iscoroutinefunction(fn) and long_running:
raise ValueError("long_running=True cannot be used in conjunction with a coroutine")
self._long_running = long_running
self._fn = fn
self._pass_context = pass_context
if fn_select_outlets and not callable(fn_select_outlets):
raise TypeError(f"Expected fn_select_outlets to be callable, got {type(fn)}")
self._outlets_selector = fn_select_outlets

async def _call(self, element):
async def _call(self, element, fn, pass_kwargs=True):
if self._long_running:
res = await asyncio.get_running_loop().run_in_executor(None, self._fn, element)
res = await asyncio.get_running_loop().run_in_executor(None, fn, element)
else:
kwargs = {}
if self._pass_context:
kwargs = {"context": self.context}
res = self._fn(element, **kwargs)
if self._is_async:
if pass_kwargs:
res = fn(element, **kwargs)
else:
res = fn(element)
if asyncio.iscoroutinefunction(fn):
res = await res
return res

Expand All @@ -452,9 +535,21 @@ async def _do(self, event):
return await self._do_downstream(_termination_obj)
else:
element = self._get_event_or_body(event)
fn_result = await self._call(element)
fn_result = await self._call(element, self._fn)
await self._do_internal(event, fn_result)

async def select_outlets(self, event_body) -> Optional[Collection[str]]:
if self._outlets_selector:
return self._outlets_selector(event_body)
else:
return super().select_outlets(event_body)

def _init(self):
self._create_name_to_outlet = self._outlets_selector is not None or self._method_is_overridden(
"select_outlets", _UnaryFunctionFlow
)
super()._init()


class DropColumns(Flow):
def __init__(self, columns, **kwargs):
Expand Down Expand Up @@ -640,6 +735,11 @@ def __init__(self, long_running=None, **kwargs):
self._long_running = long_running
self._filter = False

def _init(self):
# Ensure _name_to_outlet is built to support set_next_outlets()
self._create_name_to_outlet = True
super()._init()

def filter(self):
# used in the .do() code to signal filtering
self._filter = True
Expand Down
6 changes: 3 additions & 3 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _emit(self, event):
if event is not _termination_obj:
self._raise_on_error(self._ex)

def run(self):
def run(self, visited=None):
"""Starts the flow"""
self._closeables = super().run()

Expand Down Expand Up @@ -710,7 +710,7 @@ async def _emit(self, event):
if event is not _termination_obj:
self._raise_on_error()

def run(self):
def run(self, visited=None):
"""Starts the flow"""
self._closeables = super().run()
loop_task = asyncio.get_running_loop().create_task(self._run_loop_and_log_unexpected_error())
Expand Down Expand Up @@ -753,7 +753,7 @@ def _raise_on_error(self, ex):
raise type(self._ex)("Flow execution terminated") from self._ex
raise self._ex

def run(self):
def run(self, visited=None):
self._closeables = super().run()

self._init()
Expand Down
Loading