Skip to content

Commit 0bc0410

Browse files
authored
Support streaming + process-based execution mechanism (#609)
* Support streaming + process-based execution mechanism [ML-11878](https://iguazio.atlassian.net/browse/ML-11878) Enables streaming (generator-based) runnables to work with process_pool and dedicated_process execution mechanisms. Since generators cannot cross process boundaries directly, this implementation uses multiprocessing queues to transfer chunks from child processes to the parent: * Child process iterates the generator and sends chunks via queue * Parent process yields chunks from a queue-reading generator * Errors are serialized and re-raised in the parent as RuntimeError Changes: * `_streaming_run_wrapper` / `_static_streaming_run` - execute streaming runnables in child process `_read_streaming_queue` / `_async_read_streaming_queue` - yield chunks from queue in parent * Lazy-initialized `multiprocessing.Manager` for queue creation * Comprehensive test coverage for all execution mechanisms * Delete unused method
1 parent 212ce90 commit 0bc0410

File tree

2 files changed

+152
-112
lines changed

2 files changed

+152
-112
lines changed

storey/flow.py

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,53 @@ def _static_run(*args, **kwargs):
18501850
return _sval._run(*args, **kwargs)
18511851

18521852

1853+
def _streaming_run_wrapper(
1854+
runnable: ParallelExecutionRunnable,
1855+
input,
1856+
path: str,
1857+
origin_name: Optional[str],
1858+
queue: multiprocessing.Queue,
1859+
) -> None:
1860+
"""Wrapper that runs a streaming runnable and sends chunks via queue.
1861+
1862+
This function runs in a child process and iterates over the generator,
1863+
sending each chunk through the multiprocessing queue.
1864+
"""
1865+
try:
1866+
for chunk in runnable._run(input, path, origin_name):
1867+
queue.put(("chunk", chunk))
1868+
queue.put(("done", None))
1869+
except Exception as e:
1870+
queue.put(("error", (type(e).__name__, str(e), traceback.format_exc())))
1871+
1872+
1873+
def _static_streaming_run(input, path: str, origin_name: Optional[str], queue: multiprocessing.Queue) -> None:
1874+
"""Streaming wrapper for dedicated_process using the global runnable."""
1875+
global _sval
1876+
_streaming_run_wrapper(_sval, input, path, origin_name, queue)
1877+
1878+
1879+
async def _async_read_streaming_queue(
1880+
queue: multiprocessing.Queue, loop: Optional[asyncio.AbstractEventLoop] = None
1881+
) -> AsyncGenerator:
1882+
"""Async generator that reads chunks from a multiprocessing queue.
1883+
1884+
This runs in the parent process and yields chunks sent by the child process,
1885+
without blocking the asyncio event loop.
1886+
"""
1887+
loop = asyncio.get_running_loop()
1888+
while True:
1889+
# Use run_in_executor to avoid blocking the event loop
1890+
msg_type, payload = await loop.run_in_executor(None, queue.get)
1891+
if msg_type == "chunk":
1892+
yield payload
1893+
elif msg_type == "done":
1894+
break
1895+
elif msg_type == "error":
1896+
exc_type, exc_msg, exc_tb = payload
1897+
raise RuntimeError(f"{exc_type}: {exc_msg}\n\nOriginal traceback:\n{exc_tb}")
1898+
1899+
18531900
class RunnableExecutor:
18541901
"""
18551902
Manages and executes `ParallelExecutionRunnable` instances using various parallel execution mechanisms.
@@ -1877,6 +1924,8 @@ def __init__(
18771924
self._process_executor_by_runnable_name = {}
18781925
self._mp_context = multiprocessing.get_context("spawn")
18791926
self._executors = {}
1927+
self._is_streaming_by_runnable_name: dict[str, bool] = {}
1928+
self._manager = None # Lazy-initialized multiprocessing Manager for queue-based streaming
18801929

18811930
def add_runnable(self, runnable: ParallelExecutionRunnable, execution_mechanism: str) -> None:
18821931
"""
@@ -1913,15 +1962,9 @@ def init_runnable(self, runnable: Union[str, ParallelExecutionRunnable]) -> None
19131962

19141963
execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]
19151964

1916-
# Check for streaming + process-based execution (incompatible combination)
1917-
if execution_mechanism in ParallelExecutionMechanisms.process():
1918-
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
1919-
if is_streaming:
1920-
raise StreamingError(
1921-
f"Streaming is not supported with process-based execution mechanisms. "
1922-
f"Runnable '{runnable.name}' uses '{execution_mechanism}'. "
1923-
f"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1924-
)
1965+
# Record whether this runnable is a streaming runnable (generator function)
1966+
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
1967+
self._is_streaming_by_runnable_name[runnable.name] = is_streaming
19251968

19261969
if execution_mechanism == ParallelExecutionMechanisms.process_pool:
19271970
self.num_processes += 1
@@ -1956,6 +1999,12 @@ def init_executors(self):
19561999
if num_threads:
19572000
self._executors[ParallelExecutionMechanisms.thread_pool] = ThreadPoolExecutor(max_workers=num_threads)
19582001

2002+
def _get_manager(self):
2003+
"""Get or create the multiprocessing Manager for queue-based streaming."""
2004+
if self._manager is None:
2005+
self._manager = self._mp_context.Manager()
2006+
return self._manager
2007+
19592008
def run_executor(
19602009
self,
19612010
runnable: Union[ParallelExecutionRunnable, str],
@@ -1985,36 +2034,49 @@ def run_executor(
19852034
raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'")
19862035

19872036
execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]
2037+
is_streaming = self._is_streaming_by_runnable_name.get(runnable.name, False)
19882038

19892039
input = (
19902040
event.body if execution_mechanism in ParallelExecutionMechanisms.process() else copy.deepcopy(event.body)
19912041
)
19922042

2043+
loop = asyncio.get_running_loop()
2044+
19932045
if execution_mechanism == ParallelExecutionMechanisms.asyncio:
1994-
future = asyncio.get_running_loop().create_task(
1995-
runnable._async_run(input, event.path, origin_runnable_name)
1996-
)
2046+
future = loop.create_task(runnable._async_run(input, event.path, origin_runnable_name))
19972047
elif execution_mechanism == ParallelExecutionMechanisms.naive:
1998-
future = asyncio.get_running_loop().create_future()
2048+
future = loop.create_future()
19992049
future.set_result(runnable._run(input, event.path, origin_runnable_name))
2000-
elif execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2001-
executor = self._process_executor_by_runnable_name[runnable.name]
2002-
future = asyncio.get_running_loop().run_in_executor(
2003-
executor,
2004-
_static_run,
2005-
input,
2006-
event.path,
2007-
origin_runnable_name,
2008-
)
2050+
elif execution_mechanism in ParallelExecutionMechanisms.process():
2051+
# Get the appropriate executor for this process mechanism
2052+
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2053+
executor = self._process_executor_by_runnable_name[runnable.name]
2054+
else: # process_pool
2055+
executor = self._executors[execution_mechanism]
2056+
2057+
if is_streaming:
2058+
# Use Manager's queue for cross-process streaming (regular queues can't be passed to executor)
2059+
queue = self._get_manager().Queue()
2060+
# Use appropriate streaming function based on mechanism
2061+
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2062+
loop.run_in_executor(
2063+
executor, _static_streaming_run, input, event.path, origin_runnable_name, queue
2064+
)
2065+
else:
2066+
loop.run_in_executor(
2067+
executor, _streaming_run_wrapper, runnable, input, event.path, origin_runnable_name, queue
2068+
)
2069+
future = loop.create_future()
2070+
future.set_result(_async_read_streaming_queue(queue))
2071+
else:
2072+
# Use appropriate run function based on mechanism
2073+
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
2074+
future = loop.run_in_executor(executor, _static_run, input, event.path, origin_runnable_name)
2075+
else:
2076+
future = loop.run_in_executor(executor, runnable._run, input, event.path, origin_runnable_name)
20092077
else:
20102078
executor = self._executors[execution_mechanism]
2011-
future = asyncio.get_running_loop().run_in_executor(
2012-
executor,
2013-
runnable._run,
2014-
input,
2015-
event.path,
2016-
origin_runnable_name,
2017-
)
2079+
future = loop.run_in_executor(executor, runnable._run, input, event.path, origin_runnable_name)
20182080
return future
20192081

20202082

tests/test_streaming.py

Lines changed: 61 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ async def run_async(self, body, path: str, origin_name: Optional[str] = None) ->
5454
yield f"{body}_chunk_{i}"
5555

5656

57+
class ErrorStreamingRunnable(ParallelExecutionRunnable):
58+
"""A streaming runnable that yields one chunk then raises an error."""
59+
60+
def run(self, body, path: str, origin_name: Optional[str] = None) -> Generator:
61+
yield f"{body}_chunk_0"
62+
raise ValueError("Simulated streaming error")
63+
64+
5765
class TestStreamingPrimitives:
5866
"""Tests for streaming primitive classes."""
5967

@@ -1090,29 +1098,6 @@ def double(x):
10901098
class TestParallelExecutionStreaming:
10911099
"""Tests for ParallelExecution streaming support."""
10921100

1093-
def test_parallel_execution_single_runnable_streaming(self):
1094-
"""Test streaming with a single runnable."""
1095-
runnable = StreamingRunnable(name="streamer")
1096-
controller = build_flow(
1097-
[
1098-
SyncEmitSource(),
1099-
ParallelExecution(
1100-
runnables=[runnable],
1101-
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.naive},
1102-
),
1103-
Complete(),
1104-
]
1105-
).run()
1106-
1107-
try:
1108-
awaitable = controller.emit("test")
1109-
result = awaitable.await_result()
1110-
assert inspect.isgenerator(result)
1111-
assert list(result) == ["test_chunk_0", "test_chunk_1", "test_chunk_2"]
1112-
finally:
1113-
controller.terminate()
1114-
controller.await_termination()
1115-
11161101
def test_parallel_execution_async_runnable_streaming(self):
11171102
"""Test streaming with an async runnable."""
11181103
runnable = AsyncStreamingRunnable(name="async_streamer")
@@ -1190,15 +1175,24 @@ async def _test():
11901175

11911176
asyncio.run(_test())
11921177

1193-
def test_parallel_execution_streaming_with_thread_pool(self):
1194-
"""Test streaming works with thread_pool execution mechanism."""
1178+
@pytest.mark.parametrize(
1179+
"execution_mechanism",
1180+
[
1181+
ParallelExecutionMechanisms.naive,
1182+
ParallelExecutionMechanisms.thread_pool,
1183+
ParallelExecutionMechanisms.process_pool,
1184+
ParallelExecutionMechanisms.dedicated_process,
1185+
],
1186+
)
1187+
def test_parallel_execution_streaming_with_executor(self, execution_mechanism):
1188+
"""Test streaming works with various execution mechanisms."""
11951189
runnable = StreamingRunnable(name="streamer")
11961190
controller = build_flow(
11971191
[
11981192
SyncEmitSource(),
11991193
ParallelExecution(
12001194
runnables=[runnable],
1201-
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.thread_pool},
1195+
execution_mechanism_by_runnable_name={"streamer": execution_mechanism},
12021196
),
12031197
Complete(),
12041198
]
@@ -1213,14 +1207,20 @@ def test_parallel_execution_streaming_with_thread_pool(self):
12131207
controller.terminate()
12141208
controller.await_termination()
12151209

1216-
def test_parallel_execution_streaming_with_shared_executor_thread_based(self):
1217-
"""Test streaming works with shared_executor when the shared executor uses threads."""
1218-
# Create a shared executor with a thread-based runnable
1210+
@pytest.mark.parametrize(
1211+
"execution_mechanism",
1212+
[
1213+
ParallelExecutionMechanisms.thread_pool,
1214+
ParallelExecutionMechanisms.process_pool,
1215+
ParallelExecutionMechanisms.dedicated_process,
1216+
],
1217+
)
1218+
def test_parallel_execution_streaming_with_shared_executor(self, execution_mechanism):
1219+
"""Test streaming works with shared_executor using different underlying mechanisms."""
12191220
shared_executor = RunnableExecutor()
12201221
shared_runnable = StreamingRunnable(name="shared_streamer")
1221-
shared_executor.add_runnable(shared_runnable, ParallelExecutionMechanisms.thread_pool)
1222+
shared_executor.add_runnable(shared_runnable, execution_mechanism)
12221223

1223-
# Create a proxy runnable that references the shared executor
12241224
proxy_runnable = StreamingRunnable(name="proxy", shared_runnable_name="shared_streamer")
12251225

12261226
class ContextWithExecutor:
@@ -1251,67 +1251,45 @@ def __init__(self, executor):
12511251
controller.await_termination()
12521252

12531253
@pytest.mark.parametrize(
1254-
"mechanism",
1255-
[ParallelExecutionMechanisms.process_pool, ParallelExecutionMechanisms.dedicated_process],
1254+
"execution_mechanism,expected_error",
1255+
[
1256+
(ParallelExecutionMechanisms.naive, ValueError),
1257+
(ParallelExecutionMechanisms.thread_pool, ValueError),
1258+
# Process-based mechanisms wrap errors in RuntimeError
1259+
(ParallelExecutionMechanisms.process_pool, RuntimeError),
1260+
(ParallelExecutionMechanisms.dedicated_process, RuntimeError),
1261+
],
12561262
)
1257-
def test_parallel_execution_streaming_with_process_based_fails_at_init(self, mechanism):
1258-
"""Test that StreamingError is raised at init time when streaming runnable uses process-based mechanism."""
1259-
runnable = StreamingRunnable(name="streamer")
1260-
1261-
flow = build_flow(
1263+
def test_parallel_execution_streaming_error_propagation(self, execution_mechanism, expected_error):
1264+
"""Test that errors in streaming are propagated correctly."""
1265+
runnable = ErrorStreamingRunnable(name="error_streamer")
1266+
controller = build_flow(
12621267
[
12631268
SyncEmitSource(),
12641269
ParallelExecution(
12651270
runnables=[runnable],
1266-
execution_mechanism_by_runnable_name={"streamer": mechanism},
1271+
execution_mechanism_by_runnable_name={"error_streamer": execution_mechanism},
12671272
),
12681273
Complete(),
12691274
]
1270-
)
1271-
1272-
expected_error_message = (
1273-
"Streaming is not supported with process-based execution mechanisms. "
1274-
f"Runnable 'streamer' uses '{mechanism}'. "
1275-
"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1276-
)
1277-
with pytest.raises(StreamingError, match=expected_error_message):
1278-
flow.run()
1279-
1280-
def test_parallel_execution_streaming_with_shared_executor_process_based_fails_at_init(self):
1281-
"""Test that StreamingError is raised at init when shared_executor uses process-based mechanism."""
1282-
# Create a shared executor with a process-based runnable
1283-
shared_executor = RunnableExecutor()
1284-
shared_runnable = StreamingRunnable(name="shared_streamer")
1285-
shared_executor.add_runnable(shared_runnable, ParallelExecutionMechanisms.process_pool)
1286-
1287-
# Create a proxy runnable that references the shared executor
1288-
proxy_runnable = StreamingRunnable(name="proxy", shared_runnable_name="shared_streamer")
1289-
1290-
class ContextWithExecutor:
1291-
def __init__(self, executor):
1292-
self.executor = executor
1293-
1294-
context = ContextWithExecutor(shared_executor)
1275+
).run()
12951276

1296-
flow = build_flow(
1297-
[
1298-
SyncEmitSource(),
1299-
ParallelExecution(
1300-
runnables=[proxy_runnable],
1301-
execution_mechanism_by_runnable_name={"proxy": ParallelExecutionMechanisms.shared_executor},
1302-
context=context,
1303-
),
1304-
Complete(),
1305-
]
1306-
)
1307-
1308-
expected_error_message = (
1309-
"Streaming is not supported with process-based execution mechanisms. "
1310-
"Runnable 'shared_streamer' uses 'process_pool'. "
1311-
"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1312-
)
1313-
with pytest.raises(StreamingError, match=expected_error_message):
1314-
flow.run()
1277+
try:
1278+
awaitable = controller.emit("test")
1279+
result = awaitable.await_result()
1280+
assert inspect.isgenerator(result)
1281+
# Should get first chunk, then error
1282+
chunks = []
1283+
with pytest.raises(expected_error, match="Simulated streaming error"):
1284+
for chunk in result:
1285+
chunks.append(chunk)
1286+
# Verify we got the first chunk before the error
1287+
assert chunks == ["test_chunk_0"]
1288+
finally:
1289+
controller.terminate()
1290+
# Error is also propagated through termination
1291+
with pytest.raises(expected_error, match="Simulated streaming error"):
1292+
controller.await_termination()
13151293

13161294

13171295
class TestStreamingGraphSplits:

0 commit comments

Comments
 (0)