Skip to content

Commit 404cf24

Browse files
committed
New tests and related fixes
1 parent 7abd93b commit 404cf24

File tree

4 files changed

+177
-42
lines changed

4 files changed

+177
-42
lines changed

storey/flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,8 @@ def _init(self):
485485
async def _do(self, event):
486486
if event is _termination_obj:
487487
return await self._do_downstream(_termination_obj, select_outlets=False)
488-
# StreamCompletion objects should be forwarded to all outlets without routing
488+
# StreamCompletion should propagate to all outlets (like _termination_obj)
489+
# to avoid hangs in cyclic graphs and ensure all Collectors receive completions
489490
if isinstance(event, StreamCompletion):
490491
return await self._do_downstream(event, select_outlets=False)
491492
else:

storey/steps/collector.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15+
import copy
1516
from collections import defaultdict
1617

17-
from ..dtypes import Event, StreamCompletion, _termination_obj
18+
from ..dtypes import StreamCompletion, _termination_obj
1819
from ..flow import Flow
1920

2021

@@ -41,7 +42,9 @@ def __init__(self, expected_completions: int = 1, **kwargs):
4142
raise ValueError("expected_completions must be at least 1")
4243
self._expected_completions = expected_completions
4344
# Map from event id -> {"chunks": [], "completions": 0, "first_event": Event}
44-
self._collected_streams = defaultdict(lambda: {"chunks": [], "completions": 0, "first_event": None})
45+
self._collected_streams: dict[str, dict] = defaultdict(
46+
lambda: {"chunks": [], "completions": 0, "first_event": None}
47+
)
4548

4649
async def _do(self, event):
4750
if event is _termination_obj:
@@ -50,51 +53,43 @@ async def _do(self, event):
5053
# Handle StreamCompletion sentinel
5154
if isinstance(event, StreamCompletion):
5255
stream_id = event.original_event.id
53-
stream_data = self._collected_streams.get(stream_id)
54-
55-
if stream_data is None:
56-
# No chunks received for this stream - unusual but possible
57-
return
56+
stream_data = self._collected_streams[stream_id] # Use [] to trigger defaultdict
5857

5958
stream_data["completions"] += 1
6059

6160
if stream_data["completions"] >= self._expected_completions:
6261
# Stream is complete - emit collected result
63-
first_event = stream_data["first_event"]
64-
if first_event:
65-
collected_body = [chunk.body for chunk in stream_data["chunks"]]
66-
if len(collected_body) == 1:
67-
collected_body = collected_body[0]
62+
# Use first_event if we have chunks, otherwise use original_event from completion (empty stream)
63+
base_event = stream_data["first_event"] or event.original_event
64+
collected_body = [chunk.body for chunk in stream_data["chunks"]]
65+
if len(collected_body) == 1:
66+
collected_body = collected_body[0]
6867

69-
# Create a new event with the collected body
70-
collected_event = Event(
71-
body=collected_body,
72-
key=first_event.key,
73-
processing_time=first_event.processing_time,
74-
id=first_event.id,
75-
headers=first_event.headers,
76-
method=first_event.method,
77-
path=first_event.path,
78-
content_type=first_event.content_type,
79-
awaitable_result=first_event._awaitable_result,
80-
)
81-
await self._do_downstream(collected_event)
68+
# Copy the original event to preserve all attributes (important for offset management)
69+
collected_event = copy.copy(base_event)
70+
collected_event.body = collected_body
71+
# Clear streaming attributes
72+
if hasattr(collected_event, "streaming_step"):
73+
del collected_event.streaming_step
74+
if hasattr(collected_event, "chunk_id"):
75+
del collected_event.chunk_id
76+
await self._do_downstream(collected_event)
8277

8378
# Clean up
8479
del self._collected_streams[stream_id]
85-
return
80+
return None
8681

8782
# Check if this is a streaming chunk (has streaming_step attribute)
88-
streaming_step = getattr(event, "streaming_step", None)
89-
if streaming_step:
83+
if hasattr(event, "streaming_step"):
9084
stream_id = event.id
9185
stream_data = self._collected_streams[stream_id]
9286
if stream_data["first_event"] is None:
9387
stream_data["first_event"] = event
9488
stream_data["chunks"].append(event)
89+
return None
9590
else:
9691
# Non-streaming event - pass through directly
97-
await self._do_downstream(event)
92+
return await self._do_downstream(event)
9893

9994
async def _cleanup(self):
10095
# Warn about incomplete streams on cleanup

