Skip to content

Commit f07fa81

Browse files
Merge branch 'upstream-development' into support_mrs_batch
2 parents ac3999e + 0bc0410 commit f07fa81

File tree

2 files changed

+152
-112
lines changed

2 files changed

+152
-112
lines changed

storey/flow.py

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,53 @@ def _static_run(*args, **kwargs):
18701870
return _sval._run(*args, **kwargs)
18711871

18721872

1873+
def _streaming_run_wrapper(
1874+
runnable: ParallelExecutionRunnable,
1875+
input,
1876+
path: str,
1877+
origin_name: Optional[str],
1878+
queue: multiprocessing.Queue,
1879+
) -> None:
1880+
"""Wrapper that runs a streaming runnable and sends chunks via queue.
1881+
1882+
This function runs in a child process and iterates over the generator,
1883+
sending each chunk through the multiprocessing queue.
1884+
"""
1885+
try:
1886+
for chunk in runnable._run(input, path, origin_name):
1887+
queue.put(("chunk", chunk))
1888+
queue.put(("done", None))
1889+
except Exception as e:
1890+
queue.put(("error", (type(e).__name__, str(e), traceback.format_exc())))
1891+
1892+
1893+
def _static_streaming_run(input, path: str, origin_name: Optional[str], queue: multiprocessing.Queue) -> None:
1894+
"""Streaming wrapper for dedicated_process using the global runnable."""
1895+
global _sval
1896+
_streaming_run_wrapper(_sval, input, path, origin_name, queue)
1897+
1898+
1899+
async def _async_read_streaming_queue(
1900+
queue: multiprocessing.Queue, loop: Optional[asyncio.AbstractEventLoop] = None
1901+
) -> AsyncGenerator:
1902+
"""Async generator that reads chunks from a multiprocessing queue.
1903+
1904+
This runs in the parent process and yields chunks sent by the child process,
1905+
without blocking the asyncio event loop.
1906+
"""
1907+
loop = asyncio.get_running_loop()
1908+
while True:
1909+
# Use run_in_executor to avoid blocking the event loop
1910+
msg_type, payload = await loop.run_in_executor(None, queue.get)
1911+
if msg_type == "chunk":
1912+
yield payload
1913+
elif msg_type == "done":
1914+
break
1915+
elif msg_type == "error":
1916+
exc_type, exc_msg, exc_tb = payload
1917+
raise RuntimeError(f"{exc_type}: {exc_msg}\n\nOriginal traceback:\n{exc_tb}")
1918+
1919+
18731920
class RunnableExecutor:
18741921
"""
18751922
Manages and executes `ParallelExecutionRunnable` instances using various parallel execution mechanisms.
@@ -1897,6 +1944,8 @@ def __init__(
18971944
self._process_executor_by_runnable_name = {}
18981945
self._mp_context = multiprocessing.get_context("spawn")
18991946
self._executors = {}
1947+
self._is_streaming_by_runnable_name: dict[str, bool] = {}
1948+
self._manager = None # Lazy-initialized multiprocessing Manager for queue-based streaming
19001949

19011950
def add_runnable(self, runnable: ParallelExecutionRunnable, execution_mechanism: str) -> None:
19021951
"""
@@ -1933,15 +1982,9 @@ def init_runnable(self, runnable: Union[str, ParallelExecutionRunnable]) -> None
19331982

19341983
execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]
19351984

1936-
# Check for streaming + process-based execution (incompatible combination)
1937-
if execution_mechanism in ParallelExecutionMechanisms.process():
1938-
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
1939-
if is_streaming:
1940-
raise StreamingError(
1941-
f"Streaming is not supported with process-based execution mechanisms. "
1942-
f"Runnable '{runnable.name}' uses '{execution_mechanism}'. "
1943-
f"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1944-
)
1985+
# Record whether this runnable is a streaming runnable (generator function)
1986+
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
1987+
self._is_streaming_by_runnable_name[runnable.name] = is_streaming
19451988

