1818use async_trait:: async_trait;
1919use datafusion:: arrow:: ipc:: reader:: StreamReader ;
2020use datafusion:: common:: stats:: Precision ;
21+ use datafusion:: physical_plan:: coalesce:: { LimitedBatchCoalescer , PushBatchStatus } ;
2122use std:: any:: Any ;
2223use std:: collections:: HashMap ;
2324use std:: fmt:: Debug ;
@@ -41,12 +42,14 @@ use datafusion::arrow::record_batch::RecordBatch;
4142use datafusion:: common:: runtime:: SpawnedTask ;
4243
4344use datafusion:: error:: { DataFusionError , Result } ;
44- use datafusion:: physical_plan:: metrics:: { ExecutionPlanMetricsSet , MetricsSet } ;
45+ use datafusion:: physical_plan:: metrics:: {
46+ BaselineMetrics , ExecutionPlanMetricsSet , MetricsSet ,
47+ } ;
4548use datafusion:: physical_plan:: {
4649 ColumnStatistics , DisplayAs , DisplayFormatType , ExecutionPlan , Partitioning ,
4750 PlanProperties , RecordBatchStream , SendableRecordBatchStream , Statistics ,
4851} ;
49- use futures:: { Stream , StreamExt , TryStreamExt } ;
52+ use futures:: { Stream , StreamExt , TryStreamExt , ready } ;
5053
5154use crate :: error:: BallistaError ;
5255use datafusion:: execution:: context:: TaskContext ;
@@ -165,6 +168,7 @@ impl ExecutionPlan for ShuffleReaderExec {
165168 let max_message_size = config. ballista_grpc_client_max_message_size ( ) ;
166169 let force_remote_read = config. ballista_shuffle_reader_force_remote_read ( ) ;
167170 let prefer_flight = config. ballista_shuffle_reader_remote_prefer_flight ( ) ;
171+ let batch_size = config. batch_size ( ) ;
168172
169173 if force_remote_read {
170174 debug ! (
@@ -200,11 +204,18 @@ impl ExecutionPlan for ShuffleReaderExec {
200204 prefer_flight,
201205 ) ;
202206
203- let result = RecordBatchStreamAdapter :: new (
204- Arc :: new ( self . schema . as_ref ( ) . clone ( ) ) ,
207+ let input_stream = Box :: pin ( RecordBatchStreamAdapter :: new (
208+ self . schema . clone ( ) ,
205209 response_receiver. try_flatten ( ) ,
206- ) ;
207- Ok ( Box :: pin ( result) )
210+ ) ) ;
211+
212+ Ok ( Box :: pin ( CoalescedShuffleReaderStream :: new (
213+ input_stream,
214+ batch_size,
215+ None ,
216+ & self . metrics ,
217+ partition,
218+ ) ) )
208219 }
209220
210221 fn metrics ( & self ) -> Option < MetricsSet > {
@@ -594,6 +605,96 @@ async fn fetch_partition_object_store(
594605 ) )
595606}
596607
608+ struct CoalescedShuffleReaderStream {
609+ schema : SchemaRef ,
610+ input : SendableRecordBatchStream ,
611+ coalescer : LimitedBatchCoalescer ,
612+ completed : bool ,
613+ baseline_metrics : BaselineMetrics ,
614+ }
615+
616+ impl CoalescedShuffleReaderStream {
617+ pub fn new (
618+ input : SendableRecordBatchStream ,
619+ batch_size : usize ,
620+ limit : Option < usize > ,
621+ metrics : & ExecutionPlanMetricsSet ,
622+ partition : usize ,
623+ ) -> Self {
624+ let schema = input. schema ( ) ;
625+ Self {
626+ schema : schema. clone ( ) ,
627+ input,
628+ coalescer : LimitedBatchCoalescer :: new ( schema, batch_size, limit) ,
629+ completed : false ,
630+ baseline_metrics : BaselineMetrics :: new ( metrics, partition) ,
631+ }
632+ }
633+ }
634+
635+ impl Stream for CoalescedShuffleReaderStream {
636+ type Item = Result < RecordBatch > ;
637+
638+ fn poll_next (
639+ mut self : Pin < & mut Self > ,
640+ cx : & mut Context < ' _ > ,
641+ ) -> Poll < Option < Self :: Item > > {
642+ let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
643+ let _timer = elapsed_compute. timer ( ) ;
644+
645+ loop {
646+ // If there is already a completed batch ready, return it directly
647+ if let Some ( batch) = self . coalescer . next_completed_batch ( ) {
648+ self . baseline_metrics . record_output ( batch. num_rows ( ) ) ;
649+ return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
650+ }
651+
652+ // If the upstream is completed, then it is completed for this stream too
653+ if self . completed {
654+ return Poll :: Ready ( None ) ;
655+ }
656+
657+ // Pull from upstream
658+ match ready ! ( self . input. poll_next_unpin( cx) ) {
659+ // If upstream is completed, then flush remaning buffered batches
660+ None => {
661+ self . completed = true ;
662+ if let Err ( e) = self . coalescer . finish ( ) {
663+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
664+ }
665+ }
666+ // If upstream is not completed, then push to coalescer
667+ Some ( Ok ( batch) ) => {
668+ if batch. num_rows ( ) > 0 {
669+ // Try to push to coalescer
670+ match self . coalescer . push_batch ( batch) {
671+ // If push is successful, then continue
672+ Ok ( PushBatchStatus :: Continue ) => {
673+ continue ;
674+ }
675+ // If limit is reached, then finish coalescer and set completed to true
676+ Ok ( PushBatchStatus :: LimitReached ) => {
677+ self . completed = true ;
678+ if let Err ( e) = self . coalescer . finish ( ) {
679+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
680+ }
681+ }
682+ Err ( e) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
683+ }
684+ }
685+ }
686+ Some ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
687+ }
688+ }
689+ }
690+ }
691+
692+ impl RecordBatchStream for CoalescedShuffleReaderStream {
693+ fn schema ( & self ) -> SchemaRef {
694+ self . schema . clone ( )
695+ }
696+ }
697+
597698#[ cfg( test) ]
598699mod tests {
599700 use super :: * ;
@@ -1052,10 +1153,179 @@ mod tests {
10521153 . unwrap ( )
10531154 }
10541155
1156+ fn create_custom_test_batch ( rows : usize ) -> RecordBatch {
1157+ let schema = create_test_schema ( ) ;
1158+
1159+ // 1. Create number column (0, 1, 2, ..., rows-1)
1160+ let number_vec: Vec < u32 > = ( 0 ..rows as u32 ) . collect ( ) ;
1161+ let number_array = UInt32Array :: from ( number_vec) ;
1162+
1163+ // 2. Create string column ("s0", "s1", ..., "s{rows-1}")
1164+ // Just to fill data, the content is not important
1165+ let string_vec: Vec < String > = ( 0 ..rows) . map ( |i| format ! ( "s{}" , i) ) . collect ( ) ;
1166+ let string_array = StringArray :: from ( string_vec) ;
1167+
1168+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( number_array) , Arc :: new( string_array) ] )
1169+ . unwrap ( )
1170+ }
1171+
10551172 fn create_test_schema ( ) -> SchemaRef {
10561173 Arc :: new ( Schema :: new ( vec ! [
10571174 Field :: new( "number" , DataType :: UInt32 , true ) ,
10581175 Field :: new( "str" , DataType :: Utf8 , true ) ,
10591176 ] ) )
10601177 }
1178+
1179+ use datafusion:: physical_plan:: memory:: MemoryStream ;
1180+
1181+ #[ tokio:: test]
1182+ async fn test_coalesce_stream_logic ( ) -> Result < ( ) > {
1183+ // 1. Create test data - 10 small batches, each with 3 rows
1184+ let schema = create_test_schema ( ) ;
1185+ let small_batch = create_test_batch ( ) ;
1186+ let batches = vec ! [ small_batch. clone( ) ; 10 ] ;
1187+
1188+ // 2. Create mock upstream stream (Input Stream)
1189+ let input_stream = MemoryStream :: try_new ( batches, schema. clone ( ) , None ) ?;
1190+ let input_stream = Box :: pin ( input_stream) as SendableRecordBatchStream ;
1191+
1192+ // 3. Configure Coalescer: target batch size to 10 rows
1193+ let target_batch_size = 10 ;
1194+
1195+ // 4. Manually build the CoalescedShuffleReaderStream
1196+ let coalesced_stream = CoalescedShuffleReaderStream :: new (
1197+ input_stream,
1198+ target_batch_size,
1199+ None ,
1200+ & ExecutionPlanMetricsSet :: new ( ) ,
1201+ 0 ,
1202+ ) ;
1203+
1204+ // 5. Execute stream and collect results
1205+ let output_batches = common:: collect ( Box :: pin ( coalesced_stream) ) . await ?;
1206+
1207+ // 6. Assertions
1208+ // Assert A: Data total not lost (30 rows)
1209+ let total_rows: usize = output_batches. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
1210+ assert_eq ! ( total_rows, 30 ) ;
1211+
1212+ // Assert B: Batch count reduced (10 -> 3)
1213+ assert_eq ! ( output_batches. len( ) , 3 ) ;
1214+
1215+ // Assert C: Each batch size is correct (all should be 10)
1216+ assert_eq ! ( output_batches[ 0 ] . num_rows( ) , 10 ) ;
1217+ assert_eq ! ( output_batches[ 1 ] . num_rows( ) , 10 ) ;
1218+ assert_eq ! ( output_batches[ 2 ] . num_rows( ) , 10 ) ;
1219+
1220+ Ok ( ( ) )
1221+ }
1222+
1223+ #[ tokio:: test]
1224+ async fn test_coalesce_stream_remainder_flush ( ) -> Result < ( ) > {
1225+ let schema = create_test_schema ( ) ;
1226+ // Create 10 small batch, each with 3 rows. Total 30 rows.
1227+ let small_batch = create_test_batch ( ) ;
1228+ let batches = vec ! [ small_batch. clone( ) ; 10 ] ;
1229+
1230+ let input_stream = MemoryStream :: try_new ( batches, schema. clone ( ) , None ) ?;
1231+ let input_stream = Box :: pin ( input_stream) as SendableRecordBatchStream ;
1232+
1233+ // Target set to 100 rows.
1234+ // Because 30 < 100, it can never be filled. Must depend on the `finish()` mechanism to flush out these 30 rows at the end of the stream.
1235+ let target_batch_size = 100 ;
1236+
1237+ let coalesced_stream = CoalescedShuffleReaderStream :: new (
1238+ input_stream,
1239+ target_batch_size,
1240+ None ,
1241+ & ExecutionPlanMetricsSet :: new ( ) ,
1242+ 0 ,
1243+ ) ;
1244+
1245+ let output_batches = common:: collect ( Box :: pin ( coalesced_stream) ) . await ?;
1246+
1247+ // Assertions
1248+ assert_eq ! ( output_batches. len( ) , 1 ) ; // Should only have 1 batch
1249+ assert_eq ! ( output_batches[ 0 ] . num_rows( ) , 30 ) ; // Should contain all 30 rows
1250+
1251+ Ok ( ( ) )
1252+ }
1253+
1254+ #[ tokio:: test]
1255+ async fn test_coalesce_stream_large_batch ( ) -> Result < ( ) > {
1256+ let schema = create_test_schema ( ) ;
1257+
1258+ // 1. Create a large batch (20 rows)
1259+ let big_batch = create_custom_test_batch ( 20 ) ;
1260+ let batches = vec ! [ big_batch. clone( ) ; 10 ] ; // Total 200 rows
1261+
1262+ let input_stream = MemoryStream :: try_new ( batches, schema. clone ( ) , None ) ?;
1263+ let input_stream = Box :: pin ( input_stream) as SendableRecordBatchStream ;
1264+
1265+ // 2. Target set to small size, 10 rows
1266+ let target_batch_size = 10 ;
1267+
1268+ let coalesced_stream = CoalescedShuffleReaderStream :: new (
1269+ input_stream,
1270+ target_batch_size,
1271+ None ,
1272+ & ExecutionPlanMetricsSet :: new ( ) ,
1273+ 0 ,
1274+ ) ;
1275+
1276+ let output_batches = common:: collect ( Box :: pin ( coalesced_stream) ) . await ?;
1277+
1278+ // 3. Validation: It should not split the large batch, but directly output it
1279+ // Coalescer will not split the batch if size > (max_batch_size / 2)
1280+ assert_eq ! ( output_batches. len( ) , 10 ) ;
1281+ assert_eq ! ( output_batches[ 0 ] . num_rows( ) , 20 ) ;
1282+
1283+ Ok ( ( ) )
1284+ }
1285+
1286+ use futures:: stream;
1287+
1288+ #[ tokio:: test]
1289+ async fn test_coalesce_stream_error_propagation ( ) -> Result < ( ) > {
1290+ let schema = create_test_schema ( ) ;
1291+ let small_batch = create_test_batch ( ) ; // 3行
1292+
1293+ // 1. Construct a stream with error
1294+ let batches = vec ! [
1295+ Ok ( small_batch) ,
1296+ Err ( DataFusionError :: Execution (
1297+ "Network connection failed" . to_string( ) ,
1298+ ) ) ,
1299+ ] ;
1300+
1301+ // 2. Construct a stream with error
1302+ let stream = stream:: iter ( batches) ;
1303+ let input_stream =
1304+ Box :: pin ( RecordBatchStreamAdapter :: new ( schema. clone ( ) , stream) ) ;
1305+
1306+ // 3. Configure Coalescer
1307+ let target_batch_size = 10 ;
1308+
1309+ let coalesced_stream = CoalescedShuffleReaderStream :: new (
1310+ input_stream,
1311+ target_batch_size,
1312+ None ,
1313+ & ExecutionPlanMetricsSet :: new ( ) ,
1314+ 0 ,
1315+ ) ;
1316+
1317+ // 4. Execute stream
1318+ let result = common:: collect ( Box :: pin ( coalesced_stream) ) . await ;
1319+
1320+ // 5. Validation
1321+ assert ! ( result. is_err( ) ) ;
1322+ assert ! (
1323+ result
1324+ . unwrap_err( )
1325+ . to_string( )
1326+ . contains( "Network connection failed" )
1327+ ) ;
1328+
1329+ Ok ( ( ) )
1330+ }
10611331}
0 commit comments