diff --git a/storey/flow.py b/storey/flow.py index 6c5091d2..5c405cde 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -1728,6 +1728,20 @@ def __init__(self, runnable_name: str, data: Any, runtime: float, timestamp: dat self.timestamp = timestamp +class _StreamingResult: + """Wraps a streaming generator with timing metadata for model monitoring.""" + + def __init__( + self, + runnable_name: str, + generator: Generator | AsyncGenerator, + timestamp: datetime.datetime, + ): + self.runnable_name = runnable_name + self.generator = generator + self.timestamp = timestamp + + class ParallelExecutionMechanisms(str, enum.Enum): process_pool = "process_pool" dedicated_process = "dedicated_process" @@ -1840,9 +1854,9 @@ def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any: start = time.monotonic() try: result = self.run(body, path, origin_name) - # Return generator directly for streaming support + # Return streaming result with timing metadata for streaming support if _is_generator(result): - return result + return _StreamingResult(origin_name or self.name, result, timestamp) body = result except Exception as e: if self._raise_exception: @@ -1858,9 +1872,9 @@ async def _async_run(self, body: Any, path: str, origin_name: Optional[str] = No try: result = self.run_async(body, path, origin_name) - # Return generator directly for streaming support + # Return streaming result with timing metadata for streaming support if _is_generator(result): - return result + return _StreamingResult(origin_name or self.name, result, timestamp) # Await if coroutine if asyncio.iscoroutine(result): @@ -1902,7 +1916,10 @@ def _streaming_run_wrapper( sending each chunk through the multiprocessing queue. """ try: - for chunk in runnable._run(input, path, origin_name): + result = runnable._run(input, path, origin_name) + # Unwrap _StreamingResult to get the generator + generator = result.generator if isinstance(result, _StreamingResult) else result + for chunk in generator: queue.put(("chunk", chunk)) queue.put(("done", None)) except Exception as e: @@ -2291,15 +2308,33 @@ async def _do(self, event): # Check for streaming response (only when a single runnable is selected) if len(runnables) == 1 and results: result = results[0] - # Check if the result is a generator (streaming response) - if _is_generator(result): + # Check if the result is a streaming result (contains generator with timing metadata) + if isinstance(result, _StreamingResult): + # Set timing metadata on the event before emitting chunks + # For streaming, microsec is None since we don't have total runtime + metadata = { + "microsec": None, + "when": result.timestamp.isoformat(sep=" ", timespec="microseconds"), + } + self.set_event_metadata(event, metadata) + await self._emit_streaming_chunks(event, result.generator) + return None + # Handle raw generator from process-based streaming (no timing info from subprocess) + elif _is_generator(result): + # Use current timestamp as fallback for process-based streaming + timestamp = datetime.datetime.now(tz=datetime.timezone.utc) + metadata = { + "microsec": None, + "when": timestamp.isoformat(sep=" ", timespec="microseconds"), + } + self.set_event_metadata(event, metadata) await self._emit_streaming_chunks(event, result) return None # Non-streaming path - # Check if any results are generators (not allowed with multiple runnables) + # Check if any results are streaming (not allowed with multiple runnables) for result in results: - if _is_generator(result): + if isinstance(result, _StreamingResult): raise StreamingError( "Streaming is not supported when multiple runnables are selected. " "Streaming runnables must be the only runnable selected for an event." diff --git a/storey/steps/collector.py b/storey/steps/collector.py index 3f0b1ba4..76a67d57 100644 --- a/storey/steps/collector.py +++ b/storey/steps/collector.py @@ -13,6 +13,7 @@ # limitations under the License. # import copy +import datetime from collections import defaultdict from ..dtypes import StreamCompletion, _termination_obj @@ -46,6 +47,31 @@ def __init__(self, expected_completions: int = 1, **kwargs): lambda: {"chunks": [], "completions": 0, "first_event": None} ) + def _calculate_streaming_duration(self, event): + """ + Calculate total streaming duration and update event metadata with microsec. + + Uses the 'when' timestamp from the first chunk's metadata (set by ParallelExecution) + to calculate total elapsed time from stream start to completion. + + Streaming is only supported with a single selected runnable, so metadata is always + flat (top-level 'when' and 'microsec'), never nested under model names. + """ + if not hasattr(event, "_metadata") or not event._metadata: + return + + when_str = event._metadata.get("when") + if not when_str: + return + + try: + start_time = datetime.datetime.fromisoformat(when_str) + now = datetime.datetime.now(tz=datetime.timezone.utc) + event._metadata["microsec"] = int((now - start_time).total_seconds() * 1_000_000) + except (ValueError, TypeError) as exc: + if self.logger: + self.logger.warning(f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}") + async def _do(self, event): if event is _termination_obj: return await self._do_downstream(_termination_obj) @@ -73,6 +99,10 @@ async def _do(self, event): del collected_event.streaming_step if hasattr(collected_event, "chunk_id"): del collected_event.chunk_id + + # Calculate total streaming duration (microsec) if timing metadata exists + self._calculate_streaming_duration(collected_event) + await self._do_downstream(collected_event) # Clean up diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 84dee94b..35667ebe 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1084,6 +1084,32 @@ def stream_chunks(x): asyncio.run(_test()) + def test_streaming_with_multiple_runnables_raises_error(self): + """Test that streaming raises an error when multiple runnables are selected.""" + streaming = StreamingRunnable(name="streamer") + non_streaming = NonStreamingRunnable(name="non_streamer") + + controller = build_flow( + [ + SyncEmitSource(), + ParallelExecution( + runnables=[streaming, non_streaming], + execution_mechanism_by_runnable_name={ + "streamer": ParallelExecutionMechanisms.naive, + "non_streamer": ParallelExecutionMechanisms.naive, + }, + ), + Reduce([], lambda acc, x: acc + [x]), + ] + ).run() + + try: + controller.emit("test") + finally: + controller.terminate() + with pytest.raises(StreamingError, match="Streaming is not supported when multiple runnables are selected"): + controller.await_termination() + class TestStreamingWithIntermediateSteps: """Tests for streaming through intermediate non-streaming steps.""" @@ -1338,6 +1364,44 @@ def test_parallel_execution_streaming_error_propagation(self, execution_mechanis with pytest.raises(expected_error, match="Simulated streaming error"): controller.await_termination() + def test_parallel_execution_streaming_single_runnable_sets_metadata(self): + """Test that streaming ParallelExecution with single runnable sets timing metadata. + + This mirrors the non-streaming behavior where _metadata includes 'when' and 'microsec'. + After Collector aggregates chunks, the collected event should have timing metadata. + The 'microsec' field should contain the total streaming duration calculated by Collector. + """ + runnable = StreamingRunnable(name="streamer") + controller = build_flow( + [ + SyncEmitSource(), + ParallelExecution( + runnables=[runnable], + execution_mechanism_by_runnable_name={"streamer": ParallelExecutionMechanisms.naive}, + ), + Collector(), + Reduce([], lambda acc, x: acc + [x], full_event=True), + ] + ).run() + + try: + controller.emit("test") + finally: + controller.terminate() + result = controller.await_termination() + + assert len(result) == 1 + event = result[0] + assert hasattr(event, "_metadata"), "Expected event to have _metadata attribute" + metadata = event._metadata + assert "when" in metadata, "Expected _metadata to include 'when' field" + assert "microsec" in metadata, "Expected _metadata to include 'microsec' field" + # Verify 'when' is a valid ISO timestamp string + assert isinstance(metadata["when"], str), "Expected 'when' to be a string" + # Verify 'microsec' is a positive integer (total streaming duration calculated by Collector) + assert isinstance(metadata["microsec"], int), "Expected 'microsec' to be an integer" + assert metadata["microsec"] >= 0, "Expected 'microsec' to be non-negative" + class TestStreamingGraphSplits: """Tests for streaming through branching graph topologies."""