diff --git a/storey/flow.py b/storey/flow.py index f7a0fc80..6c5091d2 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -2187,6 +2187,27 @@ def set_event_metadata(event, metadata: dict): else: event._metadata = metadata + def _get_subevent_body_in_batched_event(self, result, result_data_index=0): + unexpected_result_type = False + if isinstance(result.data, list): + sub_event_body = result.data[result_data_index] + elif isinstance(result.data, dict): + # error case, set error to all sub events + sub_event_body = result.data + if "error" not in result.data: + unexpected_result_type = True + else: + unexpected_result_type = True + sub_event_body = result.data + if unexpected_result_type: + if self.logger: + self.logger.warn( + f"Got result data for batched event that it is not a list." + f" This result will be set as body for all sub events in the batch." + f" Result data: {result.data}" + ) + return sub_event_body + def select_runnables(self, event) -> Optional[Union[list[str], list[ParallelExecutionRunnable]]]: """ Given an event, returns a list of runnables (or a list of runnable names) to execute on it. It can also return @@ -2298,7 +2319,7 @@ async def _do(self, event): if is_full_event_batched: # reconstruct the full event batch for i, sub_event in enumerate(sub_events_to_modify): - sub_event.body = result.data[i] + sub_event.body = self._get_subevent_body_in_batched_event(result, i) event.body = sub_events_to_modify else: event.body = result.data @@ -2312,7 +2333,10 @@ async def _do(self, event): } if is_full_event_batched: for i, sub_event in enumerate(sub_events_to_modify): - sub_event.body = {result.runnable_name: result.data[i] for result in results} + sub_event_body = {} + for result in results: + sub_event_body[result.runnable_name] = self._get_subevent_body_in_batched_event(result, i) + sub_event.body = sub_event_body event.body = sub_events_to_modify else: event.body = {result.runnable_name: result.data for result in results} diff --git a/tests/test_flow.py b/tests/test_flow.py index 237c40ff..9689fd6a 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -101,36 +101,6 @@ def raise_ex(self, element): return element -class RunnableMultiplyBy2(ParallelExecutionRunnable): - - def run(self, data, path, origin_name=None): - if isinstance(data, list): - return [sub_value * 2 for sub_value in data] - return data * 2 - - -class RunnableAdd10(ParallelExecutionRunnable): - - def run(self, data, path, origin_name=None): - if isinstance(data, list): - return [sub_value + 10 for sub_value in data] - return data + 10 - - -class RunnableGetRandom(ParallelExecutionRunnable): - - def run(self, data, path, origin_name=None): - random_uuid = str(uuid.uuid4()) - if isinstance(data, list): - return [random_uuid] * len(data) - return random_uuid - - -class MyParallelExecution(ParallelExecution): - def select_runnables(self, event): - return ["multiply", "add", "uuid"] - - def test_functional_flow(): controller = build_flow( [ @@ -2364,147 +2334,6 @@ def reduce_fn(acc, x): assert termination_result == [[0, 1, 2], [3, 4, 5, 6], [7, 8, 9]] -def test_basic_batch_with_parallel_execution(): - asyncio.run(async_test_basic_batch_with_parallel_execution()) - - -async def async_test_basic_batch_with_parallel_execution(): - """Test that Batch step with full_event=True works correctly with ParallelExecution.""" - batch_size = 3 - number_of_events = 10 - - runnables = [ - RunnableMultiplyBy2("multiply"), - RunnableAdd10("add"), - RunnableGetRandom("uuid"), - ] - parallel_execution = MyParallelExecution( - runnables, - execution_mechanism_by_runnable_name={ - "multiply": "naive", - "add": "naive", - "uuid": "naive", - }, - ) - controller = build_flow( - [ - AsyncEmitSource(), - Batch(max_events=batch_size, full_event=True, flush_after_seconds=2), - parallel_execution, - FlatMap(fn=lambda x: x.body, full_event=True), - Complete(), - Reduce(initial_value=[], fn=lambda acc, x: append_and_return(acc, x)), - ] - ).run() - - async def emit_event(i): - result = await controller.emit(i) - # Verify each event has the expected fields after parallel execution - assert "add" in result - assert "multiply" in result - assert "uuid" in result - assert result["add"] == 10 + i - assert result["multiply"] == i * 2 - - # Emit events in parallel using asyncio - try: - tasks = [asyncio.create_task(emit_event(i)) for i in range(number_of_events)] - await asyncio.gather(*tasks) - finally: - await controller.terminate() - termination_result = await controller.await_termination() - assert len(termination_result) == number_of_events - - previous_batch_number = -1 - expected_uuid = "" - for i in range(number_of_events): - batch_number = math.floor(i / batch_size) - if previous_batch_number == -1 or batch_number != previous_batch_number: - expected_uuid = termination_result[i]["uuid"] - else: - assert termination_result[i]["uuid"] == expected_uuid - previous_batch_number = batch_number - - -def test_batch_with_parallel_execution_split(): - asyncio.run(async_test_batch_with_parallel_execution_split()) - - -async def async_test_batch_with_parallel_execution_split(): - """Test batched ParallelExecution with splitting to multiple outlets (len(outlets) > 1 path).""" - batch_size = 3 - number_of_events = 10 - - runnables = [ - RunnableMultiplyBy2("multiply"), - RunnableAdd10("add"), - RunnableGetRandom("uuid"), - ] - - source = AsyncEmitSource() - batch_step = Batch(max_events=batch_size, full_event=True, flush_after_seconds=2) - parallel_execution = MyParallelExecution( - runnables, - execution_mechanism_by_runnable_name={ - "multiply": "naive", - "add": "naive", - "uuid": "naive", - }, - ) - - flat_map1 = FlatMap(fn=lambda x: x.body, full_event=True) - flat_map2 = FlatMap(fn=lambda x: x.body, full_event=True) - complete = Complete() - reducer = Reduce([], lambda acc, x: append_and_return(acc, x)) - - source.to(batch_step).to(parallel_execution) - parallel_execution.to(flat_map1).to(complete).to(reducer) - parallel_execution.to(flat_map2).to(reducer) - - controller = source.run() - - async def emit_event(i): - result = await controller.emit(i) - # Each event should get results from both branches (2 completions) - assert len(result) == 3 - assert result["add"] == 10 + i - assert result["multiply"] == i * 2 - assert "uuid" in result - return result - - try: - tasks = [asyncio.create_task(emit_event(i)) for i in range(number_of_events)] - await asyncio.gather(*tasks) - finally: - await controller.terminate() - termination_result = await controller.await_termination() - - # The final result should contain a duplicated value since the flow is split into two branches - expected_number_of_events = number_of_events * 2 - assert len(termination_result) == expected_number_of_events - - # Sort by value to ensure correct order after split - termination_result = sorted(termination_result, key=lambda x: (x["add"])) - - previous_batch_number = -1 - expected_uuid = "" - # because of the split, the batch size of the results is doubled - batch_size = batch_size * 2 - for i in range(expected_number_of_events): - # because of the split, we expect alternating add/multiply values every two items - fixed_index = math.floor(i / 2) - expected_add = 10 + fixed_index - expected_multiply = fixed_index * 2 - assert termination_result[i]["add"] == expected_add - assert termination_result[i]["multiply"] == expected_multiply - batch_number = math.floor(i / batch_size) - if previous_batch_number == -1 or batch_number != previous_batch_number: - expected_uuid = termination_result[i]["uuid"] - else: - assert termination_result[i]["uuid"] == expected_uuid - previous_batch_number = batch_number - - async def async_test_write_csv(tmpdir): file_path = f"{tmpdir}/test_write_csv/out.csv" controller = build_flow([AsyncEmitSource(), CSVTarget(file_path, columns=["n", "n*10"], header=True)]).run() @@ -6040,3 +5869,366 @@ def test_nosql_target_reuse_resets_closeables(): finally: controller.terminate() controller.await_termination() + + +class TestBatchWithParallelExecution: + class RunnableMultiplyBy2(ParallelExecutionRunnable): + + def run(self, data, path, origin_name=None): + if isinstance(data, list): + return [sub_value * 2 for sub_value in data] + return data * 2 + + class RunnableAdd10(ParallelExecutionRunnable): + + def run(self, data, path, origin_name=None): + if isinstance(data, list): + return [sub_value + 10 for sub_value in data] + return data + 10 + + class RunnableGetRandom(ParallelExecutionRunnable): + + def run(self, data, path, origin_name=None): + random_uuid = str(uuid.uuid4()) + if isinstance(data, list): + return [random_uuid] * len(data) + return random_uuid + + class RunnableRaiseIfNegative(ParallelExecutionRunnable): + def run(self, data, path, origin_name=None): + if isinstance(data, list): + results = [] + for item in data: + if item < 0: + raise ValueError(f"Value {item} is negative!") + results.append(item * 2) + return results + else: + if data < 0: + raise ValueError(f"Value {data} is negative!") + return data * 2 + + class RunnableReturnWrongType(ParallelExecutionRunnable): + def run(self, data, path, origin_name=None): + if isinstance(data, list): + # Wrong! Should return a list, not a dict + return {"result": "wrong_type"} + return data + + class MyParallelExecution(ParallelExecution): + def select_runnables(self, event): + return ["multiply", "add", "uuid"] + + def test_basic_batch_with_parallel_execution(self): + asyncio.run(self.async_test_basic_batch_with_parallel_execution()) + + async def async_test_basic_batch_with_parallel_execution(self): + """Test that Batch step with full_event=True works correctly with ParallelExecution.""" + batch_size = 3 + number_of_events = 10 + + runnables = [ + self.RunnableMultiplyBy2("multiply"), + self.RunnableAdd10("add"), + self.RunnableGetRandom("uuid"), + ] + parallel_execution = self.MyParallelExecution( + runnables, + execution_mechanism_by_runnable_name={ + "multiply": "naive", + "add": "naive", + "uuid": "naive", + }, + ) + controller = build_flow( + [ + AsyncEmitSource(), + Batch(max_events=batch_size, full_event=True, flush_after_seconds=2), + parallel_execution, + FlatMap(fn=lambda x: x.body, full_event=True), + Complete(), + Reduce(initial_value=[], fn=lambda acc, x: append_and_return(acc, x)), + ] + ).run() + + async def emit_event(i): + result = await controller.emit(i) + # Verify each event has the expected fields after parallel execution + assert "add" in result + assert "multiply" in result + assert "uuid" in result + assert result["add"] == 10 + i + assert result["multiply"] == i * 2 + + # Emit events in parallel using asyncio + try: + await self._emit_batch_concurrently(emit_event, range(number_of_events)) + finally: + await controller.terminate() + termination_result = await controller.await_termination() + assert len(termination_result) == number_of_events + + previous_batch_number = -1 + expected_uuid = "" + for i in range(number_of_events): + batch_number = math.floor(i / batch_size) + if previous_batch_number == -1 or batch_number != previous_batch_number: + expected_uuid = termination_result[i]["uuid"] + else: + assert termination_result[i]["uuid"] == expected_uuid + previous_batch_number = batch_number + + def test_batch_with_parallel_execution_split(self): + asyncio.run(self.async_test_batch_with_parallel_execution_split()) + + async def async_test_batch_with_parallel_execution_split(self): + """Test batched ParallelExecution with splitting to multiple outlets (len(outlets) > 1 path).""" + batch_size = 3 + number_of_events = 10 + + runnables = [ + self.RunnableMultiplyBy2("multiply"), + self.RunnableAdd10("add"), + self.RunnableGetRandom("uuid"), + ] + + source = AsyncEmitSource() + batch_step = Batch(max_events=batch_size, full_event=True, flush_after_seconds=2) + parallel_execution = self.MyParallelExecution( + runnables, + execution_mechanism_by_runnable_name={ + "multiply": "naive", + "add": "naive", + "uuid": "naive", + }, + ) + + flat_map1 = FlatMap(fn=lambda x: x.body, full_event=True) + flat_map2 = FlatMap(fn=lambda x: x.body, full_event=True) + complete = Complete() + reducer = Reduce([], lambda acc, x: append_and_return(acc, x)) + + source.to(batch_step).to(parallel_execution) + parallel_execution.to(flat_map1).to(complete).to(reducer) + parallel_execution.to(flat_map2).to(reducer) + + controller = source.run() + + async def emit_event(i): + result = await controller.emit(i) + # Each event should get results from both branches (2 completions) + assert len(result) == 3 + assert result["add"] == 10 + i + assert result["multiply"] == i * 2 + assert "uuid" in result + return result + + try: + await self._emit_batch_concurrently(emit_event, range(number_of_events)) + finally: + await controller.terminate() + termination_result = await controller.await_termination() + + # The final result should contain a duplicated value since the flow is split into two branches + expected_number_of_events = number_of_events * 2 + assert len(termination_result) == expected_number_of_events + + # Sort by value to ensure correct order after split + termination_result = sorted(termination_result, key=lambda x: (x["add"])) + + previous_batch_number = -1 + expected_uuid = "" + # because of the split, the batch size of the results is doubled + batch_size = batch_size * 2 + for i in range(expected_number_of_events): + # because of the split, we expect alternating add/multiply values every two items + fixed_index = math.floor(i / 2) + expected_add = 10 + fixed_index + expected_multiply = fixed_index * 2 + assert termination_result[i]["add"] == expected_add + assert termination_result[i]["multiply"] == expected_multiply + batch_number = math.floor(i / batch_size) + if previous_batch_number == -1 or batch_number != previous_batch_number: + expected_uuid = termination_result[i]["uuid"] + else: + assert termination_result[i]["uuid"] == expected_uuid + previous_batch_number = batch_number + + def _create_error_handling_flow(self, runnables, batch_size=3, flush_after_seconds=0.3): + """Helper to create flow for error handling tests.""" + source = AsyncEmitSource() + batch_step = Batch(max_events=batch_size, full_event=True, flush_after_seconds=flush_after_seconds) + + execution_mechanism_by_runnable_name = {runnable.name: "naive" for runnable in runnables} + parallel_execution = ParallelExecution( + runnables, execution_mechanism_by_runnable_name=execution_mechanism_by_runnable_name + ) + reducer = Reduce([], append_and_return) + + return build_flow( + [ + source, + batch_step, + parallel_execution, + FlatMap(fn=lambda x: x.body, full_event=True), + Complete(), + reducer, + ] + ).run() + + async def _emit_batch_concurrently(self, emit_fn, values): + """Helper to emit multiple values concurrently.""" + tasks = [asyncio.create_task(emit_fn(v)) for v in values] + await asyncio.gather(*tasks) + + def test_batch_error_handling_single_runnable(self): + asyncio.run(self.async_test_batch_error_handling_single_runnable()) + + async def async_test_batch_error_handling_single_runnable(self): + """Test error handling in batched parallel execution with single runnable.""" + + async def emit_valid_event(value): + invocation_result = await controller.emit(value) + assert invocation_result == value * 2 + + async def emit_error_event(value): + invocation_result = await controller.emit(value) + assert invocation_result == {"error": "ValueError: Value -5 is negative!"} + + flush_after_seconds = 0.3 + runnables = [self.RunnableRaiseIfNegative("check_positive", raise_exception=False)] + controller = self._create_error_handling_flow(runnables, flush_after_seconds=flush_after_seconds) + + # Emit valid batch first (should succeed) + await self._emit_batch_concurrently(emit_valid_event, [1, 2, 3]) + + time.sleep(flush_after_seconds + 0.2) # Ensure different batch window + + # Emit batch with negative value concurrently (should error for all in batch) + await self._emit_batch_concurrently(emit_error_event, [4, -5, 6]) + + await controller.terminate() + batch_result = await controller.await_termination() + + # Should have 6 results total (3 valid + 3 error) + assert len(batch_result) == 6 + + # First 3 should succeed + assert batch_result[0] == 2 + assert batch_result[1] == 4 + assert batch_result[2] == 6 + + # Next 3 should all have error (error propagated to all in batch) + for single_result in batch_result[3:]: + assert single_result == {"error": "ValueError: Value -5 is negative!"} + + def test_batch_error_handling_multiple_runnables(self): + asyncio.run(self.async_test_batch_error_handling_multiple_runnables()) + + @pytest.mark.asyncio + async def async_test_batch_error_handling_multiple_runnables(self): + """Test error handling in batched parallel execution with multiple runnables.""" + flush_after_seconds = 0.3 + + class RunnableAddTen(ParallelExecutionRunnable): + def run(self, data, path, origin_name=None): + if isinstance(data, list): + return [item + 10 for item in data] + return data + 10 + + async def emit_error_event(value): + invocation_result = await controller.emit(value) + # check_positive should have error, but add_ten should still work + assert invocation_result["check_positive"] == {"error": "ValueError: Value -5 is negative!"} + assert invocation_result["add_ten"] == value + 10 + + async def emit_valid_event(value): + invocation_result = await controller.emit(value) + assert invocation_result["check_positive"] == value * 2 + assert invocation_result["add_ten"] == value + 10 + + runnables = [ + self.RunnableRaiseIfNegative("check_positive", raise_exception=False), + RunnableAddTen("add_ten", raise_exception=False), + ] + controller = self._create_error_handling_flow(runnables, flush_after_seconds=flush_after_seconds) + # Emit valid batch first (should succeed) + await self._emit_batch_concurrently(emit_valid_event, [1, 2, 3]) + + time.sleep(flush_after_seconds + 0.2) # Ensure different batch window + # Emit batch with negative value concurrently (should error for check_positive, but add_ten works) + await self._emit_batch_concurrently(emit_error_event, [4, -5, 6]) + + await controller.terminate() + batch_result = await controller.await_termination() + + # Should have 6 results total (3 valid + 3 with partial error) + assert len(batch_result) == 6 + + # First 3 should succeed with both runnables + assert batch_result[0]["check_positive"] == 2 + assert batch_result[0]["add_ten"] == 11 + assert batch_result[1]["check_positive"] == 4 + assert batch_result[1]["add_ten"] == 12 + assert batch_result[2]["check_positive"] == 6 + assert batch_result[2]["add_ten"] == 13 + + # Next 3 should have error for check_positive but add_ten should work + for i, single_result in enumerate(batch_result[3:]): + assert single_result["check_positive"] == {"error": "ValueError: Value -5 is negative!"} + # add_ten should still work for values [4, -5, 6] + expected_add_values = [14, 5, 16] + assert single_result["add_ten"] == expected_add_values[i] + + def test_batch_unexpected_return_type_single_runnable(self): + asyncio.run(self.async_test_batch_unexpected_return_type_single_runnable()) + + async def async_test_batch_unexpected_return_type_single_runnable(self): + flush_after_seconds = 0.3 + runnables = [self.RunnableReturnWrongType("wrong_type", raise_exception=False)] + controller = self._create_error_handling_flow(runnables, flush_after_seconds=flush_after_seconds) + + async def emit_event(value): + invocation_result = await controller.emit(value) + # All events in batch should get the same dict result + assert invocation_result == {"result": "wrong_type"} + + # Emit batch that will trigger the warning (returning dict instead of list) + await self._emit_batch_concurrently(emit_event, [1, 2, 3]) + + await controller.terminate() + batch_result = await controller.await_termination() + + # All 3 should have the same dict body (warning was logged, but flow continued) + assert len(batch_result) == 3 + for single_result in batch_result: + assert single_result == {"result": "wrong_type"} + + def test_batch_unexpected_return_type_multiple_runnables(self): + asyncio.run(self.async_test_batch_unexpected_return_type_multiple_runnables()) + + async def async_test_batch_unexpected_return_type_multiple_runnables(self): + flush_after_seconds = 0.3 + runnables = [ + self.RunnableReturnWrongType("wrong_type", raise_exception=False), + self.RunnableAdd10("add_ten", raise_exception=False), + ] + controller = self._create_error_handling_flow(runnables, flush_after_seconds=flush_after_seconds) + + async def emit_event(value): + invocation_result = await controller.emit(value) + # wrong_type should return same dict for all, add_ten should work correctly + assert invocation_result["wrong_type"] == {"result": "wrong_type"} + assert invocation_result["add_ten"] == value + 10 + + # Emit batch that will trigger the warning (returning dict instead of list) + await self._emit_batch_concurrently(emit_event, [1, 2, 3]) + + await controller.terminate() + batch_result = await controller.await_termination() + + # All 3 should have the same dict for wrong_type, correct values for add_ten + assert len(batch_result) == 3 + for single_result in batch_result: + assert single_result["wrong_type"] == {"result": "wrong_type"} + assert single_result["add_ten"] in [11, 12, 13]