Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 91 additions & 29 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,53 @@ def _static_run(*args, **kwargs):
return _sval._run(*args, **kwargs)


def _streaming_run_wrapper(
runnable: ParallelExecutionRunnable,
input,
path: str,
origin_name: Optional[str],
queue: multiprocessing.Queue,
) -> None:
"""Wrapper that runs a streaming runnable and sends chunks via queue.

This function runs in a child process and iterates over the generator,
sending each chunk through the multiprocessing queue.
"""
try:
for chunk in runnable._run(input, path, origin_name):
queue.put(("chunk", chunk))
queue.put(("done", None))
except Exception as e:
queue.put(("error", (type(e).__name__, str(e), traceback.format_exc())))


def _static_streaming_run(input, path: str, origin_name: Optional[str], queue: multiprocessing.Queue) -> None:
"""Streaming wrapper for dedicated_process using the global runnable."""
global _sval
_streaming_run_wrapper(_sval, input, path, origin_name, queue)


async def _async_read_streaming_queue(
queue: multiprocessing.Queue, loop: Optional[asyncio.AbstractEventLoop] = None
) -> AsyncGenerator:
"""Async generator that reads chunks from a multiprocessing queue.

This runs in the parent process and yields chunks sent by the child process,
without blocking the asyncio event loop.
"""
loop = asyncio.get_running_loop()
while True:
# Use run_in_executor to avoid blocking the event loop
msg_type, payload = await loop.run_in_executor(None, queue.get)
if msg_type == "chunk":
yield payload
elif msg_type == "done":
break
elif msg_type == "error":
exc_type, exc_msg, exc_tb = payload
raise RuntimeError(f"{exc_type}: {exc_msg}\n\nOriginal traceback:\n{exc_tb}")


class RunnableExecutor:
"""
Manages and executes `ParallelExecutionRunnable` instances using various parallel execution mechanisms.
Expand Down Expand Up @@ -1877,6 +1924,8 @@ def __init__(
self._process_executor_by_runnable_name = {}
self._mp_context = multiprocessing.get_context("spawn")
self._executors = {}
self._is_streaming_by_runnable_name: dict[str, bool] = {}
self._manager = None # Lazy-initialized multiprocessing Manager for queue-based streaming

def add_runnable(self, runnable: ParallelExecutionRunnable, execution_mechanism: str) -> None:
"""
Expand Down Expand Up @@ -1913,15 +1962,9 @@ def init_runnable(self, runnable: Union[str, ParallelExecutionRunnable]) -> None

execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]

# Check for streaming + process-based execution (incompatible combination)
if execution_mechanism in ParallelExecutionMechanisms.process():
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
if is_streaming:
raise StreamingError(
f"Streaming is not supported with process-based execution mechanisms. "
f"Runnable '{runnable.name}' uses '{execution_mechanism}'. "
f"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
)
# Record whether this runnable is a streaming runnable (generator function)
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
self._is_streaming_by_runnable_name[runnable.name] = is_streaming

if execution_mechanism == ParallelExecutionMechanisms.process_pool:
self.num_processes += 1
Expand Down Expand Up @@ -1956,6 +1999,12 @@ def init_executors(self):
if num_threads:
self._executors[ParallelExecutionMechanisms.thread_pool] = ThreadPoolExecutor(max_workers=num_threads)

def _get_manager(self):
"""Get or create the multiprocessing Manager for queue-based streaming."""
if self._manager is None:
self._manager = self._mp_context.Manager()
return self._manager

