Skip to content

Commit 6d32a0a

Browse files
authored
Mark collected events (#614)
So that they can be identified downstream. [ML-11879](https://iguazio.atlassian.net/browse/ML-11879)
1 parent 784b03e commit 6d32a0a

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

storey/steps/collector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ async def _do(self, event):
9494
# Copy the original event to preserve all attributes (important for offset management)
9595
collected_event = copy.copy(base_event)
9696
collected_event.body = collected_body
97-
# Clear streaming attributes
97+
# Clear streaming attributes and mark as collected
9898
if hasattr(collected_event, "streaming_step"):
9999
del collected_event.streaming_step
100100
if hasattr(collected_event, "chunk_id"):
101101
del collected_event.chunk_id
102+
collected_event.stream_collected = True
102103

103104
# Calculate total streaming duration (microsec) if timing metadata exists
104105
self._calculate_streaming_duration(collected_event)

tests/test_streaming.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,50 @@ def stream_chunks(x):
447447
assert ["a_0", "a_1"] in result
448448
assert ["b_0", "b_1"] in result
449449

450+
def test_collector_sets_stream_collected_marker(self):
451+
"""Test that Collector sets stream_collected=True on collected events."""
452+
453+
def stream_chunks(x):
454+
for i in range(2):
455+
yield f"{x}_{i}"
456+
457+
controller = build_flow(
458+
[
459+
SyncEmitSource(),
460+
Map(stream_chunks),
461+
Collector(),
462+
Reduce([], lambda acc, x: acc + [x], full_event=True),
463+
]
464+
).run()
465+
466+
controller.emit("test")
467+
controller.terminate()
468+
result = controller.await_termination()
469+
470+
assert len(result) == 1
471+
event = result[0]
472+
assert event.stream_collected is True
473+
474+
def test_collector_passthrough_no_stream_collected_marker(self):
475+
"""Test that Collector does NOT set stream_collected on non-streaming events."""
476+
477+
controller = build_flow(
478+
[
479+
SyncEmitSource(),
480+
Map(lambda x: x * 2),
481+
Collector(),
482+
Reduce([], lambda acc, x: acc + [x], full_event=True),
483+
]
484+
).run()
485+
486+
controller.emit(5)
487+
controller.terminate()
488+
result = controller.await_termination()
489+
490+
assert len(result) == 1
491+
event = result[0]
492+
assert getattr(event, "stream_collected", False) is False
493+
450494
def test_collector_invalid_expected_completions(self):
451495
"""Test that Collector raises error for invalid expected_completions."""
452496
with pytest.raises(ValueError, match="expected_completions must be at least 1"):

0 commit comments

Comments
 (0)