Skip to content

Commit ace4bab

Browse files
finished async_test_batch_error_handling_multiple_runnables
1 parent f8f82fe commit ace4bab

File tree

1 file changed

+120
-27
lines changed

1 file changed

+120
-27
lines changed

tests/test_flow.py

Lines changed: 120 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,21 @@ def run(self, data, path, origin_name=None):
126126
return random_uuid
127127

128128

129+
class RunnableRaiseIfNegative(ParallelExecutionRunnable):
130+
def run(self, data, path, origin_name=None):
131+
if isinstance(data, list):
132+
results = []
133+
for item in data:
134+
if item < 0:
135+
raise ValueError(f"Value {item} is negative!")
136+
results.append(item * 2)
137+
return results
138+
else:
139+
if data < 0:
140+
raise ValueError(f"Value {data} is negative!")
141+
return data * 2
142+
143+
129144
class MyParallelExecution(ParallelExecution):
130145
def select_runnables(self, event):
131146
return ["multiply", "add", "uuid"]
@@ -2504,26 +2519,15 @@ async def emit_event(i):
25042519
assert termination_result[i]["uuid"] == expected_uuid
25052520
previous_batch_number = batch_number
25062521

2522+
25072523
def test_batch_error_handling_single_runnable():
25082524
asyncio.run(async_test_batch_error_handling_single_runnable())
25092525

2526+
25102527
@pytest.mark.asyncio
25112528
async def async_test_batch_error_handling_single_runnable():
25122529
"""Test error handling in batched parallel execution with single runnable."""
25132530
flush_after_seconds = 0.3
2514-
class RunnableRaiseIfNegative(ParallelExecutionRunnable):
2515-
def run(self, data, path, origin_name=None):
2516-
if isinstance(data, list):
2517-
results = []
2518-
for item in data:
2519-
if item < 0:
2520-
raise ValueError(f"Value {item} is negative!")
2521-
results.append(item * 2)
2522-
return results
2523-
else:
2524-
if data < 0:
2525-
raise ValueError(f"Value {data} is negative!")
2526-
return data * 2
25272531

25282532
batch_size = 3
25292533
runnables = [RunnableRaiseIfNegative("check_positive", raise_exception=False)]
@@ -2535,26 +2539,29 @@ def run(self, data, path, origin_name=None):
25352539
execution_mechanism_by_runnable_name={
25362540
"check_positive": "naive",
25372541
},
2538-
25392542
)
25402543
reducer = Reduce([], append_and_return)
25412544

2542-
controller = build_flow([
2543-
source,
2544-
batch_step,
2545-
parallel_execution,
2546-
FlatMap(fn=lambda x: x.body, full_event=True),
2547-
Complete(),
2548-
reducer,
2549-
]).run()
2545+
controller = build_flow(
2546+
[
2547+
source,
2548+
batch_step,
2549+
parallel_execution,
2550+
FlatMap(fn=lambda x: x.body, full_event=True),
2551+
Complete(),
2552+
reducer,
2553+
]
2554+
).run()
25502555

25512556
async def emit_valid_event(value):
25522557
invocation_result = await controller.emit(value)
25532558
assert invocation_result == value * 2
2559+
25542560
time.sleep(flush_after_seconds + 0.2) # Ensure different batch window
2561+
25552562
async def emit_error_event(value):
25562563
invocation_result = await controller.emit(value)
2557-
assert invocation_result == {'error': 'ValueError: Value -5 is negative!'}
2564+
assert invocation_result == {"error": "ValueError: Value -5 is negative!"}
25582565

25592566
# Emit valid batch first (should succeed)
25602567
tasks = [asyncio.create_task(emit_valid_event(v)) for v in [1, 2, 3]] # All positive
@@ -2571,13 +2578,99 @@ async def emit_error_event(value):
25712578
assert len(batch_result) == 6
25722579

25732580
# First 3 should succeed
2574-
assert batch_result[0] == 2 # 1 * 2
2575-
assert batch_result[1] == 4 # 2 * 2
2576-
assert batch_result[2] == 6 # 3 * 2
2581+
assert batch_result[0] == 2
2582+
assert batch_result[1] == 4
2583+
assert batch_result[2] == 6
25772584