def run_executor(
self,
runnable: Union[ParallelExecutionRunnable, str],
Expand Down Expand Up @@ -1985,36 +2034,49 @@ def run_executor(
raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'")

execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]
is_streaming = self._is_streaming_by_runnable_name.get(runnable.name, False)

input = (
event.body if execution_mechanism in ParallelExecutionMechanisms.process() else copy.deepcopy(event.body)
)

loop = asyncio.get_running_loop()

if execution_mechanism == ParallelExecutionMechanisms.asyncio:
future = asyncio.get_running_loop().create_task(
runnable._async_run(input, event.path, origin_runnable_name)
)
future = loop.create_task(runnable._async_run(input, event.path, origin_runnable_name))
elif execution_mechanism == ParallelExecutionMechanisms.naive:
future = asyncio.get_running_loop().create_future()
future = loop.create_future()
future.set_result(runnable._run(input, event.path, origin_runnable_name))
elif execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
executor = self._process_executor_by_runnable_name[runnable.name]
future = asyncio.get_running_loop().run_in_executor(
executor,
_static_run,
input,
event.path,
origin_runnable_name,
)
elif execution_mechanism in ParallelExecutionMechanisms.process():
# Get the appropriate executor for this process mechanism
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
executor = self._process_executor_by_runnable_name[runnable.name]
else: # process_pool
executor = self._executors[execution_mechanism]

if is_streaming:
# Use Manager's queue for cross-process streaming (regular queues can't be passed to executor)
queue = self._get_manager().Queue()
# Use appropriate streaming function based on mechanism
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
loop.run_in_executor(
executor, _static_streaming_run, input, event.path, origin_runnable_name, queue
)
else:
loop.run_in_executor(
executor, _streaming_run_wrapper, runnable, input, event.path, origin_runnable_name, queue
)
future = loop.create_future()
future.set_result(_async_read_streaming_queue(queue))
else:
# Use appropriate run function based on mechanism
if execution_mechanism == ParallelExecutionMechanisms.dedicated_process:
future = loop.run_in_executor(executor, _static_run, input, event.path, origin_runnable_name)
else:
future = loop.run_in_executor(executor, runnable._run, input, event.path, origin_runnable_name)
else:
executor = self._executors[execution_mechanism]
future = asyncio.get_running_loop().run_in_executor(
executor,
runnable._run,
input,
event.path,
origin_runnable_name,
)
future = loop.run_in_executor(executor, runnable._run, input, event.path, origin_runnable_name)
return future


Expand Down
144 changes: 61 additions & 83 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ async def run_async(self, body, path: str, origin_name: Optional[str] = None) ->
yield f"{body}_chunk_{i}"


class ErrorStreamingRunnable(ParallelExecutionRunnable):
"""A streaming runnable that yields one chunk then raises an error."""

def run(self, body, path: str, origin_name: Optional[str] = None) -> Generator:
yield f"{body}_chunk_0"
raise ValueError("Simulated streaming error")


class TestStreamingPrimitives:
"""Tests for streaming primitive classes."""

Expand Down Expand Up @@ -1090,29 +1098,6 @@ def double(x):
class TestParallelExecutionStreaming:
"""Tests for ParallelExecution streaming support."""

def test_parallel_execution_single_runnable_streaming(self):
"""Test streaming with a single runnable."""
runnable = StreamingRunnable(name="streamer")
controller = build_flow(
[
SyncEmitSource(),
ParallelExecution(
runnables=[runnable],
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.naive},
),
Complete(),
]
).run()

try:
awaitable = controller.emit("test")
result = awaitable.await_result()
assert inspect.isgenerator(result)
assert list(result) == ["test_chunk_0", "test_chunk_1", "test_chunk_2"]
finally:
controller.terminate()
controller.await_termination()

def test_parallel_execution_async_runnable_streaming(self):
"""Test streaming with an async runnable."""
runnable = AsyncStreamingRunnable(name="async_streamer")
Expand Down Expand Up @@ -1190,15 +1175,24 @@ async def _test():

asyncio.run(_test())

def test_parallel_execution_streaming_with_thread_pool(self):
"""Test streaming works with thread_pool execution mechanism."""
@pytest.mark.parametrize(
"execution_mechanism",
[
ParallelExecutionMechanisms.naive,
ParallelExecutionMechanisms.thread_pool,
ParallelExecutionMechanisms.process_pool,
ParallelExecutionMechanisms.dedicated_process,
],
)
def test_parallel_execution_streaming_with_executor(self, execution_mechanism):
"""Test streaming works with various execution mechanisms."""
runnable = StreamingRunnable(name="streamer")
controller = build_flow(
[
SyncEmitSource(),
ParallelExecution(
runnables=[runnable],
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.thread_pool},
execution_mechanism_by_runnable_name={"streamer": execution_mechanism},
),
Complete(),
]
Expand All @@ -1213,14 +1207,20 @@ def test_parallel_execution_streaming_with_thread_pool(self):
controller.terminate()
controller.await_termination()

