1919//! partitions to M output partitions based on a partitioning scheme, optionally
2020//! maintaining the order of the input rows in the output.
2121
22+ use std:: fmt:: { Debug , Formatter } ;
2223use std:: pin:: Pin ;
24+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
2325use std:: sync:: Arc ;
2426use std:: task:: { Context , Poll } ;
2527use std:: { any:: Any , vec} ;
@@ -44,7 +46,7 @@ use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
4446use arrow:: compute:: take_arrays;
4547use arrow:: datatypes:: { SchemaRef , UInt32Type } ;
4648use datafusion_common:: utils:: transpose;
47- use datafusion_common:: HashMap ;
49+ use datafusion_common:: { internal_err , HashMap } ;
4850use datafusion_common:: { not_impl_err, DataFusionError , Result } ;
4951use datafusion_common_runtime:: SpawnedTask ;
5052use datafusion_execution:: memory_pool:: MemoryConsumer ;
@@ -63,9 +65,8 @@ type MaybeBatch = Option<Result<RecordBatch>>;
6365type InputPartitionsToCurrentPartitionSender = Vec < DistributionSender < MaybeBatch > > ;
6466type InputPartitionsToCurrentPartitionReceiver = Vec < DistributionReceiver < MaybeBatch > > ;
6567
66- /// Inner state of [`RepartitionExec`].
6768#[ derive( Debug ) ]
68- struct RepartitionExecState {
69+ struct ConsumingInputStreamsState {
6970 /// Channels for sending batches from input partitions to output partitions.
7071 /// Key is the partition number.
7172 channels : HashMap <
@@ -81,16 +82,84 @@ struct RepartitionExecState {
8182 abort_helper : Arc < Vec < SpawnedTask < ( ) > > > ,
8283}
8384
85+ /// Inner state of [`RepartitionExec`].
86+ enum RepartitionExecState {
87+ /// Not initialized yet. This is the default state stored in the RepartitionExec node
88+ /// upon instantiation.
89+ NotInitialized ,
90+ /// Input streams are initialized, but they are still not being consumed. The node
91+ /// transitions to this state when the arrow's RecordBatch stream is created in
92+ /// RepartitionExec::execute(), but before any message is polled.
93+ InputStreamsInitialized ( Vec < ( SendableRecordBatchStream , RepartitionMetrics ) > ) ,
94+ /// The input streams are being consumed. The node transitions to the state when
95+ /// the first message in the arrow's RecordBatch stream is consumed.
96+ ConsumingInputStreams ( ConsumingInputStreamsState ) ,
97+ }
98+
99+ impl Default for RepartitionExecState {
100+ fn default ( ) -> Self {
101+ Self :: NotInitialized
102+ }
103+ }
104+
105+ impl Debug for RepartitionExecState {
106+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
107+ match self {
108+ RepartitionExecState :: NotInitialized => write ! ( f, "NotInitialized" ) ,
109+ RepartitionExecState :: InputStreamsInitialized ( v) => {
110+ write ! ( f, "InputStreamsInitialized({:?})" , v. len( ) )
111+ }
112+ RepartitionExecState :: ConsumingInputStreams ( v) => {
113+ write ! ( f, "ConsumingInputStreams({:?})" , v)
114+ }
115+ }
116+ }
117+ }
118+
84119impl RepartitionExecState {
85- fn new (
120+ fn ensure_input_streams_initialized (
121+ & mut self ,
86122 input : Arc < dyn ExecutionPlan > ,
87- partitioning : Partitioning ,
88123 metrics : ExecutionPlanMetricsSet ,
124+ output_partitions : usize ,
125+ ctx : Arc < TaskContext > ,
126+ ) -> Result < ( ) > {
127+ if !matches ! ( self , RepartitionExecState :: NotInitialized ) {
128+ return Ok ( ( ) ) ;
129+ }
130+
131+ let num_input_partitions = input. output_partitioning ( ) . partition_count ( ) ;
132+ let mut streams_and_metrics = Vec :: with_capacity ( num_input_partitions) ;
133+
134+ for i in 0 ..num_input_partitions {
135+ let metrics = RepartitionMetrics :: new ( i, output_partitions, & metrics) ;
136+
137+ let timer = metrics. fetch_time . timer ( ) ;
138+ let stream = input. execute ( i, Arc :: clone ( & ctx) ) ?;
139+ timer. done ( ) ;
140+
141+ streams_and_metrics. push ( ( stream, metrics) ) ;
142+ }
143+ * self = RepartitionExecState :: InputStreamsInitialized ( streams_and_metrics) ;
144+ Ok ( ( ) )
145+ }
146+
147+ fn consume_input_streams (
148+ & mut self ,
149+ partitioning : Partitioning ,
89150 preserve_order : bool ,
90151 name : String ,
91152 context : Arc < TaskContext > ,
92- ) -> Self {
93- let num_input_partitions = input. output_partitioning ( ) . partition_count ( ) ;
153+ ) -> Result < & mut ConsumingInputStreamsState > {
154+ let streams_and_metrics = match self {
155+ RepartitionExecState :: NotInitialized => {
156+ return internal_err ! ( "RepartitionExecState::init_input_streams must be called before consuming input streams" ) ;
157+ }
158+ RepartitionExecState :: ConsumingInputStreams ( value) => return Ok ( value) ,
159+ RepartitionExecState :: InputStreamsInitialized ( value) => value,
160+ } ;
161+
162+ let num_input_partitions = streams_and_metrics. len ( ) ;
94163 let num_output_partitions = partitioning. partition_count ( ) ;
95164
96165 let ( txs, rxs) = if preserve_order {
@@ -125,23 +194,21 @@ impl RepartitionExecState {
125194
126195 // launch one async task per *input* partition
127196 let mut spawned_tasks = Vec :: with_capacity ( num_input_partitions) ;
128- for i in 0 ..num_input_partitions {
197+ for ( i, ( stream, metrics) ) in
198+ std:: mem:: take ( streams_and_metrics) . into_iter ( ) . enumerate ( )
199+ {
129200 let txs: HashMap < _ , _ > = channels
130201 . iter ( )
131202 . map ( |( partition, ( tx, _rx, reservation) ) | {
132203 ( * partition, ( tx[ i] . clone ( ) , Arc :: clone ( reservation) ) )
133204 } )
134205 . collect ( ) ;
135206
136- let r_metrics = RepartitionMetrics :: new ( i, num_output_partitions, & metrics) ;
137-
138207 let input_task = SpawnedTask :: spawn ( RepartitionExec :: pull_from_input (
139- Arc :: clone ( & input) ,
140- i,
208+ stream,
141209 txs. clone ( ) ,
142210 partitioning. clone ( ) ,
143- r_metrics,
144- Arc :: clone ( & context) ,
211+ metrics,
145212 ) ) ;
146213
147214 // In a separate task, wait for each input to be done
@@ -154,28 +221,17 @@ impl RepartitionExecState {
154221 ) ) ;
155222 spawned_tasks. push ( wait_for_task) ;
156223 }
157-
158- Self {
224+ * self = Self :: ConsumingInputStreams ( ConsumingInputStreamsState {
159225 channels,
160226 abort_helper : Arc :: new ( spawned_tasks) ,
227+ } ) ;
228+ match self {
229+ RepartitionExecState :: ConsumingInputStreams ( value) => Ok ( value) ,
230+ _ => unreachable ! ( ) ,
161231 }
162232 }
163233}
164234
165- /// Lazily initialized state
166- ///
167- /// Note that the state is initialized ONCE for all partitions by a single task(thread).
168- /// This may take a short while. It is also like that multiple threads
169- /// call execute at the same time, because we have just started "target partitions" tasks
170- /// which is commonly set to the number of CPU cores and all call execute at the same time.
171- ///
172- /// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles
173- /// in a mutex lock but instead allow other threads to do something useful.
174- ///
175- /// Uses a parking_lot `Mutex` to control other accesses as they are very short duration
176- /// (e.g. removing channels on completion) where the overhead of `await` is not warranted.
177- type LazyState = Arc < tokio:: sync:: OnceCell < Mutex < RepartitionExecState > > > ;
178-
179235/// A utility that can be used to partition batches based on [`Partitioning`]
180236pub struct BatchPartitioner {
181237 state : BatchPartitionerState ,
@@ -402,8 +458,12 @@ impl BatchPartitioner {
402458pub struct RepartitionExec {
403459 /// Input execution plan
404460 input : Arc < dyn ExecutionPlan > ,
405- /// Inner state that is initialized when the first output stream is created.
406- state : LazyState ,
461+ /// Inner state that is initialized when the parent calls .execute() on this node
462+ /// and consumed as soon as the parent starts consuming this node.
463+ state : Arc < Mutex < RepartitionExecState > > ,
464+ /// Stores whether the state has been initialized. Checking this AtomicBool is faster than
465+ /// locking the state's Mutex to check if the state is already initialized.
466+ state_initialized : Arc < AtomicBool > ,
407467 /// Execution metrics
408468 metrics : ExecutionPlanMetricsSet ,
409469 /// Boolean flag to decide whether to preserve ordering. If true means
@@ -482,11 +542,7 @@ impl RepartitionExec {
482542}
483543
484544impl DisplayAs for RepartitionExec {
485- fn fmt_as (
486- & self ,
487- t : DisplayFormatType ,
488- f : & mut std:: fmt:: Formatter ,
489- ) -> std:: fmt:: Result {
545+ fn fmt_as ( & self , t : DisplayFormatType , f : & mut Formatter ) -> std:: fmt:: Result {
490546 match t {
491547 DisplayFormatType :: Default | DisplayFormatType :: Verbose => {
492548 write ! (
@@ -580,7 +636,6 @@ impl ExecutionPlan for RepartitionExec {
580636 partition
581637 ) ;
582638
583- let lazy_state = Arc :: clone ( & self . state ) ;
584639 let input = Arc :: clone ( & self . input ) ;
585640 let partitioning = self . partitioning ( ) . clone ( ) ;
586641 let metrics = self . metrics . clone ( ) ;
@@ -592,30 +647,29 @@ impl ExecutionPlan for RepartitionExec {
592647 // Get existing ordering to use for merging
593648 let sort_exprs = self . sort_exprs ( ) . cloned ( ) . unwrap_or_default ( ) ;
594649
650+ let state = Arc :: clone ( & self . state ) ;
651+ if !self . state_initialized . swap ( true , Ordering :: Relaxed ) {
652+ state. lock ( ) . ensure_input_streams_initialized (
653+ Arc :: clone ( & input) ,
654+ metrics. clone ( ) ,
655+ partitioning. partition_count ( ) ,
656+ Arc :: clone ( & context) ,
657+ ) ?;
658+ }
659+
595660 let stream = futures:: stream:: once ( async move {
596661 let num_input_partitions = input. output_partitioning ( ) . partition_count ( ) ;
597662
598- let input_captured = Arc :: clone ( & input) ;
599- let metrics_captured = metrics. clone ( ) ;
600- let name_captured = name. clone ( ) ;
601- let context_captured = Arc :: clone ( & context) ;
602- let state = lazy_state
603- . get_or_init ( || async move {
604- Mutex :: new ( RepartitionExecState :: new (
605- input_captured,
606- partitioning,
607- metrics_captured,
608- preserve_order,
609- name_captured,
610- context_captured,
611- ) )
612- } )
613- . await ;
614-
615663 // lock scope
616664 let ( mut rx, reservation, abort_helper) = {
617665 // lock mutexes
618666 let mut state = state. lock ( ) ;
667+ let state = state. consume_input_streams (
668+ partitioning,
669+ preserve_order,
670+ name. clone ( ) ,
671+ Arc :: clone ( & context) ,
672+ ) ?;
619673
620674 // now return stream for the specified *output* partition which will
621675 // read from the channel
@@ -746,6 +800,7 @@ impl RepartitionExec {
746800 Ok ( RepartitionExec {
747801 input,
748802 state : Default :: default ( ) ,
803+ state_initialized : Arc :: new ( AtomicBool :: new ( false ) ) ,
749804 metrics : ExecutionPlanMetricsSet :: new ( ) ,
750805 preserve_order,
751806 cache,
@@ -825,24 +880,17 @@ impl RepartitionExec {
825880 ///
826881 /// txs hold the output sending channels for each output partition
827882 async fn pull_from_input (
828- input : Arc < dyn ExecutionPlan > ,
829- partition : usize ,
883+ mut stream : SendableRecordBatchStream ,
830884 mut output_channels : HashMap <
831885 usize ,
832886 ( DistributionSender < MaybeBatch > , SharedMemoryReservation ) ,
833887 > ,
834888 partitioning : Partitioning ,
835889 metrics : RepartitionMetrics ,
836- context : Arc < TaskContext > ,
837890 ) -> Result < ( ) > {
838891 let mut partitioner =
839892 BatchPartitioner :: try_new ( partitioning, metrics. repartition_time . clone ( ) ) ?;
840893
841- // execute the child operator
842- let timer = metrics. fetch_time . timer ( ) ;
843- let mut stream = input. execute ( partition, context) ?;
844- timer. done ( ) ;
845-
846894 // While there are still outputs to send to, keep pulling inputs
847895 let mut batches_until_yield = partitioner. num_partitions ( ) ;
848896 while !output_channels. is_empty ( ) {
@@ -1090,6 +1138,7 @@ mod tests {
10901138 use datafusion_common_runtime:: JoinSet ;
10911139 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
10921140 use insta:: assert_snapshot;
1141+ use itertools:: Itertools ;
10931142
10941143 #[ tokio:: test]
10951144 async fn one_to_many_round_robin ( ) -> Result < ( ) > {
@@ -1270,15 +1319,9 @@ mod tests {
12701319 let partitioning = Partitioning :: RoundRobinBatch ( 1 ) ;
12711320 let exec = RepartitionExec :: try_new ( Arc :: new ( input) , partitioning) . unwrap ( ) ;
12721321
1273- // Note: this should pass (the stream can be created) but the
1274- // error when the input is executed should get passed back
1275- let output_stream = exec. execute ( 0 , task_ctx) . unwrap ( ) ;
1276-
12771322 // Expect that an error is returned
1278- let result_string = crate :: common:: collect ( output_stream)
1279- . await
1280- . unwrap_err ( )
1281- . to_string ( ) ;
1323+ let result_string = exec. execute ( 0 , task_ctx) . err ( ) . unwrap ( ) . to_string ( ) ;
1324+
12821325 assert ! (
12831326 result_string. contains( "ErrorExec, unsurprisingly, errored in partition 0" ) ,
12841327 "actual: {result_string}"
@@ -1468,7 +1511,14 @@ mod tests {
14681511 } ) ;
14691512 let batches_with_drop = crate :: common:: collect ( output_stream1) . await . unwrap ( ) ;
14701513
1471- assert_eq ! ( batches_without_drop, batches_with_drop) ;
1514+ fn sort ( batch : Vec < RecordBatch > ) -> Vec < RecordBatch > {
1515+ batch
1516+ . into_iter ( )
1517+ . sorted_by_key ( |b| format ! ( "{b:?}" ) )
1518+ . collect ( )
1519+ }
1520+
1521+ assert_eq ! ( sort( batches_without_drop) , sort( batches_with_drop) ) ;
14721522 }
14731523
14741524 fn str_batches_to_vec ( batches : & [ RecordBatch ] ) -> Vec < & str > {
0 commit comments