@@ -54,6 +54,14 @@ async def run_async(self, body, path: str, origin_name: Optional[str] = None) ->
5454 yield f"{ body } _chunk_{ i } "
5555
5656
57+ class ErrorStreamingRunnable (ParallelExecutionRunnable ):
58+ """A streaming runnable that yields one chunk then raises an error."""
59+
60+ def run (self , body , path : str , origin_name : Optional [str ] = None ) -> Generator :
61+ yield f"{ body } _chunk_0"
62+ raise ValueError ("Simulated streaming error" )
63+
64+
5765class TestStreamingPrimitives :
5866 """Tests for streaming primitive classes."""
5967
@@ -1090,29 +1098,6 @@ def double(x):
10901098class TestParallelExecutionStreaming :
10911099 """Tests for ParallelExecution streaming support."""
10921100
1093- def test_parallel_execution_single_runnable_streaming (self ):
1094- """Test streaming with a single runnable."""
1095- runnable = StreamingRunnable (name = "streamer" )
1096- controller = build_flow (
1097- [
1098- SyncEmitSource (),
1099- ParallelExecution (
1100- runnables = [runnable ],
1101- execution_mechanism_by_runnable_name = {"streamer" : ParallelExecutionMechanisms .naive },
1102- ),
1103- Complete (),
1104- ]
1105- ).run ()
1106-
1107- try :
1108- awaitable = controller .emit ("test" )
1109- result = awaitable .await_result ()
1110- assert inspect .isgenerator (result )
1111- assert list (result ) == ["test_chunk_0" , "test_chunk_1" , "test_chunk_2" ]
1112- finally :
1113- controller .terminate ()
1114- controller .await_termination ()
1115-
11161101 def test_parallel_execution_async_runnable_streaming (self ):
11171102 """Test streaming with an async runnable."""
11181103 runnable = AsyncStreamingRunnable (name = "async_streamer" )
@@ -1190,15 +1175,24 @@ async def _test():
11901175
11911176 asyncio .run (_test ())
11921177
1193- def test_parallel_execution_streaming_with_thread_pool (self ):
1194- """Test streaming works with thread_pool execution mechanism."""
1178+ @pytest .mark .parametrize (
1179+ "execution_mechanism" ,
1180+ [
1181+ ParallelExecutionMechanisms .naive ,
1182+ ParallelExecutionMechanisms .thread_pool ,
1183+ ParallelExecutionMechanisms .process_pool ,
1184+ ParallelExecutionMechanisms .dedicated_process ,
1185+ ],
1186+ )
1187+ def test_parallel_execution_streaming_with_executor (self , execution_mechanism ):
1188+ """Test streaming works with various execution mechanisms."""
11951189 runnable = StreamingRunnable (name = "streamer" )
11961190 controller = build_flow (
11971191 [
11981192 SyncEmitSource (),
11991193 ParallelExecution (
12001194 runnables = [runnable ],
1201- execution_mechanism_by_runnable_name = {"streamer" : ParallelExecutionMechanisms . thread_pool },
1195+ execution_mechanism_by_runnable_name = {"streamer" : execution_mechanism },
12021196 ),
12031197 Complete (),
12041198 ]
@@ -1213,14 +1207,20 @@ def test_parallel_execution_streaming_with_thread_pool(self):
12131207 controller .terminate ()
12141208 controller .await_termination ()
12151209
1216- def test_parallel_execution_streaming_with_shared_executor_thread_based (self ):
1217- """Test streaming works with shared_executor when the shared executor uses threads."""
1218- # Create a shared executor with a thread-based runnable
1210+ @pytest .mark .parametrize (
1211+ "execution_mechanism" ,
1212+ [
1213+ ParallelExecutionMechanisms .thread_pool ,
1214+ ParallelExecutionMechanisms .process_pool ,
1215+ ParallelExecutionMechanisms .dedicated_process ,
1216+ ],
1217+ )
1218+ def test_parallel_execution_streaming_with_shared_executor (self , execution_mechanism ):
1219+ """Test streaming works with shared_executor using different underlying mechanisms."""
12191220 shared_executor = RunnableExecutor ()
12201221 shared_runnable = StreamingRunnable (name = "shared_streamer" )
1221- shared_executor .add_runnable (shared_runnable , ParallelExecutionMechanisms . thread_pool )
1222+ shared_executor .add_runnable (shared_runnable , execution_mechanism )
12221223
1223- # Create a proxy runnable that references the shared executor
12241224 proxy_runnable = StreamingRunnable (name = "proxy" , shared_runnable_name = "shared_streamer" )
12251225
12261226 class ContextWithExecutor :
@@ -1251,67 +1251,45 @@ def __init__(self, executor):
12511251 controller .await_termination ()
12521252
12531253 @pytest .mark .parametrize (
1254- "mechanism" ,
1255- [ParallelExecutionMechanisms .process_pool , ParallelExecutionMechanisms .dedicated_process ],
1254+ "execution_mechanism,expected_error" ,
1255+ [
1256+ (ParallelExecutionMechanisms .naive , ValueError ),
1257+ (ParallelExecutionMechanisms .thread_pool , ValueError ),
1258+ # Process-based mechanisms wrap errors in RuntimeError
1259+ (ParallelExecutionMechanisms .process_pool , RuntimeError ),
1260+ (ParallelExecutionMechanisms .dedicated_process , RuntimeError ),
1261+ ],
12561262 )
1257- def test_parallel_execution_streaming_with_process_based_fails_at_init (self , mechanism ):
1258- """Test that StreamingError is raised at init time when streaming runnable uses process-based mechanism."""
1259- runnable = StreamingRunnable (name = "streamer" )
1260-
1261- flow = build_flow (
1263+ def test_parallel_execution_streaming_error_propagation (self , execution_mechanism , expected_error ):
1264+ """Test that errors in streaming are propagated correctly."""
1265+ runnable = ErrorStreamingRunnable (name = "error_streamer" )
1266+ controller = build_flow (
12621267 [
12631268 SyncEmitSource (),
12641269 ParallelExecution (
12651270 runnables = [runnable ],
1266- execution_mechanism_by_runnable_name = {"streamer " : mechanism },
1271+ execution_mechanism_by_runnable_name = {"error_streamer " : execution_mechanism },
12671272 ),
12681273 Complete (),
12691274 ]
1270- )
1271-
1272- expected_error_message = (
1273- "Streaming is not supported with process-based execution mechanisms. "
1274- f"Runnable 'streamer' uses '{ mechanism } '. "
1275- "Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1276- )
1277- with pytest .raises (StreamingError , match = expected_error_message ):
1278- flow .run ()
1279-
1280- def test_parallel_execution_streaming_with_shared_executor_process_based_fails_at_init (self ):
1281- """Test that StreamingError is raised at init when shared_executor uses process-based mechanism."""
1282- # Create a shared executor with a process-based runnable
1283- shared_executor = RunnableExecutor ()
1284- shared_runnable = StreamingRunnable (name = "shared_streamer" )
1285- shared_executor .add_runnable (shared_runnable , ParallelExecutionMechanisms .process_pool )
1286-
1287- # Create a proxy runnable that references the shared executor
1288- proxy_runnable = StreamingRunnable (name = "proxy" , shared_runnable_name = "shared_streamer" )
1289-
1290- class ContextWithExecutor :
1291- def __init__ (self , executor ):
1292- self .executor = executor
1293-
1294- context = ContextWithExecutor (shared_executor )
1275+ ).run ()
12951276
1296- flow = build_flow (
1297- [
1298- SyncEmitSource (),
1299- ParallelExecution (
1300- runnables = [proxy_runnable ],
1301- execution_mechanism_by_runnable_name = {"proxy" : ParallelExecutionMechanisms .shared_executor },
1302- context = context ,
1303- ),
1304- Complete (),
1305- ]
1306- )
1307-
1308- expected_error_message = (
1309- "Streaming is not supported with process-based execution mechanisms. "
1310- "Runnable 'shared_streamer' uses 'process_pool'. "
1311- "Use 'thread_pool', 'asyncio', or 'naive' for streaming runnables."
1312- )
1313- with pytest .raises (StreamingError , match = expected_error_message ):
1314- flow .run ()
1277+ try :
1278+ awaitable = controller .emit ("test" )
1279+ result = awaitable .await_result ()
1280+ assert inspect .isgenerator (result )
1281+ # Should get first chunk, then error
1282+ chunks = []
1283+ with pytest .raises (expected_error , match = "Simulated streaming error" ):
1284+ for chunk in result :
1285+ chunks .append (chunk )
1286+ # Verify we got the first chunk before the error
1287+ assert chunks == ["test_chunk_0" ]
1288+ finally :
1289+ controller .terminate ()
1290+ # Error is also propagated through termination
1291+ with pytest .raises (expected_error , match = "Simulated streaming error" ):
1292+ controller .await_termination ()
13151293
13161294
13171295class TestStreamingGraphSplits :
0 commit comments