def test_parallel_execution_streaming_with_shared_executor_thread_based(self):
"""Test streaming works with shared_executor when the shared executor uses threads."""
# Create a shared executor with a thread-based runnable
@pytest.mark.parametrize(
"execution_mechanism",
[
ParallelExecutionMechanisms.thread_pool,
ParallelExecutionMechanisms.process_pool,
ParallelExecutionMechanisms.dedicated_process,
],
)
def test_parallel_execution_streaming_with_shared_executor(self, execution_mechanism):
"""Test streaming works with shared_executor using different underlying mechanisms."""
shared_executor = RunnableExecutor()
shared_runnable = StreamingRunnable(name="shared_streamer")
shared_executor.add_runnable(shared_runnable, ParallelExecutionMechanisms.thread_pool)
shared_executor.add_runnable(shared_runnable, execution_mechanism)

# Create a proxy runnable that references the shared executor
proxy_runnable = StreamingRunnable(name="proxy", shared_runnable_name="shared_streamer")

class ContextWithExecutor:
Expand Down Expand Up @@ -1251,67 +1251,45 @@ def __init__(self, executor):
controller.await_termination()

@pytest.mark.parametrize(
"mechanism",
[ParallelExecutionMechanisms.process_pool, ParallelExecutionMechanisms.dedicated_process],
"execution_mechanism,expected_error",
[
(ParallelExecutionMechanisms.naive, ValueError),
(ParallelExecutionMechanisms.thread_pool, ValueError),
# Process-based mechanisms wrap errors in RuntimeError
(ParallelExecutionMechanisms.process_pool, RuntimeError),
(ParallelExecutionMechanisms.dedicated_process, RuntimeError),
],
)
def test_parallel_execution_streaming_with_process_based_fails_at_init(self, mechanism):
"""Test that StreamingError is raised at init time when streaming runnable uses process-based mechanism."""
runnable = StreamingRunnable(name="streamer")

flow = build_flow(
def test_parallel_execution_streaming_error_propagation(self, execution_mechanism, expected_error):
"""Test that errors in streaming are propagated correctly."""
runnable = ErrorStreamingRunnable(name="error_streamer")
controller = build_flow(
[
SyncEmitSource(),
ParallelExecution(
runnables=[runnable],
execution_mechanism_by_runnable_name={"streamer": mechanism},
execution_mechanism_by_runnable_name={"error_streamer": execution_mechanism},
),
Complete(),
]
)

expected_error_message = (
"Streaming is not supported with process-based execution mechanisms. "
f"Runnable 'streamer' uses '{mechanism}'. "
"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
)
with pytest.raises(StreamingError, match=expected_error_message):
flow.run()

def test_parallel_execution_streaming_with_shared_executor_process_based_fails_at_init(self):
"""Test that StreamingError is raised at init when shared_executor uses process-based mechanism."""
# Create a shared executor with a process-based runnable
shared_executor = RunnableExecutor()
shared_runnable = StreamingRunnable(name="shared_streamer")
shared_executor.add_runnable(shared_runnable, ParallelExecutionMechanisms.process_pool)

# Create a proxy runnable that references the shared executor
proxy_runnable = StreamingRunnable(name="proxy", shared_runnable_name="shared_streamer")

class ContextWithExecutor:
def __init__(self, executor):
self.executor = executor

context = ContextWithExecutor(shared_executor)
).run()

flow = build_flow(
[
SyncEmitSource(),
ParallelExecution(
runnables=[proxy_runnable],
execution_mechanism_by_runnable_name={"proxy": ParallelExecutionMechanisms.shared_executor},
context=context,
),
Complete(),
]
)

expected_error_message = (
"Streaming is not supported with process-based execution mechanisms. "
"Runnable 'shared_streamer' uses 'process_pool'. "
"Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
)
with pytest.raises(StreamingError, match=expected_error_message):
flow.run()
try:
awaitable = controller.emit("test")
result = awaitable.await_result()
assert inspect.isgenerator(result)
# Should get first chunk, then error
chunks = []
with pytest.raises(expected_error, match="Simulated streaming error"):
for chunk in result:
chunks.append(chunk)
# Verify we got the first chunk before the error
assert chunks == ["test_chunk_0"]
finally:
controller.terminate()
# Error is also propagated through termination
with pytest.raises(expected_error, match="Simulated streaming error"):
controller.await_termination()


class TestStreamingGraphSplits:
Expand Down