Skip to content

Commit ee2e2eb

Browse files
committed
Repartition exec experiment
1 parent 7eb85f3 commit ee2e2eb

1 file changed

Lines changed: 122 additions & 72 deletions

File tree

  • datafusion/physical-plan/src/repartition

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 122 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
//! partitions to M output partitions based on a partitioning scheme, optionally
2020
//! maintaining the order of the input rows in the output.
2121
22+
use std::fmt::{Debug, Formatter};
2223
use std::pin::Pin;
24+
use std::sync::atomic::{AtomicBool, Ordering};
2325
use std::sync::Arc;
2426
use std::task::{Context, Poll};
2527
use std::{any::Any, vec};
@@ -44,7 +46,7 @@ use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
4446
use arrow::compute::take_arrays;
4547
use arrow::datatypes::{SchemaRef, UInt32Type};
4648
use datafusion_common::utils::transpose;
47-
use datafusion_common::HashMap;
49+
use datafusion_common::{internal_err, HashMap};
4850
use datafusion_common::{not_impl_err, DataFusionError, Result};
4951
use datafusion_common_runtime::SpawnedTask;
5052
use datafusion_execution::memory_pool::MemoryConsumer;
@@ -63,9 +65,8 @@ type MaybeBatch = Option<Result<RecordBatch>>;
6365
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
6466
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
6567

66-
/// Inner state of [`RepartitionExec`].
6768
#[derive(Debug)]
68-
struct RepartitionExecState {
69+
struct ConsumingInputStreamsState {
6970
/// Channels for sending batches from input partitions to output partitions.
7071
/// Key is the partition number.
7172
channels: HashMap<
@@ -81,16 +82,84 @@ struct RepartitionExecState {
8182
abort_helper: Arc<Vec<SpawnedTask<()>>>,
8283
}
8384

85+
/// Inner state of [`RepartitionExec`].
86+
enum RepartitionExecState {
87+
/// Not initialized yet. This is the default state stored in the RepartitionExec node
88+
/// upon instantiation.
89+
NotInitialized,
90+
/// Input streams are initialized, but they are still not being consumed. The node
91+
/// transitions to this state when the arrow's RecordBatch stream is created in
92+
/// RepartitionExec::execute(), but before any message is polled.
93+
InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
94+
/// The input streams are being consumed. The node transitions to the state when
95+
/// the first message in the arrow's RecordBatch stream is consumed.
96+
ConsumingInputStreams(ConsumingInputStreamsState),
97+
}
98+
99+
impl Default for RepartitionExecState {
100+
fn default() -> Self {
101+
Self::NotInitialized
102+
}
103+
}
104+
105+
impl Debug for RepartitionExecState {
106+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107+
match self {
108+
RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
109+
RepartitionExecState::InputStreamsInitialized(v) => {
110+
write!(f, "InputStreamsInitialized({:?})", v.len())
111+
}
112+
RepartitionExecState::ConsumingInputStreams(v) => {
113+
write!(f, "ConsumingInputStreams({:?})", v)
114+
}
115+
}
116+
}
117+
}
118+
84119
impl RepartitionExecState {
85-
fn new(
120+
fn ensure_input_streams_initialized(
121+
&mut self,
86122
input: Arc<dyn ExecutionPlan>,
87-
partitioning: Partitioning,
88123
metrics: ExecutionPlanMetricsSet,
124+
output_partitions: usize,
125+
ctx: Arc<TaskContext>,
126+
) -> Result<()> {
127+
if !matches!(self, RepartitionExecState::NotInitialized) {
128+
return Ok(());
129+
}
130+
131+
let num_input_partitions = input.output_partitioning().partition_count();
132+
let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
133+
134+
for i in 0..num_input_partitions {
135+
let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
136+
137+
let timer = metrics.fetch_time.timer();
138+
let stream = input.execute(i, Arc::clone(&ctx))?;
139+
timer.done();
140+
141+
streams_and_metrics.push((stream, metrics));
142+
}
143+
*self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
144+
Ok(())
145+
}
146+
147+
fn consume_input_streams(
148+
&mut self,
149+
partitioning: Partitioning,
89150
preserve_order: bool,
90151
name: String,
91152
context: Arc<TaskContext>,
92-
) -> Self {
93-
let num_input_partitions = input.output_partitioning().partition_count();
153+
) -> Result<&mut ConsumingInputStreamsState> {
154+
let streams_and_metrics = match self {
155+
RepartitionExecState::NotInitialized => {
156+
return internal_err!("RepartitionExecState::init_input_streams must be called before consuming input streams");
157+
}
158+
RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
159+
RepartitionExecState::InputStreamsInitialized(value) => value,
160+
};
161+
162+
let num_input_partitions = streams_and_metrics.len();
94163
let num_output_partitions = partitioning.partition_count();
95164

96165
let (txs, rxs) = if preserve_order {
@@ -125,23 +194,21 @@ impl RepartitionExecState {
125194

126195
// launch one async task per *input* partition
127196
let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
128-
for i in 0..num_input_partitions {
197+
for (i, (stream, metrics)) in
198+
std::mem::take(streams_and_metrics).into_iter().enumerate()
199+
{
129200
let txs: HashMap<_, _> = channels
130201
.iter()
131202
.map(|(partition, (tx, _rx, reservation))| {
132203
(*partition, (tx[i].clone(), Arc::clone(reservation)))
133204
})
134205
.collect();
135206

136-
let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics);
137-
138207
let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
139-
Arc::clone(&input),
140-
i,
208+
stream,
141209
txs.clone(),
142210
partitioning.clone(),
143-
r_metrics,
144-
Arc::clone(&context),
211+
metrics,
145212
));
146213

147214
// In a separate task, wait for each input to be done
@@ -154,28 +221,17 @@ impl RepartitionExecState {
154221
));
155222
spawned_tasks.push(wait_for_task);
156223
}
157-
158-
Self {
224+
*self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
159225
channels,
160226
abort_helper: Arc::new(spawned_tasks),
227+
});
228+
match self {
229+
RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
230+
_ => unreachable!(),
161231
}
162232
}
163233
}
164234

