Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,20 @@ def __init__(self, runnable_name: str, data: Any, runtime: float, timestamp: dat
self.timestamp = timestamp


class _StreamingResult:
"""Wraps a streaming generator with timing metadata for model monitoring."""

def __init__(
self,
runnable_name: str,
generator: Generator | AsyncGenerator,
timestamp: datetime.datetime,
):
self.runnable_name = runnable_name
self.generator = generator
self.timestamp = timestamp


class ParallelExecutionMechanisms(str, enum.Enum):
process_pool = "process_pool"
dedicated_process = "dedicated_process"
Expand Down Expand Up @@ -1840,9 +1854,9 @@ def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
start = time.monotonic()
try:
result = self.run(body, path, origin_name)
# Return generator directly for streaming support
# Return streaming result with timing metadata for streaming support
if _is_generator(result):
return result
return _StreamingResult(origin_name or self.name, result, timestamp)
body = result
except Exception as e:
if self._raise_exception:
Expand All @@ -1858,9 +1872,9 @@ async def _async_run(self, body: Any, path: str, origin_name: Optional[str] = No
try:
result = self.run_async(body, path, origin_name)

# Return generator directly for streaming support
# Return streaming result with timing metadata for streaming support
if _is_generator(result):
return result
return _StreamingResult(origin_name or self.name, result, timestamp)

# Await if coroutine
if asyncio.iscoroutine(result):
Expand Down Expand Up @@ -1902,7 +1916,10 @@ def _streaming_run_wrapper(
sending each chunk through the multiprocessing queue.
"""
try:
for chunk in runnable._run(input, path, origin_name):
result = runnable._run(input, path, origin_name)
# Unwrap _StreamingResult to get the generator
generator = result.generator if isinstance(result, _StreamingResult) else result
for chunk in generator:
queue.put(("chunk", chunk))
queue.put(("done", None))
except Exception as e:
Expand Down Expand Up @@ -2291,15 +2308,33 @@ async def _do(self, event):
# Check for streaming response (only when a single runnable is selected)
if len(runnables) == 1 and results:
result = results[0]
# Check if the result is a generator (streaming response)
if _is_generator(result):
# Check if the result is a streaming result (contains generator with timing metadata)
if isinstance(result, _StreamingResult):
# Set timing metadata on the event before emitting chunks
# For streaming, microsec is None since we don't have total runtime
metadata = {
"microsec": None,
"when": result.timestamp.isoformat(sep=" ", timespec="microseconds"),
}
self.set_event_metadata(event, metadata)
await self._emit_streaming_chunks(event, result.generator)
return None
# Handle raw generator from process-based streaming (no timing info from subprocess)
elif _is_generator(result):
# Use current timestamp as fallback for process-based streaming
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
metadata = {
"microsec": None,
"when": timestamp.isoformat(sep=" ", timespec="microseconds"),
}
self.set_event_metadata(event, metadata)
await self._emit_streaming_chunks(event, result)
return None

# Non-streaming path
# Check if any results are generators (not allowed with multiple runnables)
# Check if any results are streaming (not allowed with multiple runnables)
for result in results:
if _is_generator(result):
if isinstance(result, _StreamingResult):
raise StreamingError(
"Streaming is not supported when multiple runnables are selected. "
"Streaming runnables must be the only runnable selected for an event."
Expand Down
30 changes: 30 additions & 0 deletions storey/steps/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
#
import copy
import datetime
from collections import defaultdict

from ..dtypes import StreamCompletion, _termination_obj
Expand Down Expand Up @@ -46,6 +47,31 @@ def __init__(self, expected_completions: int = 1, **kwargs):
lambda: {"chunks": [], "completions": 0, "first_event": None}
)

def _calculate_streaming_duration(self, event):
"""
Calculate total streaming duration and update event metadata with microsec.

Uses the 'when' timestamp from the first chunk's metadata (set by ParallelExecution)
to calculate total elapsed time from stream start to completion.

Streaming is only supported with a single selected runnable, so metadata is always
flat (top-level 'when' and 'microsec'), never nested under model names.
"""
if not hasattr(event, "_metadata") or not event._metadata:
return

when_str = event._metadata.get("when")
if not when_str:
return

try:
start_time = datetime.datetime.fromisoformat(when_str)
now = datetime.datetime.now(tz=datetime.timezone.utc)
event._metadata["microsec"] = int((now - start_time).total_seconds() * 1_000_000)
except (ValueError, TypeError) as exc:
if self.logger:
self.logger.warning(f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}")

async def _do(self, event):
if event is _termination_obj:
return await self._do_downstream(_termination_obj)
Expand Down Expand Up @@ -73,6 +99,10 @@ async def _do(self, event):
del collected_event.streaming_step
if hasattr(collected_event, "chunk_id"):
del collected_event.chunk_id

# Calculate total streaming duration (microsec) if timing metadata exists
self._calculate_streaming_duration(collected_event)

await self._do_downstream(collected_event)

# Clean up
Expand Down
64 changes: 64 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,32 @@ def stream_chunks(x):

asyncio.run(_test())

def test_streaming_with_multiple_runnables_raises_error(self):
"""Test that streaming raises an error when multiple runnables are selected."""
streaming = StreamingRunnable(name="streamer")
non_streaming = NonStreamingRunnable(name="non_streamer")

controller = build_flow(
[
SyncEmitSource(),
ParallelExecution(
runnables=[streaming, non_streaming],
execution_mechanism_by_runnable_name={
"streamer": ParallelExecutionMechanisms.naive,
"non_streamer": ParallelExecutionMechanisms.naive,
},
),
Reduce([], lambda acc, x: acc + [x]),
]
).run()

try:
controller.emit("test")
finally:
controller.terminate()
with pytest.raises(StreamingError, match="Streaming is not supported when multiple runnables are selected"):
controller.await_termination()


class TestStreamingWithIntermediateSteps:
"""Tests for streaming through intermediate non-streaming steps."""
Expand Down Expand Up @@ -1338,6 +1364,44 @@ def test_parallel_execution_streaming_error_propagation(self, execution_mechanis
with pytest.raises(expected_error, match="Simulated streaming error"):
controller.await_termination()

def test_parallel_execution_streaming_single_runnable_sets_metadata(self):
"""Test that streaming ParallelExecution with single runnable sets timing metadata.

This mirrors the non-streaming behavior where _metadata includes 'when' and 'microsec'.
After Collector aggregates chunks, the collected event should have timing metadata.
The 'microsec' field should contain the total streaming duration calculated by Collector.
"""
runnable = StreamingRunnable(name="streamer")
controller = build_flow(
[
SyncEmitSource(),
ParallelExecution(
runnables=[runnable],
execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.naive},
),
Collector(),
Reduce([], lambda acc, x: acc + [x], full_event=True),
]
).run()

try:
controller.emit("test")
finally:
controller.terminate()
result = controller.await_termination()

assert len(result) == 1
event = result[0]
assert hasattr(event, "_metadata"), "Expected event to have _metadata attribute"
metadata = event._metadata
assert "when" in metadata, "Expected _metadata to include 'when' field"
assert "microsec" in metadata, "Expected _metadata to include 'microsec' field"
# Verify 'when' is a valid ISO timestamp string
assert isinstance(metadata["when"], str), "Expected 'when' to be a string"
# Verify 'microsec' is a positive integer (total streaming duration calculated by Collector)
assert isinstance(metadata["microsec"], int), "Expected 'microsec' to be an integer"
assert metadata["microsec"] >= 0, "Expected 'microsec' to be non-negative"


class TestStreamingGraphSplits:
"""Tests for streaming through branching graph topologies."""
Expand Down