25782585
# Next 3 should all have error (error propagated to all in batch)
25792586
for single_result in batch_result[3:]:
2580-
assert single_result == {'error': 'ValueError: Value -5 is negative!'}
2587+
assert single_result == {"error": "ValueError: Value -5 is negative!"}
2588+
2589+
2590+
def test_batch_error_handling_multiple_runnables():
2591+
asyncio.run(async_test_batch_error_handling_multiple_runnables())
2592+
2593+
2594+
@pytest.mark.asyncio
2595+
async def async_test_batch_error_handling_multiple_runnables():
2596+
"""Test error handling in batched parallel execution with multiple runnables."""
2597+
flush_after_seconds = 0.3
2598+
2599+
class RunnableAddTen(ParallelExecutionRunnable):
2600+
def run(self, data, path, origin_name=None):
2601+
if isinstance(data, list):
2602+
return [item + 10 for item in data]
2603+
return data + 10
2604+
2605+
batch_size = 3
2606+
runnables = [
2607+
RunnableRaiseIfNegative("check_positive", raise_exception=False),
2608+
RunnableAddTen("add_ten", raise_exception=False),
2609+
]
2610+
2611+
source = AsyncEmitSource()
2612+
batch_step = Batch(max_events=batch_size, full_event=True, flush_after_seconds=flush_after_seconds)
2613+
parallel_execution = ParallelExecution(
2614+
runnables,
2615+
execution_mechanism_by_runnable_name={
2616+
"check_positive": "naive",
2617+
"add_ten": "naive",
2618+
},
2619+
)
2620+
reducer = Reduce([], append_and_return)
2621+
2622+
controller = build_flow(
2623+
[
2624+
source,
2625+
batch_step,
2626+
parallel_execution,
2627+
FlatMap(fn=lambda x: x.body, full_event=True),
2628+
Complete(),
2629+
reducer,
2630+
]
2631+
).run()
2632+
2633+
async def emit_valid_event(value):
2634+
invocation_result = await controller.emit(value)
2635+
assert invocation_result["check_positive"] == value * 2
2636+
assert invocation_result["add_ten"] == value + 10
2637+
2638+
time.sleep(flush_after_seconds + 0.2) # Ensure different batch window
2639+
2640+
async def emit_error_event(value):
2641+
invocation_result = await controller.emit(value)
2642+
# check_positive should have error, but add_ten should still work
2643+
assert invocation_result["check_positive"] == {"error": "ValueError: Value -5 is negative!"}
2644+
assert invocation_result["add_ten"] == value + 10
2645+
2646+
# Emit valid batch first (should succeed)
2647+
tasks = [asyncio.create_task(emit_valid_event(v)) for v in [1, 2, 3]] # All positive
2648+
await asyncio.gather(*tasks)
2649+
2650+
# Emit batch with negative value concurrently (should error for check_positive, but add_ten works)
2651+
tasks = [asyncio.create_task(emit_error_event(v)) for v in [4, -5, 6]] # Middle value is negative
2652+
await asyncio.gather(*tasks)
2653+
2654+
await controller.terminate()
2655+
batch_result = await controller.await_termination()
2656+
2657+
# Should have 6 results total (3 valid + 3 with partial error)
2658+
assert len(batch_result) == 6
2659+
2660+
# First 3 should succeed with both runnables
2661+
assert batch_result[0]["check_positive"] == 2 # 1 * 2
2662+
assert batch_result[0]["add_ten"] == 11 # 1 + 10
2663+
assert batch_result[1]["check_positive"] == 4 # 2 * 2
2664+
assert batch_result[1]["add_ten"] == 12 # 2 + 10
2665+
assert batch_result[2]["check_positive"] == 6 # 3 * 2
2666+
assert batch_result[2]["add_ten"] == 13 # 3 + 10
2667+
2668+
# Next 3 should have error for check_positive but add_ten should work
2669+
for i, single_result in enumerate(batch_result[3:]):
2670+
assert single_result["check_positive"] == {"error": "ValueError: Value -5 is negative!"}
2671+
# add_ten should still work for values [4, -5, 6]
2672+
expected_add_values = [14, 5, 16]
2673+
assert single_result["add_ten"] == expected_add_values[i]
25812674

25822675

25832676
async def async_test_write_csv(tmpdir):

0 commit comments

Comments
 (0)