Skip to content

Commit bbf9d30

Browse files
feat: Add batch coalescing ability to shuffle reader exec (apache#1380)
* impl * fix format and simplify new
1 parent 2b51481 commit bbf9d30

1 file changed

Lines changed: 276 additions & 6 deletions

File tree

ballista/core/src/execution_plans/shuffle_reader.rs

Lines changed: 276 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use async_trait::async_trait;
1919
use datafusion::arrow::ipc::reader::StreamReader;
2020
use datafusion::common::stats::Precision;
21+
use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer, PushBatchStatus};
2122
use std::any::Any;
2223
use std::collections::HashMap;
2324
use std::fmt::Debug;
@@ -41,12 +42,14 @@ use datafusion::arrow::record_batch::RecordBatch;
4142
use datafusion::common::runtime::SpawnedTask;
4243

4344
use datafusion::error::{DataFusionError, Result};
44-
use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
45+
use datafusion::physical_plan::metrics::{
46+
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
47+
};
4548
use datafusion::physical_plan::{
4649
ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning,
4750
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
4851
};
49-
use futures::{Stream, StreamExt, TryStreamExt};
52+
use futures::{Stream, StreamExt, TryStreamExt, ready};
5053

5154
use crate::error::BallistaError;
5255
use datafusion::execution::context::TaskContext;
@@ -165,6 +168,7 @@ impl ExecutionPlan for ShuffleReaderExec {
165168
let max_message_size = config.ballista_grpc_client_max_message_size();
166169
let force_remote_read = config.ballista_shuffle_reader_force_remote_read();
167170
let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight();
171+
let batch_size = config.batch_size();
168172

169173
if force_remote_read {
170174
debug!(
@@ -200,11 +204,18 @@ impl ExecutionPlan for ShuffleReaderExec {
200204
prefer_flight,
201205
);
202206

203-
let result = RecordBatchStreamAdapter::new(
204-
Arc::new(self.schema.as_ref().clone()),
207+
let input_stream = Box::pin(RecordBatchStreamAdapter::new(
208+
self.schema.clone(),
205209
response_receiver.try_flatten(),
206-
);
207-
Ok(Box::pin(result))
210+
));
211+
212+
Ok(Box::pin(CoalescedShuffleReaderStream::new(
213+
input_stream,
214+
batch_size,
215+
None,
216+
&self.metrics,
217+
partition,
218+
)))
208219
}
209220

210221
fn metrics(&self) -> Option<MetricsSet> {
@@ -594,6 +605,96 @@ async fn fetch_partition_object_store(
594605
))
595606
}
596607

608+
struct CoalescedShuffleReaderStream {
609+
schema: SchemaRef,
610+
input: SendableRecordBatchStream,
611+
coalescer: LimitedBatchCoalescer,
612+
completed: bool,
613+
baseline_metrics: BaselineMetrics,
614+
}
615+
616+
impl CoalescedShuffleReaderStream {
617+
pub fn new(
618+
input: SendableRecordBatchStream,
619+
batch_size: usize,
620+
limit: Option<usize>,
621+
metrics: &ExecutionPlanMetricsSet,
622+
partition: usize,
623+
) -> Self {
624+
let schema = input.schema();
625+
Self {
626+
schema: schema.clone(),
627+
input,
628+
coalescer: LimitedBatchCoalescer::new(schema, batch_size, limit),
629+
completed: false,
630+
baseline_metrics: BaselineMetrics::new(metrics, partition),
631+
}
632+
}
633+
}
634+
635+
impl Stream for CoalescedShuffleReaderStream {
636+
type Item = Result<RecordBatch>;
637+
638+
fn poll_next(
639+
mut self: Pin<&mut Self>,
640+
cx: &mut Context<'_>,
641+
) -> Poll<Option<Self::Item>> {
642+
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
643+
let _timer = elapsed_compute.timer();
644+
645+
loop {
646+
// If there is already a completed batch ready, return it directly
647+
if let Some(batch) = self.coalescer.next_completed_batch() {
648+
self.baseline_metrics.record_output(batch.num_rows());
649+
return Poll::Ready(Some(Ok(batch)));
650+
}
651+
652+
// If the upstream is completed, then it is completed for this stream too
653+
if self.completed {
654+
return Poll::Ready(None);
655+
}
656+
657+
// Pull from upstream
658+
match ready!(self.input.poll_next_unpin(cx)) {
659+
// If upstream is completed, then flush remaning buffered batches
660+
None => {
661+
self.completed = true;
662+
if let Err(e) = self.coalescer.finish() {
663+
return Poll::Ready(Some(Err(e)));
664+
}
665+
}
666+
// If upstream is not completed, then push to coalescer
667+
Some(Ok(batch)) => {
668+
if batch.num_rows() > 0 {
669+
// Try to push to coalescer
670+
match self.coalescer.push_batch(batch) {
671+
// If push is successful, then continue
672+
Ok(PushBatchStatus::Continue) => {
673+
continue;
674+
}
675+
// If limit is reached, then finish coalescer and set completed to true
676+
Ok(PushBatchStatus::LimitReached) => {
677+
self.completed = true;
678+
if let Err(e) = self.coalescer.finish() {
679+
return Poll::Ready(Some(Err(e)));
680+
}
681+
}
682+
Err(e) => return Poll::Ready(Some(Err(e))),
683+
}
684+
}
685+
}
686+
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
687+
}
688+
}
689+
}
690+
}
691+
692+
impl RecordBatchStream for CoalescedShuffleReaderStream {
693+
fn schema(&self) -> SchemaRef {
694+
self.schema.clone()
695+
}
696+
}
697+
597698
#[cfg(test)]
598699
mod tests {
599700
use super::*;
@@ -1052,10 +1153,179 @@ mod tests {
10521153
.unwrap()
10531154
}
10541155

1156+
fn create_custom_test_batch(rows: usize) -> RecordBatch {
1157+
let schema = create_test_schema();
1158+
1159+
// 1. Create number column (0, 1, 2, ..., rows-1)
1160+
let number_vec: Vec<u32> = (0..rows as u32).collect();
1161+
let number_array = UInt32Array::from(number_vec);
1162+
1163+
// 2. Create string column ("s0", "s1", ..., "s{rows-1}")
1164+
// Just to fill data, the content is not important
1165+
let string_vec: Vec<String> = (0..rows).map(|i| format!("s{}", i)).collect();
1166+
let string_array = StringArray::from(string_vec);
1167+
1168+
RecordBatch::try_new(schema, vec![Arc::new(number_array), Arc::new(string_array)])
1169+
.unwrap()
1170+
}
1171+
10551172
fn create_test_schema() -> SchemaRef {
10561173
Arc::new(Schema::new(vec![
10571174
Field::new("number", DataType::UInt32, true),
10581175
Field::new("str", DataType::Utf8, true),
10591176
]))
10601177
}
1178+
1179+
use datafusion::physical_plan::memory::MemoryStream;
1180+
1181+
#[tokio::test]
1182+
async fn test_coalesce_stream_logic() -> Result<()> {
1183+
// 1. Create test data - 10 small batches, each with 3 rows
1184+
let schema = create_test_schema();
1185+
let small_batch = create_test_batch();
1186+
let batches = vec![small_batch.clone(); 10];
1187+
1188+
// 2. Create mock upstream stream (Input Stream)
1189+
let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
1190+
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
1191+
1192+
// 3. Configure Coalescer: target batch size to 10 rows
1193+
let target_batch_size = 10;
1194+
1195+
// 4. Manually build the CoalescedShuffleReaderStream
1196+
let coalesced_stream = CoalescedShuffleReaderStream::new(
1197+
input_stream,
1198+
target_batch_size,
1199+
None,
1200+
&ExecutionPlanMetricsSet::new(),
1201+
0,
1202+
);
1203+
1204+
// 5. Execute stream and collect results
1205+
let output_batches = common::collect(Box::pin(coalesced_stream)).await?;
1206+
1207+
// 6. Assertions
1208+
// Assert A: Data total not lost (30 rows)
1209+
let total_rows: usize = output_batches.iter().map(|b| b.num_rows()).sum();
1210+
assert_eq!(total_rows, 30);
1211+
1212+
// Assert B: Batch count reduced (10 -> 3)
1213+
assert_eq!(output_batches.len(), 3);
1214+
1215+
// Assert C: Each batch size is correct (all should be 10)
1216+
assert_eq!(output_batches[0].num_rows(), 10);
1217+
assert_eq!(output_batches[1].num_rows(), 10);
1218+
assert_eq!(output_batches[2].num_rows(), 10);
1219+
1220+
Ok(())
1221+
}
1222+
1223+
#[tokio::test]
1224+
async fn test_coalesce_stream_remainder_flush() -> Result<()> {
1225+
let schema = create_test_schema();
1226+
// Create 10 small batch, each with 3 rows. Total 30 rows.
1227+
let small_batch = create_test_batch();
1228+
let batches = vec![small_batch.clone(); 10];
1229+
1230+
let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
1231+
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
1232+
1233+
// Target set to 100 rows.
1234+
// Because 30 < 100, it can never be filled. Must depend on the `finish()` mechanism to flush out these 30 rows at the end of the stream.
1235+
let target_batch_size = 100;
1236+
1237+
let coalesced_stream = CoalescedShuffleReaderStream::new(
1238+
input_stream,
1239+
target_batch_size,
1240+
None,
1241+
&ExecutionPlanMetricsSet::new(),
1242+
0,
1243+
);
1244+
1245+
let output_batches = common::collect(Box::pin(coalesced_stream)).await?;
1246+
1247+
// Assertions
1248+
assert_eq!(output_batches.len(), 1); // Should only have 1 batch
1249+
assert_eq!(output_batches[0].num_rows(), 30); // Should contain all 30 rows
1250+
1251+
Ok(())
1252+
}
1253+
1254+
#[tokio::test]
1255+
async fn test_coalesce_stream_large_batch() -> Result<()> {
1256+
let schema = create_test_schema();
1257+
1258+
// 1. Create a large batch (20 rows)
1259+
let big_batch = create_custom_test_batch(20);
1260+
let batches = vec![big_batch.clone(); 10]; // Total 200 rows
1261+
1262+
let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
1263+
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
1264+
1265+
// 2. Target set to small size, 10 rows
1266+
let target_batch_size = 10;
1267+
1268+
let coalesced_stream = CoalescedShuffleReaderStream::new(
1269+
input_stream,
1270+
target_batch_size,
1271+
None,
1272+
&ExecutionPlanMetricsSet::new(),
1273+
0,
1274+
);
1275+
1276+
let output_batches = common::collect(Box::pin(coalesced_stream)).await?;
1277+
1278+
// 3. Validation: It should not split the large batch, but directly output it
1279+
// Coalescer will not split the batch if size > (max_batch_size / 2)
1280+
assert_eq!(output_batches.len(), 10);
1281+
assert_eq!(output_batches[0].num_rows(), 20);
1282+
1283+
Ok(())
1284+
}
1285+
1286+
use futures::stream;
1287+
1288+
#[tokio::test]
1289+
async fn test_coalesce_stream_error_propagation() -> Result<()> {
1290+
let schema = create_test_schema();
1291+
let small_batch = create_test_batch(); // 3行
1292+
1293+
// 1. Construct a stream with error
1294+
let batches = vec![
1295+
Ok(small_batch),
1296+
Err(DataFusionError::Execution(
1297+
"Network connection failed".to_string(),
1298+
)),
1299+
];
1300+
1301+
// 2. Construct a stream with error
1302+
let stream = stream::iter(batches);
1303+
let input_stream =
1304+
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream));
1305+
1306+
// 3. Configure Coalescer
1307+
let target_batch_size = 10;
1308+
1309+
let coalesced_stream = CoalescedShuffleReaderStream::new(
1310+
input_stream,
1311+
target_batch_size,
1312+
None,
1313+
&ExecutionPlanMetricsSet::new(),
1314+
0,
1315+
);
1316+
1317+
// 4. Execute stream
1318+
let result = common::collect(Box::pin(coalesced_stream)).await;
1319+
1320+
// 5. Validation
1321+
assert!(result.is_err());
1322+
assert!(
1323+
result
1324+
.unwrap_err()
1325+
.to_string()
1326+
.contains("Network connection failed")
1327+
);
1328+
1329+
Ok(())
1330+
}
10611331
}

0 commit comments

Comments
 (0)