19461989
if execution_mechanism == ParallelExecutionMechanisms.process_pool:
19471990
self.num_processes += 1
@@ -1976,6 +2019,12 @@ def init_executors(self):
19762019
if num_threads:
19772020
self._executors[ParallelExecutionMechanisms.thread_pool] = ThreadPoolExecutor(max_workers=num_threads)
19782021

2022+
def _get_manager(self):
2023+
"""Get or create the multiprocessing Manager for queue-based streaming."""
2024+
if self._manager is None:
2025+
self._manager = self._mp_context.Manager()
2026+
return self._manager
2027+
19792028
def run_executor(
19802029
self,
19812030
runnable: Union[ParallelExecutionRunnable, str],
@@ -2005,36 +2054,49 @@ def run_executor(
20052054
raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'")
20062055

20072056
execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]
2057+
is_streaming = self._is_streaming_by_runnable_name.get(runnable.name, False)
20082058

20092059
input = (
20102060
event.body if execution_mechanism in ParallelExecutionMechanisms.process() else copy.deepcopy(event.body)
20112061
)
20122062

2063+
loop = asyncio.get_running_loop()
2064+
20132065
if execution_mechanism == ParallelExecutionMechanisms.asyncio:
2014-
future = asyncio.get_running_loop().create_task(
2015-
runnable._async_run(input, event.path, origin_runnable_name)
2016-
)
2066+
future = loop.create_task(runnable._async_run(input, event.path, origin_runnable_name))
20172067
elif execution_mechanism == ParallelExecutionMechanisms.naive:
2018-
future = asyncio.get_running_loop().create_future()
2068+
future = loop.create_future()
20192069
future.set_result(runnable._run(input, event.path, origin_runnable_name))
2020-
elif execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2021-
executor = self._process_executor_by_runnable_name[runnable.name]
2022-
future = asyncio.get_running_loop().run_in_executor(
2023-
executor,
2024-
_static_run,
2025-
input,
2026-
event.path,
2027-
origin_runnable_name,
2028-
)
2070+
elif execution_mechanism in ParallelExecutionMechanisms.process():
2071+
# Get the appropriate executor for this process mechanism
2072+
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2073+
executor = self._process_executor_by_runnable_name[runnable.name]
2074+
else: # process_pool
2075+
executor = self._executors[execution_mechanism]
2076+
2077+
if is_streaming:
2078+
# Use Manager's queue for cross-process streaming (regular queues can't be passed to executor)
2079+
queue = self._get_manager().Queue()
2080+
# Use appropriate streaming function based on mechanism
2081+
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2082+
loop.run_in_executor(
2083+
executor, _static_streaming_run, input, event.path, origin_runnable_name, queue
2084+
)
2085+
else:
2086+
loop.run_in_executor(
2087+
executor, _streaming_run_wrapper, runnable, input, event.path, origin_runnable_name, queue
2088+
)
2089+
future = loop.create_future()
2090+
future.set_result(_async_read_streaming_queue(queue))
2091+
else:
2092+
# Use appropriate run function based on mechanism
2093+
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2094+
future = loop.run_in_executor(executor, _static_run, input, event.path, origin_runnable_name)
2095+
else:
2096+
future = loop.run_in_executor(executor, runnable._run, input, event.path, origin_runnable_name)
20292097
else:
20302098
executor = self._executors[execution_mechanism]
2031-
future = asyncio.get_running_loop().run_in_executor(
2032-
executor,
2033-
runnable._run,
2034-
input,
2035-
event.path,
2036-
origin_runnable_name,
2037-
)
2099+
future = loop.run_in_executor(executor, runnable._run, input, event.path, origin_runnable_name)
20382100
return future
20392101

20402102

tests/test_streaming.py

Lines changed: 61 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ async def run_async(self, body, path: str, origin_name: Optional[str] = None) ->
5454
yield f"{body}_chunk_{i}"
5555

5656

57+
class ErrorStreamingRunnable(ParallelExecutionRunnable):
58+
"""A streaming runnable that yields one chunk then raises an error."""
59+
60+
def run(self, body, path: str, origin_name: Optional[str] = None) -> Generator:
61+
yield f"{body}_chunk_0"
62+
raise ValueError("Simulated streaming error")
63+
64+
5765
class TestStreamingPrimitives:
5866
"""Tests for streaming primitive classes."""
5967

@@ -1090,29 +1098,6 @@ def double(x):
10901098
class TestParallelExecutionStreaming:
10911099
"""Tests for ParallelExecution streaming support."""
10921100

1093-
def test_parallel_execution_single_runnable_streaming(self):
1094-
"""Test streaming with a single runnable."""
1095-
runnable = StreamingRunnable(name="streamer")
1096-
controller = build_flow(
1097-
[
1098-
SyncEmitSource(),
1099-
ParallelExecution(
1100-
runnables=[runnable],
1101-
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.naive},
1102-
),
1103-
Complete(),
1104-
]
1105-
).run()
1106-
1107-
try:
1108-
awaitable = controller.emit("test")
1109-
result = awaitable.await_result()
1110-
assert inspect.isgenerator(result)
1111-
assert list(result) == ["test_chunk_0", "test_chunk_1", "test_chunk_2"]
1112-
finally:
1113-
controller.terminate()
1114-
controller.await_termination()
1115-
11161101
def test_parallel_execution_async_runnable_streaming(self):
11171102
"""Test streaming with an async runnable."""
11181103
runnable = AsyncStreamingRunnable(name="async_streamer")
@@ -1190,15 +1175,24 @@ async def _test():
11901175

