2626from asyncio import Task
2727from collections import defaultdict
2828from concurrent .futures import ProcessPoolExecutor , ThreadPoolExecutor
29- from typing import Any , Callable , Collection , Dict , Iterable , List , Optional , Set , Union
29+ from typing import (
30+ Any ,
31+ AsyncGenerator ,
32+ Callable ,
33+ Collection ,
34+ Dict ,
35+ Generator ,
36+ Iterable ,
37+ List ,
38+ Optional ,
39+ Set ,
40+ Union ,
41+ )
3042
3143import aiohttp
3244
4557from .utils import _split_path , get_in , stringify_key , update_in
4658
4759
60+ def _is_generator (obj ) -> bool :
61+ """Check if an object is a sync or async generator."""
62+ return inspect .isgenerator (obj ) or inspect .isasyncgen (obj )
63+
64+
4865class Flow :
4966 _legal_first_step = False
5067
@@ -534,11 +551,6 @@ class _StreamingStepMixin:
534551 want to support user-provided generator functions.
535552 """
536553
537- @staticmethod
538- def _is_generator (obj ) -> bool :
539- """Check if an object is a sync or async generator."""
540- return inspect .isgenerator (obj ) or inspect .isasyncgen (obj )
541-
542554 def _validate_not_already_streaming (self , event ):
543555 """Ensure we're not streaming on top of an already streaming event.
544556
@@ -552,35 +564,35 @@ def _validate_not_already_streaming(self, event):
552564 f"Step '{ self .name } ' received a streaming event from '{ streaming_step } '."
553565 )
554566
555- async def _emit_streaming_chunks (self , original_event , generator ) :
567+ async def _emit_streaming_chunks (self , event , generator : Generator | AsyncGenerator ) -> None :
556568 """Emit streaming chunks from a generator, then send StreamCompletion.
557569
558570 Args:
559- original_event : The original event that triggered this streaming response .
571+ event : The event that will be used to create chunk events .
560572 generator: A sync or async generator yielding chunk bodies.
561573 """
562- self ._validate_not_already_streaming (original_event )
574+ self ._validate_not_already_streaming (event )
563575
564576 chunk_id = 0
565577 if inspect .isgenerator (generator ):
566578 # Sync generator
567579 for chunk_body in generator :
568- chunk_event = self ._user_fn_output_to_event (original_event , chunk_body )
580+ chunk_event = self ._user_fn_output_to_event (event , chunk_body )
569581 chunk_event .streaming_step = self .name
570582 chunk_event .chunk_id = chunk_id
571583 await self ._do_downstream (chunk_event )
572584 chunk_id += 1
573585 else :
574586 # Async generator
575587 async for chunk_body in generator :
576- chunk_event = self ._user_fn_output_to_event (original_event , chunk_body )
588+ chunk_event = self ._user_fn_output_to_event (event , chunk_body )
577589 chunk_event .streaming_step = self .name
578590 chunk_event .chunk_id = chunk_id
579591 await self ._do_downstream (chunk_event )
580592 chunk_id += 1
581593
582594 # Send completion signal
583- await self ._do_downstream (StreamCompletion (self .name , original_event ))
595+ await self ._do_downstream (StreamCompletion (self .name , event ))
584596
585597
586598class _UnaryFunctionFlow (Flow ):
@@ -671,7 +683,7 @@ class Map(_UnaryFunctionFlow, _StreamingStepMixin):
671683
672684 async def _do_internal (self , event , fn_result ):
673685 # Check if the result is a generator (streaming response)
674- if self . _is_generator (fn_result ):
686+ if _is_generator (fn_result ):
675687 await self ._emit_streaming_chunks (event , fn_result )
676688 else :
677689 mapped_event = self ._user_fn_output_to_event (event , fn_result )
@@ -864,7 +876,7 @@ async def _do(self, event):
864876 fn_result = await self ._call (element )
865877 if not self ._filter :
866878 # Check if the result is a generator (streaming response)
867- if self . _is_generator (fn_result ):
879+ if _is_generator (fn_result ):
868880 await self ._emit_streaming_chunks (event , fn_result )
869881 else :
870882 mapped_event = self ._user_fn_output_to_event (event , fn_result )
@@ -1806,18 +1818,13 @@ async def run_async(self, body: Any, path: str, origin_name: Optional[str] = Non
18061818 """
18071819 return body
18081820
1809- @staticmethod
1810- def _is_generator (obj ) -> bool :
1811- """Check if an object is a sync or async generator."""
1812- return inspect .isgenerator (obj ) or inspect .isasyncgen (obj )
1813-
18141821 def _run (self , body : Any , path : str , origin_name : Optional [str ] = None ) -> Any :
18151822 timestamp = datetime .datetime .now (tz = datetime .timezone .utc )
18161823 start = time .monotonic ()
18171824 try :
18181825 result = self .run (body , path , origin_name )
18191826 # Return generator directly for streaming support
1820- if self . _is_generator (result ):
1827+ if _is_generator (result ):
18211828 return result
18221829 body = result
18231830 except Exception as e :
@@ -1837,7 +1844,7 @@ async def _async_run(self, body: Any, path: str, origin_name: Optional[str] = No
18371844 return self .run_async (body , path , origin_name )
18381845 result = await self .run_async (body , path , origin_name )
18391846 # Return generator directly for streaming support
1840- if self . _is_generator (result ):
1847+ if _is_generator (result ):
18411848 return result
18421849 body = result
18431850 except Exception as e :
@@ -2154,7 +2161,7 @@ async def _do(self, event):
21542161 if len (runnables ) == 1 and results :
21552162 result = results [0 ]
21562163 # Check if the result is a generator (streaming response)
2157- if self . _is_generator (result ):
2164+ if _is_generator (result ):
21582165 # Validate execution mechanism - streaming not supported with process-based execution
21592166 runnable_name = (
21602167 runnables [0 ].name if isinstance (runnables [0 ], ParallelExecutionRunnable ) else runnables [0 ]
@@ -2181,7 +2188,7 @@ async def _do(self, event):
21812188 else :
21822189 # Check if any results are generators (not allowed with multiple runnables)
21832190 for result in results :
2184- if self . _is_generator (result ):
2191+ if _is_generator (result ):
21852192 raise StreamingError (
21862193 "Streaming is not supported when multiple runnables are selected. "
21872194 "Streaming runnables must be the only runnable selected for an event."
0 commit comments