Skip to content

Commit b563da7

Browse files
committed
Refactor
1 parent e3bdfdc commit b563da7

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

storey/flow.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,19 @@
2626
from asyncio import Task
2727
from collections import defaultdict
2828
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
29-
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Union
29+
from typing import (
30+
Any,
31+
AsyncGenerator,
32+
Callable,
33+
Collection,
34+
Dict,
35+
Generator,
36+
Iterable,
37+
List,
38+
Optional,
39+
Set,
40+
Union,
41+
)
3042

3143
import aiohttp
3244

@@ -45,6 +57,11 @@
4557
from .utils import _split_path, get_in, stringify_key, update_in
4658

4759

60+
def _is_generator(obj) -> bool:
61+
"""Check if an object is a sync or async generator."""
62+
return inspect.isgenerator(obj) or inspect.isasyncgen(obj)
63+
64+
4865
class Flow:
4966
_legal_first_step = False
5067

@@ -534,11 +551,6 @@ class _StreamingStepMixin:
534551
want to support user-provided generator functions.
535552
"""
536553

537-
@staticmethod
538-
def _is_generator(obj) -> bool:
539-
"""Check if an object is a sync or async generator."""
540-
return inspect.isgenerator(obj) or inspect.isasyncgen(obj)
541-
542554
def _validate_not_already_streaming(self, event):
543555
"""Ensure we're not streaming on top of an already streaming event.
544556
@@ -552,35 +564,35 @@ def _validate_not_already_streaming(self, event):
552564
f"Step '{self.name}' received a streaming event from '{streaming_step}'."
553565
)
554566

555-
async def _emit_streaming_chunks(self, original_event, generator):
567+
async def _emit_streaming_chunks(self, event, generator: Generator | AsyncGenerator) -> None:
556568
"""Emit streaming chunks from a generator, then send StreamCompletion.
557569
558570
Args:
559-
original_event: The original event that triggered this streaming response.
571+
event: The event that will be used to create chunk events.
560572
generator: A sync or async generator yielding chunk bodies.
561573
"""
562-
self._validate_not_already_streaming(original_event)
574+
self._validate_not_already_streaming(event)
563575

564576
chunk_id = 0
565577
if inspect.isgenerator(generator):
566578
# Sync generator
567579
for chunk_body in generator:
568-
chunk_event = self._user_fn_output_to_event(original_event, chunk_body)
580+
chunk_event = self._user_fn_output_to_event(event, chunk_body)
569581
chunk_event.streaming_step = self.name
570582
chunk_event.chunk_id = chunk_id
571583
await self._do_downstream(chunk_event)
572584
chunk_id += 1
573585
else:
574586
# Async generator
575587
async for chunk_body in generator:
576-
chunk_event = self._user_fn_output_to_event(original_event, chunk_body)
588+
chunk_event = self._user_fn_output_to_event(event, chunk_body)
577589
chunk_event.streaming_step = self.name
578590
chunk_event.chunk_id = chunk_id
579591
await self._do_downstream(chunk_event)
580592
chunk_id += 1
581593

582594
# Send completion signal
583-
await self._do_downstream(StreamCompletion(self.name, original_event))
595+
await self._do_downstream(StreamCompletion(self.name, event))
584596

585597

586598
class _UnaryFunctionFlow(Flow):
@@ -671,7 +683,7 @@ class Map(_UnaryFunctionFlow, _StreamingStepMixin):
671683