11911176
asyncio.run(_test())
11921177

1193-
def test_parallel_execution_streaming_with_thread_pool(self):
1194-
"""Test streaming works with thread_pool execution mechanism."""
1178+
@pytest.mark.parametrize(
1179+
"execution_mechanism",
1180+
[
1181+
ParallelExecutionMechanisms.naive,
1182+
ParallelExecutionMechanisms.thread_pool,
1183+
ParallelExecutionMechanisms.process_pool,
1184+
ParallelExecutionMechanisms.dedicated_process,
1185+
],
1186+
)
1187+
def test_parallel_execution_streaming_with_executor(self, execution_mechanism):
1188+
"""Test streaming works with various execution mechanisms."""
11951189
runnable = StreamingRunnable(name="streamer")
11961190
controller = build_flow(
11971191
[
11981192
SyncEmitSource(),
11991193
ParallelExecution(
12001194
runnables=[runnable],
1201-
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.thread_pool},
1195+
execution_mechanism_by_runnable_name={"streamer": execution_mechanism},
12021196
),
12031197
Complete(),
12041198
]
@@ -1213,14 +1207,20 @@ def test_parallel_execution_streaming_with_thread_pool(self):
12131207
controller.terminate()
12141208
controller.await_termination()
12151209

1216-
def test_parallel_execution_streaming_with_shared_executor_thread_based(self):
1217-
"""Test streaming works with shared_executor when the shared executor uses threads."""
1218-
# Create a shared executor with a thread-based runnable
1210+
@pytest.mark.parametrize(
1211+
"execution_mechanism",
1212+
[
1213+
ParallelExecutionMechanisms.thread_pool,
1214+
ParallelExecutionMechanisms.process_pool,
1215+
ParallelExecutionMechanisms.dedicated_process,
1216+
],
1217+
)
1218+
def test_parallel_execution_streaming_with_shared_executor(self, execution_mechanism):
1219+
"""Test streaming works with shared_executor using different underlying mechanisms."""
12191220
shared_executor = RunnableExecutor()
12201221
shared_runnable = StreamingRunnable(name="shared_streamer")
1221-
shared_executor.add_runnable(shared_runnable, ParallelExecutionMechanisms.thread_pool)
1222+
shared_executor.add_runnable(shared_runnable, execution_mechanism)
12221223

1223-
# Create a proxy runnable that references the shared executor
12241224
proxy_runnable = StreamingRunnable(name="proxy", shared_runnable_name="shared_streamer")
12251225

12261226
class ContextWithExecutor:
@@ -1251,67 +1251,45 @@ def __init__(self, executor):
12511251
controller.await_termination()
12521252

