Skip to content
15 changes: 13 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,7 +2298,11 @@ async def _do(self, event):
if is_full_event_batched:
# reconstruct the full event batch
for i, sub_event in enumerate(sub_events_to_modify):
sub_event.body = result.data[i]
if isinstance(result.data, list):
sub_event.body = result.data[i]
elif isinstance(result.data, dict) and "error" in result.data:
# error case, set error to all sub events
sub_event.body = result.data
event.body = sub_events_to_modify
else:
event.body = result.data
Expand All @@ -2312,7 +2316,14 @@ async def _do(self, event):
}
if is_full_event_batched:
for i, sub_event in enumerate(sub_events_to_modify):
sub_event.body = {result.runnable_name: result.data[i] for result in results}
sub_event_body = {}
for result in results:
if isinstance(result.data, list):
sub_event_body[result.runnable_name] = result.data[i]
else:
Copy link
Collaborator

@royischoss royischoss Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here the case is different and not checking for error and dict? or the comment is not accurate, as discussed use the check above and add else with error raising, plus adding test to check it

# error case, set error to all sub events
sub_event_body[result.runnable_name] = result.data
sub_event.body = sub_event_body
event.body = sub_events_to_modify
else:
event.body = {result.runnable_name: result.data for result in results}
Expand Down
168 changes: 168 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ def run(self, data, path, origin_name=None):
return random_uuid


class RunnableRaiseIfNegative(ParallelExecutionRunnable):
def run(self, data, path, origin_name=None):
if isinstance(data, list):
results = []
for item in data:
if item < 0:
raise ValueError(f"Value {item} is negative!")
results.append(item * 2)
return results
else:
if data < 0:
raise ValueError(f"Value {data} is negative!")
return data * 2


class MyParallelExecution(ParallelExecution):
def select_runnables(self, event):
return ["multiply", "add", "uuid"]
Expand Down Expand Up @@ -2505,6 +2520,159 @@ async def emit_event(i):
previous_batch_number = batch_number


def test_batch_error_handling_single_runnable():
asyncio.run(async_test_batch_error_handling_single_runnable())


@pytest.mark.asyncio
async def async_test_batch_error_handling_single_runnable():
"""Test error handling in batched parallel execution with single runnable."""
flush_after_seconds = 0.3

batch_size = 3
runnables = [RunnableRaiseIfNegative("check_positive", raise_exception=False)]

source = AsyncEmitSource()
batch_step = Batch(max_events=batch_size, full_event=True, flush_after_seconds=flush_after_seconds)
parallel_execution = ParallelExecution(
runnables,
execution_mechanism_by_runnable_name={
"check_positive": "naive",
},
)
reducer = Reduce([], append_and_return)

controller = build_flow(
[
source,
batch_step,
parallel_execution,
FlatMap(fn=lambda x: x.body, full_event=True),
Complete(),
reducer,
]
).run()

async def emit_valid_event(value):
invocation_result = await controller.emit(value)
assert invocation_result == value * 2

time.sleep(flush_after_seconds + 0.2) # Ensure different batch window

async def emit_error_event(value):
invocation_result = await controller.emit(value)
assert invocation_result == {"error": "ValueError: Value -5 is negative!"}

# Emit valid batch first (should succeed)
tasks = [asyncio.create_task(emit_valid_event(v)) for v in [1, 2, 3]] # All positive
await asyncio.gather(*tasks)

# Emit batch with negative value concurrently (should error for all in batch)
tasks = [asyncio.create_task(emit_error_event(v)) for v in [4, -5, 6]] # Middle value is negative
await asyncio.gather(*tasks)

await controller.terminate()
batch_result = await controller.await_termination()

# Should have 6 results total (3 valid + 3 error)
assert len(batch_result) == 6

# First 3 should succeed
assert batch_result[0] == 2
assert batch_result[1] == 4
assert batch_result[2] == 6

# Next 3 should all have error (error propagated to all in batch)
for single_result in batch_result[3:]:
assert single_result == {"error": "ValueError: Value -5 is negative!"}


def test_batch_error_handling_multiple_runnables():
asyncio.run(async_test_batch_error_handling_multiple_runnables())


@pytest.mark.asyncio
async def async_test_batch_error_handling_multiple_runnables():
"""Test error handling in batched parallel execution with multiple runnables."""
flush_after_seconds = 0.3

class RunnableAddTen(ParallelExecutionRunnable):
def run(self, data, path, origin_name=None):
if isinstance(data, list):
return [item + 10 for item in data]
return data + 10

batch_size = 3
runnables = [
RunnableRaiseIfNegative("check_positive", raise_exception=False),
RunnableAddTen("add_ten", raise_exception=False),
]

source = AsyncEmitSource()
batch_step = Batch(max_events=batch_size, full_event=True, flush_after_seconds=flush_after_seconds)
parallel_execution = ParallelExecution(
runnables,
execution_mechanism_by_runnable_name={
"check_positive": "naive",
"add_ten": "naive",
},
)
reducer = Reduce([], append_and_return)

controller = build_flow(
[
source,
batch_step,
parallel_execution,
FlatMap(fn=lambda x: x.body, full_event=True),
Complete(),
reducer,
]
).run()

async def emit_valid_event(value):
invocation_result = await controller.emit(value)
assert invocation_result["check_positive"] == value * 2
assert invocation_result["add_ten"] == value + 10

time.sleep(flush_after_seconds + 0.2) # Ensure different batch window

async def emit_error_event(value):
invocation_result = await controller.emit(value)
# check_positive should have error, but add_ten should still work
assert invocation_result["check_positive"] == {"error": "ValueError: Value -5 is negative!"}
assert invocation_result["add_ten"] == value + 10

# Emit valid batch first (should succeed)
tasks = [asyncio.create_task(emit_valid_event(v)) for v in [1, 2, 3]] # All positive
await asyncio.gather(*tasks)

# Emit batch with negative value concurrently (should error for check_positive, but add_ten works)
tasks = [asyncio.create_task(emit_error_event(v)) for v in [4, -5, 6]] # Middle value is negative
await asyncio.gather(*tasks)

await controller.terminate()
batch_result = await controller.await_termination()

# Should have 6 results total (3 valid + 3 with partial error)
assert len(batch_result) == 6

# First 3 should succeed with both runnables
assert batch_result[0]["check_positive"] == 2
assert batch_result[0]["add_ten"] == 11
assert batch_result[1]["check_positive"] == 4
assert batch_result[1]["add_ten"] == 12
assert batch_result[2]["check_positive"] == 6
assert batch_result[2]["add_ten"] == 13

# Next 3 should have error for check_positive but add_ten should work
for i, single_result in enumerate(batch_result[3:]):
assert single_result["check_positive"] == {"error": "ValueError: Value -5 is negative!"}
# add_ten should still work for values [4, -5, 6]
expected_add_values = [14, 5, 16]
assert single_result["add_ten"] == expected_add_values[i]


async def async_test_write_csv(tmpdir):
file_path = f"{tmpdir}/test_write_csv/out.csv"
controller = build_flow([AsyncEmitSource(), CSVTarget(file_path, columns=["n", "n*10"], header=True)]).run()
Expand Down