tests/test_flow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5592,7 +5592,7 @@ def test_cyclic_graphs(iterations, with_recovery):
55925592
counter.to(my_loop)
55935593
my_loop.to(end)
55945594
end.to(Complete())
5595-
my_loop._outlets.append(counter)
5595+
my_loop.to(counter)
55965596
if with_recovery:
55975597
recovery_step = Map(lambda x: -1, name="end-2")
55985598
counter.set_recovery_step(recovery_step)
@@ -5636,10 +5636,10 @@ def test_two_cyclic_graphs():
56365636
start.to(counter)
56375637
counter.to(my_loop)
56385638
my_loop.to(counter_2)
5639-
my_loop._outlets.append(counter)
5639+
my_loop.to(counter)
56405640
counter_2.to(my_loop_2)
56415641
my_loop_2.to(end)
5642-
my_loop_2._outlets.append(counter_2)
5642+
my_loop_2.to(counter_2)
56435643
end.to(Complete())
56445644
controller = source.run()
56455645

@@ -5666,7 +5666,7 @@ def test_flow_reuse_with_cycle():
56665666
my_loop.to(end)
56675667
end.to(Complete())
56685668
# Create the cycle by appending counter as an outlet of my_loop
5669-
my_loop._outlets.append(counter)
5669+
my_loop.to(counter)
56705670

56715671
# Run the SAME flow 3 times to test reusability with cyclic structure
56725672
for run_num in range(3):

tests/test_streaming.py

Lines changed: 147 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,56 @@ def stream_chunks(x):
458458

459459
asyncio.run(_test())
460460

461+
def test_collector_empty_stream(self):
462+
"""Test that Collector emits an empty list for a stream with zero chunks."""
463+
464+
def empty_stream(x):
465+
return
466+
yield # Makes it a generator
467+
468+
controller = build_flow(
469+
[
470+
SyncEmitSource(),
471+
Map(empty_stream),
472+
Collector(),
473+
Reduce([], lambda acc, x: acc + [x]),
474+
]
475+
).run()
476+
477+
controller.emit("test")
478+
controller.terminate()
479+
result = controller.await_termination()
480+
481+
# Empty stream should emit an empty list
482+
assert len(result) == 1
483+
assert result[0] == []
484+
485+
def test_async_collector_empty_stream(self):
486+
"""Async version: Test that Collector emits an empty list for a stream with zero chunks."""
487+
488+
async def _test():
489+
def empty_stream(x):
490+
return
491+
yield # Makes it a generator
492+
493+
controller = build_flow(
494+
[
495+
AsyncEmitSource(),
496+
Map(empty_stream),
497+
Collector(),
498+
Reduce([], lambda acc, x: acc + [x]),
499+
]
500+
).run()
501+
502+
await controller.emit("test")
503+
await controller.terminate()
504+
result = await controller.await_termination()
505+
506+
assert len(result) == 1
507+
assert result[0] == []
508+
509+
asyncio.run(_test())
510+
461511

462512
class TestCompleteStreaming:
463513
"""Tests for Complete step streaming support."""
@@ -1190,13 +1240,16 @@ def select_outlets(self, event):
11901240
controller.terminate()
11911241
result = controller.await_termination()
11921242

1193-
# Should have 2 collected results (one per emit)
1194-
assert len(result) == 2
1195-
# low_value chunks go to branch_low
1196-
low_result = [r for r in result if any("LOW_" in str(item) for item in r)]
1197-
high_result = [r for r in result if any("HIGH_" in str(item) for item in r)]
1243+
# 4 results: chunks go to one branch, but StreamCompletion goes to all branches
1244+
# (like _termination_obj) to avoid hangs in cyclic graphs.
1245+
# Branches that don't receive chunks emit empty lists.
1246+
assert len(result) == 4
1247+
low_result = [r for r in result if r and any("LOW_" in str(item) for item in r)]
1248+
high_result = [r for r in result if r and any("HIGH_" in str(item) for item in r)]
1249+
empty_results = [r for r in result if r == []]
11981250
assert len(low_result) == 1
11991251
assert len(high_result) == 1
1252+
assert len(empty_results) == 2
12001253
assert "LOW_low_value_chunk_0" in low_result[0]
12011254
assert "HIGH_high_value_chunk_0" in high_result[0]
12021255

@@ -1235,11 +1288,14 @@ def select_outlets(self, event):
12351288
await controller.terminate()
12361289
result = await controller.await_termination()
12371290

