Skip to content

Commit fd5fee1

Browse files
committed
Add overridable is_streaming() to runnables
Adds an `is_streaming()` method to `ParallelExecutionRunnable` that detects whether the runnable produces streaming output by inspecting `run()` and `run_async()` for generator functions. > > Subclasses can override this method when streaming detection requires custom logic (e.g., when the runnable delegates to other methods like `predict()` in mlrun). To enable a cleaner solution to [ML-11878](https://iguazio.atlassian.net/browse/ML-11878).
1 parent 0bc0410 commit fd5fee1

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

storey/flow.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,6 +1796,18 @@ async def run_async(self, body: Any, path: str, origin_name: Optional[str] = Non
17961796
"""
17971797
return body
17981798

1799+
def is_streaming(self) -> bool:
1800+
"""
1801+
Returns True if this runnable produces streaming output (generator).
1802+
1803+
Override this method if your runnable's streaming behavior cannot be detected
1804+
by inspecting the run()/run_async() methods directly (e.g., when run() delegates
1805+
to another method that returns a generator).
1806+
1807+
:return: True if the runnable produces streaming output, False otherwise.
1808+
"""
1809+
return inspect.isgeneratorfunction(self.run) or inspect.isasyncgenfunction(self.run_async)
1810+
17991811
def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
18001812
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
18011813
start = time.monotonic()
@@ -1963,8 +1975,7 @@ def init_runnable(self, runnable: Union[str, ParallelExecutionRunnable]) -> None
19631975
execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]
19641976

19651977
# Record whether this runnable is a streaming runnable (generator function)
1966-
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
1967-
self._is_streaming_by_runnable_name[runnable.name] = is_streaming
1978+
self._is_streaming_by_runnable_name[runnable.name] = runnable.is_streaming()
19681979

19691980
if execution_mechanism == ParallelExecutionMechanisms.process_pool:
19701981
self.num_processes += 1

tests/test_streaming.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def run(self, body, path: str, origin_name: Optional[str] = None) -> Generator:
6262
raise ValueError("Simulated streaming error")
6363

6464

65+
class NonStreamingRunnable(ParallelExecutionRunnable):
66+
"""A non-streaming runnable that returns a single value."""
67+
68+
def run(self, body, path: str, origin_name: Optional[str] = None):
69+
return f"{body}_result"
70+
71+
6572
class TestStreamingPrimitives:
6673
"""Tests for streaming primitive classes."""
6774

@@ -120,6 +127,46 @@ async def coro():
120127
c.close()
121128

122129

130+
class TestIsStreamingMethod:
131+
"""Tests for ParallelExecutionRunnable.is_streaming() method."""
132+
133+
def test_is_streaming_sync_generator(self):
134+
"""Test that a runnable with a sync generator run() is detected as streaming."""
135+
runnable = StreamingRunnable(name="test")
136+
assert runnable.is_streaming() is True
137+
138+
def test_is_streaming_async_generator(self):
139+
"""Test that a runnable with an async generator run_async() is detected as streaming."""
140+
runnable = AsyncStreamingRunnable(name="test")
141+
assert runnable.is_streaming() is True
142+
143+
def test_is_streaming_non_generator(self):
144+
"""Test that a runnable with a non-generator run() is not detected as streaming."""
145+
runnable = NonStreamingRunnable(name="test")
146+
assert runnable.is_streaming() is False
147+
148+
def test_is_streaming_base_class(self):
149+
"""Test that the base ParallelExecutionRunnable is not streaming by default."""
150+
runnable = ParallelExecutionRunnable(name="test")
151+
assert runnable.is_streaming() is False
152+
153+
def test_is_streaming_override(self):
154+
"""Test that is_streaming() can be overridden by subclasses."""
155+
156+
class OverriddenRunnable(ParallelExecutionRunnable):
157+
"""A runnable that overrides is_streaming() to return True."""
158+
159+
def is_streaming(self) -> bool:
160+
return True
161+
162+
def run(self, body, path: str, origin_name: Optional[str] = None):
163+
# Even though run() is not a generator, is_streaming() returns True
164+
return f"{body}_result"
165+
166+
runnable = OverriddenRunnable(name="test")
167+
assert runnable.is_streaming() is True
168+
169+
123170
class TestMapStreaming:
124171
"""Tests for Map step streaming support."""
125172

0 commit comments

Comments
 (0)