Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion storey/steps/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ async def _do(self, event):
# Copy the original event to preserve all attributes (important for offset management)
collected_event = copy.copy(base_event)
collected_event.body = collected_body
# Clear streaming attributes
# Clear streaming attributes and mark as collected
if hasattr(collected_event, "streaming_step"):
del collected_event.streaming_step
if hasattr(collected_event, "chunk_id"):
del collected_event.chunk_id
collected_event.stream_collected = True

# Calculate total streaming duration (microsec) if timing metadata exists
self._calculate_streaming_duration(collected_event)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,50 @@ def stream_chunks(x):
assert ["a_0", "a_1"] in result
assert ["b_0", "b_1"] in result

def test_collector_sets_stream_collected_marker(self):
"""Test that Collector sets stream_collected=True on collected events."""

def stream_chunks(x):
for i in range(2):
yield f"{x}_{i}"

controller = build_flow(
[
SyncEmitSource(),
Map(stream_chunks),
Collector(),
Reduce([], lambda acc, x: acc + [x], full_event=True),
]
).run()

controller.emit("test")
controller.terminate()
result = controller.await_termination()

assert len(result) == 1
event = result[0]
assert event.stream_collected is True

def test_collector_passthrough_no_stream_collected_marker(self):
"""Test that Collector does NOT set stream_collected on non-streaming events."""

controller = build_flow(
[
SyncEmitSource(),
Map(lambda x: x * 2),
Collector(),
Reduce([], lambda acc, x: acc + [x], full_event=True),
]
).run()

controller.emit(5)
controller.terminate()
result = controller.await_termination()

assert len(result) == 1
event = result[0]
assert getattr(event, "stream_collected", False) is False

def test_collector_invalid_expected_completions(self):
"""Test that Collector raises error for invalid expected_completions."""
with pytest.raises(ValueError, match="expected_completions must be at least 1"):
Expand Down