165-
/// Lazily initialized state
166-
///
167-
/// Note that the state is initialized ONCE for all partitions by a single task(thread).
168-
/// This may take a short while. It is also like that multiple threads
169-
/// call execute at the same time, because we have just started "target partitions" tasks
170-
/// which is commonly set to the number of CPU cores and all call execute at the same time.
171-
///
172-
/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles
173-
/// in a mutex lock but instead allow other threads to do something useful.
174-
///
175-
/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration
176-
/// (e.g. removing channels on completion) where the overhead of `await` is not warranted.
177-
type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>;
178-
179235
/// A utility that can be used to partition batches based on [`Partitioning`]
180236
pub struct BatchPartitioner {
181237
state: BatchPartitionerState,
@@ -402,8 +458,12 @@ impl BatchPartitioner {
402458
pub struct RepartitionExec {
403459
/// Input execution plan
404460
input: Arc<dyn ExecutionPlan>,
405-
/// Inner state that is initialized when the first output stream is created.
406-
state: LazyState,
461+
/// Inner state that is initialized when the parent calls .execute() on this node
462+
/// and consumed as soon as the parent starts consuming this node.
463+
state: Arc<Mutex<RepartitionExecState>>,
464+
/// Stores whether the state has been initialized. Checking this AtomicBool is faster than
465+
/// locking the state's Mutex to check if the state is already initialized.
466+
state_initialized: Arc<AtomicBool>,
407467
/// Execution metrics
408468
metrics: ExecutionPlanMetricsSet,
409469
/// Boolean flag to decide whether to preserve ordering. If true means
@@ -482,11 +542,7 @@ impl RepartitionExec {
482542
}
483543

484544
impl DisplayAs for RepartitionExec {
485-
fn fmt_as(
486-
&self,
487-
t: DisplayFormatType,
488-
f: &mut std::fmt::Formatter,
489-
) -> std::fmt::Result {
545+
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
490546
match t {
491547
DisplayFormatType::Default | DisplayFormatType::Verbose => {
492548
write!(
@@ -580,7 +636,6 @@ impl ExecutionPlan for RepartitionExec {
580636
partition
581637
);
582638

583-
let lazy_state = Arc::clone(&self.state);
584639
let input = Arc::clone(&self.input);
585640
let partitioning = self.partitioning().clone();
586641
let metrics = self.metrics.clone();
@@ -592,30 +647,29 @@ impl ExecutionPlan for RepartitionExec {
592647
// Get existing ordering to use for merging
593648
let sort_exprs = self.sort_exprs().cloned().unwrap_or_default();
594649

650+
let state = Arc::clone(&self.state);
651+
if !self.state_initialized.swap(true, Ordering::Relaxed) {
652+
state.lock().ensure_input_streams_initialized(
653+
Arc::clone(&input),
654+
metrics.clone(),
655+
partitioning.partition_count(),
656+
Arc::clone(&context),
657+
)?;
658+
}
659+
595660
let stream = futures::stream::once(async move {
596661
let num_input_partitions = input.output_partitioning().partition_count();
597662

598-
let input_captured = Arc::clone(&input);
599-
let metrics_captured = metrics.clone();
600-
let name_captured = name.clone();
601-
let context_captured = Arc::clone(&context);
602-
let state = lazy_state
603-
.get_or_init(|| async move {
604-
Mutex::new(RepartitionExecState::new(
605-
input_captured,
606-
partitioning,
607-
metrics_captured,
608-
preserve_order,
609-
name_captured,
610-
context_captured,
611-
))
612-
})
613-
.await;
614-
615663
// lock scope
616664
let (mut rx, reservation, abort_helper) = {
617665
// lock mutexes
618666
let mut state = state.lock();
667+
let state = state.consume_input_streams(
668+
partitioning,
669+
preserve_order,
670+
name.clone(),
671+
Arc::clone(&context),
672+
)?;
619673

620674
// now return stream for the specified *output* partition which will
621675
// read from the channel
@@ -746,6 +800,7 @@ impl RepartitionExec {
746800
Ok(RepartitionExec {
747801
input,
748802
state: Default::default(),
803+
state_initialized: Arc::new(AtomicBool::new(false)),
749804
metrics: ExecutionPlanMetricsSet::new(),
750805
preserve_order,
751806
cache,
@@ -825,24 +880,17 @@ impl RepartitionExec {
825880
///
826881
/// txs hold the output sending channels for each output partition
827882
async fn pull_from_input(
828-
input: Arc<dyn ExecutionPlan>,
829-
partition: usize,
883+
mut stream: SendableRecordBatchStream,
830884
mut output_channels: HashMap<
831885
usize,
832886
(DistributionSender<MaybeBatch>, SharedMemoryReservation),
833887
>,
834888
partitioning: Partitioning,
835889
metrics: RepartitionMetrics,
836-
context: Arc<TaskContext>,
837890
) -> Result<()> {
838891
let mut partitioner =
839892
BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
840893

841-
// execute the child operator
842-
let timer = metrics.fetch_time.timer();
843-
let mut stream = input.execute(partition, context)?;
844-
timer.done();
845-
846894
// While there are still outputs to send to, keep pulling inputs
847895
let mut batches_until_yield = partitioner.num_partitions();
848896
while !output_channels.is_empty() {
@@ -1090,6 +1138,7 @@ mod tests {
10901138
use datafusion_common_runtime::JoinSet;
10911139
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
10921140
use insta::assert_snapshot;
1141+
use itertools::Itertools;
10931142

10941143
#[tokio::test]
10951144
async fn one_to_many_round_robin() -> Result<()> {
@@ -1270,15 +1319,9 @@ mod tests {
12701319
let partitioning = Partitioning::RoundRobinBatch(1);
12711320
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
12721321

1273-
// Note: this should pass (the stream can be created) but the
1274-
// error when the input is executed should get passed back
1275-
let output_stream = exec.execute(0, task_ctx).unwrap();
1276-
12771322
// Expect that an error is returned
1278-
let result_string = crate::common::collect(output_stream)
1279-
.await
1280-
.unwrap_err()
1281-
.to_string();
1323+
let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1324+
12821325
assert!(
12831326
result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
12841327
"actual: {result_string}"
@@ -1468,7 +1511,14 @@ mod tests {
14681511
});
14691512
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
14701513

1471-
assert_eq!(batches_without_drop, batches_with_drop);
1514+
fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1515+
batch
1516+
.into_iter()
1517+
.sorted_by_key(|b| format!("{b:?}"))
1518+
.collect()
1519+
}
1520+
1521+
assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
14721522
}
14731523

14741524
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {

0 commit comments

Comments
 (0)