1238-
assert len(result) == 2
1239-
low_result = [r for r in result if any("LOW_" in str(item) for item in r)]
1240-
high_result = [r for r in result if any("HIGH_" in str(item) for item in r)]
1291+
# 4 results: chunks go to one branch, but StreamCompletion goes to all branches
1292+
assert len(result) == 4
1293+
low_result = [r for r in result if r and any("LOW_" in str(item) for item in r)]
1294+
high_result = [r for r in result if r and any("HIGH_" in str(item) for item in r)]
1295+
empty_results = [r for r in result if r == []]
12411296
assert len(low_result) == 1
12421297
assert len(high_result) == 1
1298+
assert len(empty_results) == 2
12431299
assert "LOW_low_value_chunk_0" in low_result[0]
12441300
assert "HIGH_high_value_chunk_0" in high_result[0]
12451301

@@ -1403,3 +1459,86 @@ def failing_transform(x):
14031459
await controller.await_termination()
14041460

14051461
asyncio.run(_test())
1462+
1463+
def test_streaming_in_cycle_fails(self):
1464+
"""Test that streaming step inside a cycle fails on second iteration.
1465+
1466+
When a streaming step is inside a cycle, the first iteration streams chunks.
1467+
When those chunks loop back to the streaming step, they already have
1468+
streaming_step set, so the step should fail with StreamingError.
1469+
"""
1470+
1471+
class AlwaysLoop(Map):
1472+
"""A Map step that always routes back to the loop target."""
1473+
1474+
def __init__(self, loop_target, **kwargs):
1475+
super().__init__(**kwargs)
1476+
self._loop_target = loop_target
1477+
1478+
def select_outlets(self, event_body):
1479+
# Always loop back - the streaming error should stop us
1480+
return [self._loop_target]
1481+
1482+
def stream_chunks(x):
1483+
yield f"{x}_chunk_0"
1484+
yield f"{x}_chunk_1"
1485+
1486+
source = SyncEmitSource()
1487+
# The streaming map is the entry point of the loop
1488+
streaming_map = Map(stream_chunks, name="streamer", max_iterations=5)
1489+
loop_controller = AlwaysLoop(
1490+
fn=lambda x: x, name="loop_ctrl", loop_target="streamer", max_iterations=5
1491+
)
1492+
end = Reduce([], lambda acc, x: acc + [x], name="end")
1493+
1494+
source.to(streaming_map)
1495+
streaming_map.to(loop_controller)
1496+
loop_controller.to(end)
1497+
loop_controller.to(streaming_map) # Create cycle
1498+
1499+
controller = source.run()
1500+
1501+
controller.emit("test")
1502+
controller.terminate()
1503+
1504+
# Should fail because chunks looping back already have streaming_step set
1505+
with pytest.raises(StreamingError, match="Streaming on top of streaming is not allowed"):
1506+
controller.await_termination()
1507+
1508+
def test_async_streaming_in_cycle_fails(self):
1509+
"""Async version: Test that streaming step inside a cycle fails on second iteration."""
1510+
1511+
async def _test():
1512+
class AlwaysLoop(Map):
1513+
def __init__(self, loop_target, **kwargs):
1514+
super().__init__(**kwargs)
1515+
self._loop_target = loop_target
1516+
1517+
def select_outlets(self, event_body):
1518+
return [self._loop_target]
1519+
1520+
def stream_chunks(x):
1521+
yield f"{x}_chunk_0"
1522+
yield f"{x}_chunk_1"
1523+
1524+
source = AsyncEmitSource()
1525+
streaming_map = Map(stream_chunks, name="streamer", max_iterations=5)
1526+
loop_controller = AlwaysLoop(
1527+
fn=lambda x: x, name="loop_ctrl", loop_target="streamer", max_iterations=5
1528+
)
1529+
end = Reduce([], lambda acc, x: acc + [x], name="end")
1530+
1531+
source.to(streaming_map)
1532+
streaming_map.to(loop_controller)
1533+
loop_controller.to(end)
1534+
loop_controller.to(streaming_map) # Create cycle
1535+
1536+
controller = source.run()
1537+
1538+
await controller.emit("test")
1539+
await controller.terminate()
1540+
1541+
with pytest.raises(StreamingError, match="Streaming on top of streaming is not allowed"):
1542+
await controller.await_termination()
1543+
1544+
asyncio.run(_test())

0 commit comments

Comments
 (0)