672684
async def _do_internal(self, event, fn_result):
673685
# Check if the result is a generator (streaming response)
674-
if self._is_generator(fn_result):
686+
if _is_generator(fn_result):
675687
await self._emit_streaming_chunks(event, fn_result)
676688
else:
677689
mapped_event = self._user_fn_output_to_event(event, fn_result)
@@ -864,7 +876,7 @@ async def _do(self, event):
864876
fn_result = await self._call(element)
865877
if not self._filter:
866878
# Check if the result is a generator (streaming response)
867-
if self._is_generator(fn_result):
879+
if _is_generator(fn_result):
868880
await self._emit_streaming_chunks(event, fn_result)
869881
else:
870882
mapped_event = self._user_fn_output_to_event(event, fn_result)
@@ -1806,18 +1818,13 @@ async def run_async(self, body: Any, path: str, origin_name: Optional[str] = Non
18061818
"""
18071819
return body
18081820

1809-
@staticmethod
1810-
def _is_generator(obj) -> bool:
1811-
"""Check if an object is a sync or async generator."""
1812-
return inspect.isgenerator(obj) or inspect.isasyncgen(obj)
1813-
18141821
def _run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
18151822
timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
18161823
start = time.monotonic()
18171824
try:
18181825
result = self.run(body, path, origin_name)
18191826
# Return generator directly for streaming support
1820-
if self._is_generator(result):
1827+
if _is_generator(result):
18211828
return result
18221829
body = result
18231830
except Exception as e:
@@ -1837,7 +1844,7 @@ async def _async_run(self, body: Any, path: str, origin_name: Optional[str] = No
18371844
return self.run_async(body, path, origin_name)
18381845
result = await self.run_async(body, path, origin_name)
18391846
# Return generator directly for streaming support
1840-
if self._is_generator(result):
1847+
if _is_generator(result):
18411848
return result
18421849
body = result
18431850
except Exception as e:
@@ -2154,7 +2161,7 @@ async def _do(self, event):
21542161
if len(runnables) == 1 and results:
21552162
result = results[0]
21562163
# Check if the result is a generator (streaming response)
2157-
if self._is_generator(result):
2164+
if _is_generator(result):
21582165
# Validate execution mechanism - streaming not supported with process-based execution
21592166
runnable_name = (
21602167
runnables[0].name if isinstance(runnables[0], ParallelExecutionRunnable) else runnables[0]
@@ -2181,7 +2188,7 @@ async def _do(self, event):
21812188
else:
21822189
# Check if any results are generators (not allowed with multiple runnables)
21832190
for result in results:
2184-
if self._is_generator(result):
2191+
if _is_generator(result):
21852192
raise StreamingError(
21862193
"Streaming is not supported when multiple runnables are selected. "
21872194
"Streaming runnables must be the only runnable selected for an event."

tests/test_streaming.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
build_flow,
3434
)
3535
from storey.dtypes import Event, StreamChunk, StreamCompletion
36-
from storey.flow import _StreamingStepMixin
36+
from storey.flow import _is_generator
3737

3838

3939
class TestStreamingPrimitives:
@@ -60,27 +60,27 @@ def test_stream_completion_repr(self):
6060
assert "my_step" in repr(completion)
6161

6262

63-
class TestStreamingStepMixin:
64-
"""Tests for the _StreamingStepMixin utility methods."""
63+
class TestIsGenerator:
64+
"""Tests for the _is_generator utility function."""
6565

6666
def test_is_generator_sync(self):
6767
def gen():
6868
yield 1
6969
yield 2
7070

71-
assert _StreamingStepMixin._is_generator(gen())
71+
assert _is_generator(gen())
7272

7373
def test_is_generator_async(self):
7474
async def async_gen():
7575
yield 1
7676
yield 2
7777

78-
assert _StreamingStepMixin._is_generator(async_gen())
78+
assert _is_generator(async_gen())
7979

8080
def test_is_generator_non_generator(self):
81-
assert not _StreamingStepMixin._is_generator([1, 2, 3])
82-
assert not _StreamingStepMixin._is_generator("string")
83-
assert not _StreamingStepMixin._is_generator(42)
81+
assert not _is_generator([1, 2, 3])
82+
assert not _is_generator("string")
83+
assert not _is_generator(42)
8484

8585
def test_is_generator_coroutine(self):
8686
async def coro():
@@ -89,7 +89,7 @@ async def coro():
8989
# Coroutine is not a generator
9090
c = coro()
9191
try:
92-
assert not _StreamingStepMixin._is_generator(c)
92+
assert not _is_generator(c)
9393
finally:
9494
c.close()
9595

0 commit comments

Comments
 (0)