Skip to content

Commit 0471120

Browse files
committed
Merge remote-tracking branch 'mlrun/development' into support_mrs_batch
2 parents 7dff69e + af69b05 commit 0471120

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

storey/flow.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,6 +1823,18 @@ async def run_async(self, body: Any, path: str, origin_name: Optional[str] = Non
18231823
"""
18241824
return body
18251825

1826+
def is_streaming(self) -> bool:
1827+
"""
1828+
Returns True if this runnable produces streaming output (generator).
1829+
1830+
Override this method if your runnable's streaming behavior cannot be detected
1831+
by inspecting the run()/run_async() methods directly (e.g., when run() delegates
1832+
to another method that returns a generator).
1833+
1834+
:return: True if the runnable produces streaming output, False otherwise.
1835+
"""
1836+
return inspect.isgeneratorfunction(self.run) or inspect.isasyncgenfunction(self.run_async)
1837+
18261838
def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
18271839
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
18281840
start = time.monotonic()
@@ -1990,8 +2002,7 @@ def init_runnable(self, runnable: Union[str, ParallelExecutionRunnable]) -> None
19902002
execution_mechanism = self._execution_mechanism_by_runnable_name[runnable.name]
19912003

19922004
# Record whether this runnable is a streaming runnable (generator function)
1993-
is_streaming = inspect.isgeneratorfunction(runnable.run) or inspect.isasyncgenfunction(runnable.run_async)
1994-
self._is_streaming_by_runnable_name[runnable.name] = is_streaming
2005+
self._is_streaming_by_runnable_name[runnable.name] = runnable.is_streaming()
19952006

19962007
if execution_mechanism == ParallelExecutionMechanisms.process_pool:
19972008
self.num_processes += 1

tests/test_streaming.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def run(self, body, path: str, origin_name: Optional[str] = None) -> Generator:
6262
raise ValueError("Simulated streaming error")
6363

6464

65+
class NonStreamingRunnable(ParallelExecutionRunnable):
66+
"""A non-streaming runnable that returns a single value."""
67+
68+
def run(self, body, path: str, origin_name: Optional[str] = None):
69+
return f"{body}_result"
70+
71+
6572
class TestStreamingPrimitives:
6673
"""Tests for streaming primitive classes."""
6774

@@ -120,6 +127,46 @@ async def coro():
120127
c.close()
121128

122129

130+
class TestIsStreamingMethod:
131+
"""Tests for ParallelExecutionRunnable.is_streaming() method."""
132+
133+
def test_is_streaming_sync_generator(self):
134+
"""Test that a runnable with a sync generator run() is detected as streaming."""
135+
runnable = StreamingRunnable(name="test")
136+
assert runnable.is_streaming() is True
137+
138+
def test_is_streaming_async_generator(self):
139+
"""Test that a runnable with an async generator run_async() is detected as streaming."""
140+
runnable = AsyncStreamingRunnable(name="test")
141+
assert runnable.is_streaming() is True
142+
143+
def test_is_streaming_non_generator(self):
144+
"""Test that a runnable with a non-generator run() is not detected as streaming."""
145+
runnable = NonStreamingRunnable(name="test")
146+
assert runnable.is_streaming() is False
147+
148+
def test_is_streaming_base_class(self):
149+
"""Test that the base ParallelExecutionRunnable is not streaming by default."""
150+
runnable = ParallelExecutionRunnable(name="test")
151+
assert runnable.is_streaming() is False
152+
153+
def test_is_streaming_override(self):
154+
"""Test that is_streaming() can be overridden by subclasses."""
155+
156+
class OverriddenRunnable(ParallelExecutionRunnable):
157+
"""A runnable that overrides is_streaming() to return True."""
158+
159+
def is_streaming(self) -> bool:
160+
return True
161+
162+
def run(self, body, path: str, origin_name: Optional[str] = None):
163+
# Even though run() is not a generator, is_streaming() returns True
164+
return f"{body}_result"
165+
166+
runnable = OverriddenRunnable(name="test")
167+
assert runnable.is_streaming() is True
168+
169+
123170
class TestMapStreaming:
124171
"""Tests for Map step streaming support."""
125172

0 commit comments

Comments
 (0)