Skip to content
28 changes: 26 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,6 +2187,27 @@ def set_event_metadata(event, metadata: dict):
else:
event._metadata = metadata

def _get_subevent_body_in_batched_event(self, result, result_data_index=0):
unexpected_result_type = False
if isinstance(result.data, list):
sub_event_body = result.data[result_data_index]
elif isinstance(result.data, dict):
# error case, set error to all sub events
sub_event_body = result.data
if not "error" in result.data:
unexpected_result_type = True
else:
unexpected_result_type = True
sub_event_body = result.data
if unexpected_result_type:
if self.logger:
self.logger.warn(
f"Got result data for batched event that it is not a list."
f" This result will be set as body for all sub events in the batch."
f" Result data: {result.data}"
)
return sub_event_body

def select_runnables(self, event) -> Optional[Union[list[str], list[ParallelExecutionRunnable]]]:
"""
Given an event, returns a list of runnables (or a list of runnable names) to execute on it. It can also return
Expand Down Expand Up @@ -2298,7 +2319,7 @@ async def _do(self, event):
if is_full_event_batched:
# reconstruct the full event batch
for i, sub_event in enumerate(sub_events_to_modify):
sub_event.body = result.data[i]
sub_event.body = self._get_subevent_body_in_batched_event(result, i)
event.body = sub_events_to_modify
else:
event.body = result.data
Expand All @@ -2312,7 +2333,10 @@ async def _do(self, event):
}
if is_full_event_batched:
for i, sub_event in enumerate(sub_events_to_modify):
sub_event.body = {result.runnable_name: result.data[i] for result in results}
sub_event_body = {}
for result in results:
sub_event_body[result.runnable_name] = self._get_subevent_body_in_batched_event(result, i)
sub_event.body = sub_event_body
event.body = sub_events_to_modify
else:
event.body = {result.runnable_name: result.data for result in results}
Expand Down
Loading
Loading