Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,18 @@ async def run_async(self, body: Any, path: str, origin_name: Optional[str] = Non
"""
return body

def is_streaming(self) -> bool:
"""
Returns True if this runnable produces streaming output (generator).

Override this method if your runnable's streaming behavior cannot be detected
by inspecting the run()/run_async() methods directly (e.g., when run() delegates
to another method that returns a generator).

:return: True if the runnable produces streaming output, False otherwise.
"""
return inspect.isgeneratorfunction(self.run) or inspect.isasyncgenfunction(self.run_async)

def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
start = time.monotonic()
Expand Down Expand Up @@ -1963,8 +1975,7 @@ def init_runnable(self, runnable: Union[str, ParallelExecutionRunnable]) -> None
execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]

# Record whether this runnable is a streaming runnable (generator function)
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
self._is_streaming_by_runnable_name[runnable.name] = is_streaming
self._is_streaming_by_runnable_name[runnable.name] = runnable.is_streaming()

if execution_mechanism == ParallelExecutionMechanisms.process_pool:
self.num_processes += 1
Expand Down
47 changes: 47 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def run(self, body, path: str, origin_name: Optional[str] = None) -> Generator:
raise ValueError("Simulated streaming error")


class NonStreamingRunnable(ParallelExecutionRunnable):
"""A non-streaming runnable that returns a single value."""

def run(self, body, path: str, origin_name: Optional[str] = None):
return f"{body}_result"


class TestStreamingPrimitives:
"""Tests for streaming primitive classes."""

Expand Down Expand Up @@ -120,6 +127,46 @@ async def coro():
c.close()


class TestIsStreamingMethod:
"""Tests for ParallelExecutionRunnable.is_streaming() method."""

def test_is_streaming_sync_generator(self):
"""Test that a runnable with a sync generator run() is detected as streaming."""
runnable = StreamingRunnable(name="test")
assert runnable.is_streaming() is True

def test_is_streaming_async_generator(self):
"""Test that a runnable with an async generator run_async() is detected as streaming."""
runnable = AsyncStreamingRunnable(name="test")
assert runnable.is_streaming() is True

def test_is_streaming_non_generator(self):
"""Test that a runnable with a non-generator run() is not detected as streaming."""
runnable = NonStreamingRunnable(name="test")
assert runnable.is_streaming() is False

def test_is_streaming_base_class(self):
"""Test that the base ParallelExecutionRunnable is not streaming by default."""
runnable = ParallelExecutionRunnable(name="test")
assert runnable.is_streaming() is False

def test_is_streaming_override(self):
"""Test that is_streaming() can be overridden by subclasses."""

class OverriddenRunnable(ParallelExecutionRunnable):
"""A runnable that overrides is_streaming() to return True."""

def is_streaming(self) -> bool:
return True

def run(self, body, path: str, origin_name: Optional[str] = None):
# Even though run() is not a generator, is_streaming() returns True
return f"{body}_result"

runnable = OverriddenRunnable(name="test")
assert runnable.is_streaming() is True


class TestMapStreaming:
"""Tests for Map step streaming support."""

Expand Down