12531253
@pytest.mark.parametrize(
1254-
"mechanism",
1255-
[ParallelExecutionMechanisms.process_pool, ParallelExecutionMechanisms.dedicated_process],
1254+
"execution_mechanism,expected_error",
1255+
[
1256+
(ParallelExecutionMechanisms.naive, ValueError),
1257+
(ParallelExecutionMechanisms.thread_pool, ValueError),
1258+
# Process-based mechanisms wrap errors in RuntimeError
1259+
(ParallelExecutionMechanisms.process_pool, RuntimeError),
1260+
(ParallelExecutionMechanisms.dedicated_process, RuntimeError),
1261+
],
12561262
)
1257-
def test_parallel_execution_streaming_with_process_based_fails_at_init(self, mechanism):
1258-
"""Test that StreamingError is raised at init time when streaming runnable uses process-based mechanism."""
1259-
runnable = StreamingRunnable(name="streamer")
1260-
1261-
flow = build_flow(
1263+
def test_parallel_execution_streaming_error_propagation(self, execution_mechanism, expected_error):
1264+
"""Test that errors in streaming are propagated correctly."""
1265+
runnable = ErrorStreamingRunnable(name="error_streamer")
1266+
controller = build_flow(
12621267
[
12631268
SyncEmitSource(),
12641269
ParallelExecution(
12651270
runnables=[runnable],
1266-
execution_mechanism_by_runnable_name={"streamer": mechanism},
1271+
execution_mechanism_by_runnable_name={"error_streamer": execution_mechanism},
12671272
),
12681273
Complete(),
12691274
]
1270-
)
1271-
1272-
expected_error_message = (
1273-
"Streaming is not supported with process-based execution mechanisms. "
1274-
f"Runnable 'streamer' uses '{mechanism}'. "
1275-
"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1276-
)
1277-
with pytest.raises(StreamingError, match=expected_error_message):
1278-
flow.run()
1279-
1280-
def test_parallel_execution_streaming_with_shared_executor_process_based_fails_at_init(self):
1281-
"""Test that StreamingError is raised at init when shared_executor uses process-based mechanism."""
1282-
# Create a shared executor with a process-based runnable
1283-
shared_executor = RunnableExecutor()
1284-
shared_runnable = StreamingRunnable(name="shared_streamer")
1285-
shared_executor.add_runnable(shared_runnable, ParallelExecutionMechanisms.process_pool)
1286-
1287-
# Create a proxy runnable that references the shared executor
1288-
proxy_runnable = StreamingRunnable(name="proxy", shared_runnable_name="shared_streamer")
1289-
1290-
class ContextWithExecutor:
1291-
def __init__(self, executor):
1292-
self.executor = executor
1293-
1294-
context = ContextWithExecutor(shared_executor)
1275+
).run()
12951276

1296-
flow = build_flow(
1297-
[
1298-
SyncEmitSource(),
1299-
ParallelExecution(
1300-
runnables=[proxy_runnable],
1301-
execution_mechanism_by_runnable_name={"proxy": ParallelExecutionMechanisms.shared_executor},
1302-
context=context,
1303-
),
1304-
Complete(),
1305-
]
1306-
)
1307-
1308-
expected_error_message = (
1309-
"Streaming is not supported with process-based execution mechanisms. "
1310-
"Runnable 'shared_streamer' uses 'process_pool'. "
1311-
"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1312-
)
1313-
with pytest.raises(StreamingError, match=expected_error_message):
1314-
flow.run()
1277+
try:
1278+
awaitable = controller.emit("test")
1279+
result = awaitable.await_result()
1280+
assert inspect.isgenerator(result)
1281+
# Should get first chunk, then error
1282+
chunks = []
1283+
with pytest.raises(expected_error, match="Simulated streaming error"):
1284+
for chunk in result:
1285+
chunks.append(chunk)
1286+
# Verify we got the first chunk before the error
1287+
assert chunks == ["test_chunk_0"]
1288+
finally:
1289+
controller.terminate()
1290+
# Error is also propagated through termination
1291+
with pytest.raises(expected_error, match="Simulated streaming error"):
1292+
controller.await_termination()
13151293

13161294

13171295
class TestStreamingGraphSplits:

0 commit comments

Comments
 (0)