Skip to content

Commit 784b03e

Browse files
authored
Add timing metadata for streaming ParallelExecution responses (#613)
* Add timing metadata for streaming ParallelExecution responses [ML-11879](https://iguazio.atlassian.net/browse/ML-11879) Streaming responses from `ParallelExecution` were missing the when and `microsec` timing metadata that non-streaming responses include. This metadata is required for model monitoring in MLRun. Changes * Add `_StreamingResult` class to wrap streaming generators with timing info * Set timing metadata on events before emitting streaming chunks * Handle both in-process streaming (`_StreamingResult`) and process-based streaming (raw generators) Notes For streaming, `microsec` is set to `None` since total runtime isn't available until streaming completes For process-based streaming, when uses the timestamp when chunks start arriving (timing from subprocess isn't available) * Implement latency metadata in Collector * Lint * Remove code that handles unsupported multi-model stream result * Add missing unit test for `StreamingError`
1 parent 404d886 commit 784b03e

File tree

3 files changed

+138
-9
lines changed

3 files changed

+138
-9
lines changed

storey/flow.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,6 +1728,20 @@ def __init__(self, runnable_name: str, data: Any, runtime: float, timestamp: dat
17281728
self.timestamp = timestamp
17291729

17301730

1731+
class _StreamingResult:
1732+
"""Wraps a streaming generator with timing metadata for model monitoring."""
1733+
1734+
def __init__(
1735+
self,
1736+
runnable_name: str,
1737+
generator: Generator | AsyncGenerator,
1738+
timestamp: datetime.datetime,
1739+
):
1740+
self.runnable_name = runnable_name
1741+
self.generator = generator
1742+
self.timestamp = timestamp
1743+
1744+
17311745
class ParallelExecutionMechanisms(str, enum.Enum):
17321746
process_pool = "process_pool"
17331747
dedicated_process = "dedicated_process"
@@ -1840,9 +1854,9 @@ def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
18401854
start = time.monotonic()
18411855
try:
18421856
result = self.run(body, path, origin_name)
1843-
# Return generator directly for streaming support
1857+
# Return streaming result with timing metadata for streaming support
18441858
if _is_generator(result):
1845-
return result
1859+
return _StreamingResult(origin_name or self.name, result, timestamp)
18461860
body = result
18471861
except Exception as e:
18481862
if self._raise_exception:
@@ -1858,9 +1872,9 @@ async def _async_run(self, body: Any, path: str, origin_name: Optional[str] = No
18581872
try:
18591873
result = self.run_async(body, path, origin_name)
18601874

1861-
# Return generator directly for streaming support
1875+
# Return streaming result with timing metadata for streaming support
18621876
if _is_generator(result):
1863-
return result
1877+
return _StreamingResult(origin_name or self.name, result, timestamp)
18641878

18651879
# Await if coroutine
18661880
if asyncio.iscoroutine(result):
@@ -1902,7 +1916,10 @@ def _streaming_run_wrapper(
19021916
sending each chunk through the multiprocessing queue.
19031917
"""
19041918
try:
1905-
for chunk in runnable._run(input, path, origin_name):
1919+
result = runnable._run(input, path, origin_name)
1920+
# Unwrap _StreamingResult to get the generator
1921+
generator = result.generator if isinstance(result, _StreamingResult) else result
1922+
for chunk in generator:
19061923
queue.put(("chunk", chunk))
19071924
queue.put(("done", None))
19081925
except Exception as e:
@@ -2291,15 +2308,33 @@ async def _do(self, event):
22912308
# Check for streaming response (only when a single runnable is selected)
22922309
if len(runnables) == 1 and results:
22932310
result = results[0]
2294-
# Check if the result is a generator (streaming response)
2295-
if _is_generator(result):
2311+
# Check if the result is a streaming result (contains generator with timing metadata)
2312+
if isinstance(result, _StreamingResult):
2313+
# Set timing metadata on the event before emitting chunks
2314+
# For streaming, microsec is None since we don't have total runtime
2315+
metadata = {
2316+
"microsec": None,
2317+
"when": result.timestamp.isoformat(sep=" ", timespec="microseconds"),
2318+
}
2319+
self.set_event_metadata(event, metadata)
2320+
await self._emit_streaming_chunks(event, result.generator)
2321+
return None
2322+
# Handle raw generator from process-based streaming (no timing info from subprocess)
2323+
elif _is_generator(result):
2324+
# Use current timestamp as fallback for process-based streaming
2325+
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
2326+
metadata = {
2327+
"microsec": None,
2328+
"when": timestamp.isoformat(sep=" ", timespec="microseconds"),
2329+
}
2330+
self.set_event_metadata(event, metadata)
22962331
await self._emit_streaming_chunks(event, result)
22972332
return None
22982333

22992334
# Non-streaming path
2300-
# Check if any results are generators (not allowed with multiple runnables)
2335+
# Check if any results are streaming (not allowed with multiple runnables)
23012336
for result in results:
2302-
if _is_generator(result):
2337+
if isinstance(result, _StreamingResult):
23032338
raise StreamingError(
23042339
"Streaming is not supported when multiple runnables are selected. "
23052340
"Streaming runnables must be the only runnable selected for an event."

storey/steps/collector.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
#
1515
import copy
16+
import datetime
1617
from collections import defaultdict
1718

1819
from ..dtypes import StreamCompletion, _termination_obj
@@ -46,6 +47,31 @@ def __init__(self, expected_completions: int = 1, **kwargs):
4647
lambda: {"chunks": [], "completions": 0, "first_event": None}
4748
)
4849

50+
def _calculate_streaming_duration(self, event):
51+
"""
52+
Calculate total streaming duration and update event metadata with microsec.
53+
54+
Uses the 'when' timestamp from the first chunk's metadata (set by ParallelExecution)
55+
to calculate total elapsed time from stream start to completion.
56+
57+
Streaming is only supported with a single selected runnable, so metadata is always
58+
flat (top-level 'when' and 'microsec'), never nested under model names.
59+
"""
60+
if not hasattr(event, "_metadata") or not event._metadata:
61+
return
62+
63+
when_str = event._metadata.get("when")
64+
if not when_str:
65+
return
66+
67+
try:
68+
start_time = datetime.datetime.fromisoformat(when_str)
69+
now = datetime.datetime.now(tz=datetime.timezone.utc)
70+
event._metadata["microsec"] = int((now - start_time).total_seconds() * 1_000_000)
71+
except (ValueError, TypeError) as exc:
72+
if self.logger:
73+
self.logger.warning(f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}")
74+
4975
async def _do(self, event):
5076
if event is _termination_obj:
5177
return await self._do_downstream(_termination_obj)
@@ -73,6 +99,10 @@ async def _do(self, event):
7399
del collected_event.streaming_step
74100
if hasattr(collected_event, "chunk_id"):
75101
del collected_event.chunk_id
102+
103+
# Calculate total streaming duration (microsec) if timing metadata exists
104+
self._calculate_streaming_duration(collected_event)
105+
76106
await self._do_downstream(collected_event)
77107

78108
# Clean up

tests/test_streaming.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,32 @@ def stream_chunks(x):
10841084

10851085
asyncio.run(_test())
10861086

1087+
def test_streaming_with_multiple_runnables_raises_error(self):
1088+
"""Test that streaming raises an error when multiple runnables are selected."""
1089+
streaming = StreamingRunnable(name="streamer")
1090+
non_streaming = NonStreamingRunnable(name="non_streamer")
1091+
1092+
controller = build_flow(
1093+
[
1094+
SyncEmitSource(),
1095+
ParallelExecution(
1096+
runnables=[streaming, non_streaming],
1097+
execution_mechanism_by_runnable_name={
1098+
"streamer": ParallelExecutionMechanisms.naive,
1099+
"non_streamer": ParallelExecutionMechanisms.naive,
1100+
},
1101+
),
1102+
Reduce([], lambda acc, x: acc + [x]),
1103+
]
1104+
).run()
1105+
1106+
try:
1107+
controller.emit("test")
1108+
finally:
1109+
controller.terminate()
1110+
with pytest.raises(StreamingError, match="Streaming is not supported when multiple runnables are selected"):
1111+
controller.await_termination()
1112+
10871113

10881114
class TestStreamingWithIntermediateSteps:
10891115
"""Tests for streaming through intermediate non-streaming steps."""
@@ -1338,6 +1364,44 @@ def test_parallel_execution_streaming_error_propagation(self, execution_mechanis
13381364
with pytest.raises(expected_error, match="Simulated streaming error"):
13391365
controller.await_termination()
13401366

1367+
def test_parallel_execution_streaming_single_runnable_sets_metadata(self):
1368+
"""Test that streaming ParallelExecution with single runnable sets timing metadata.
1369+
1370+
This mirrors the non-streaming behavior where _metadata includes 'when' and 'microsec'.
1371+
After Collector aggregates chunks, the collected event should have timing metadata.
1372+
The 'microsec' field should contain the total streaming duration calculated by Collector.
1373+
"""
1374+
runnable = StreamingRunnable(name="streamer")
1375+
controller = build_flow(
1376+
[
1377+
SyncEmitSource(),
1378+
ParallelExecution(
1379+
runnables=[runnable],
1380+
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.naive},
1381+
),
1382+
Collector(),
1383+
Reduce([], lambda acc, x: acc + [x], full_event=True),
1384+
]
1385+
).run()
1386+
1387+
try:
1388+
controller.emit("test")
1389+
finally:
1390+
controller.terminate()
1391+
result = controller.await_termination()
1392+
1393+
assert len(result) == 1
1394+
event = result[0]
1395+
assert hasattr(event, "_metadata"), "Expected event to have _metadata attribute"
1396+
metadata = event._metadata
1397+
assert "when" in metadata, "Expected _metadata to include 'when' field"
1398+
assert "microsec" in metadata, "Expected _metadata to include 'microsec' field"
1399+
# Verify 'when' is a valid ISO timestamp string
1400+
assert isinstance(metadata["when"], str), "Expected 'when' to be a string"
1401+
# Verify 'microsec' is a positive integer (total streaming duration calculated by Collector)
1402+
assert isinstance(metadata["microsec"], int), "Expected 'microsec' to be an integer"
1403+
assert metadata["microsec"] >= 0, "Expected 'microsec' to be non-negative"
1404+
13411405

13421406
class TestStreamingGraphSplits:
13431407
"""Tests for streaming through branching graph topologies."""

0 commit comments

Comments
 (0)