@@ -350,15 +350,33 @@ impl ExecutionPlan for CoalescePartitionsExec {
350350#[ cfg( test) ]
351351mod tests {
352352 use super :: * ;
353+ use crate :: RecordBatchStream ;
354+ use crate :: aggregates:: { AggregateExec , AggregateMode , PhysicalGroupBy } ;
355+ use crate :: execution_plan:: { Boundedness , EmissionType } ;
356+ use crate :: expressions:: col;
357+ use crate :: repartition:: RepartitionExec ;
353358 use crate :: test:: exec:: {
354359 BlockingExec , PanicExec , assert_strong_count_converges_to_zero,
355360 } ;
356361 use crate :: test:: { self , assert_is_pending} ;
357362 use crate :: { collect, common} ;
358363
359- use arrow:: datatypes:: { DataType , Field , Schema } ;
360-
361- use futures:: FutureExt ;
364+ use arrow:: array:: { ArrayRef , UInt64Array } ;
365+ use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
366+ use arrow:: record_batch:: RecordBatch ;
367+ use datafusion_common:: Result ;
368+ use datafusion_common:: internal_err;
369+ use datafusion_common:: tree_node:: TreeNodeRecursion ;
370+ use datafusion_functions_aggregate:: count:: count_udaf;
371+ use datafusion_physical_expr:: EquivalenceProperties ;
372+ use datafusion_physical_expr:: PhysicalExpr ;
373+ use datafusion_physical_expr:: aggregate:: AggregateExprBuilder ;
374+
375+ use futures:: { FutureExt , Stream } ;
376+ use std:: pin:: Pin ;
377+ use std:: sync:: { Arc , Weak } ;
378+ use std:: task:: { Context , Poll } ;
379+ use std:: time:: { Duration , Instant } ;
362380
363381 #[ tokio:: test]
364382 async fn merge ( ) -> Result < ( ) > {
@@ -390,6 +408,259 @@ mod tests {
390408 Ok ( ( ) )
391409 }
392410
411+ async fn wait_for_repartition_drop_times (
412+ refs : & [ Weak < RepartitionExec > ] ,
413+ start : Instant ,
414+ ) -> Vec < Duration > {
415+ let mut drop_times = vec ! [ None ; refs. len( ) ] ;
416+ tokio:: time:: timeout ( Duration :: from_secs ( 10 ) , async {
417+ loop {
418+ for ( idx, refs) in refs. iter ( ) . enumerate ( ) {
419+ if drop_times[ idx] . is_none ( ) && refs. strong_count ( ) == 0 {
420+ drop_times[ idx] = Some ( start. elapsed ( ) ) ;
421+ }
422+ }
423+
424+ if drop_times. iter ( ) . all ( Option :: is_some) {
425+ break ;
426+ }
427+
428+ tokio:: time:: sleep ( Duration :: from_millis ( 1 ) ) . await ;
429+ }
430+ } )
431+ . await
432+ . unwrap ( ) ;
433+
434+ drop_times
435+ . into_iter ( )
436+ . map ( |drop_time| drop_time. expect ( "all repartition refs dropped" ) )
437+ . collect ( )
438+ }
439+
440+ #[ derive( Debug ) ]
441+ struct CountingExec {
442+ schema : SchemaRef ,
443+ partitions : usize ,
444+ batches_per_partition : usize ,
445+ rows_per_batch : usize ,
446+ plan_ref : Arc < ( ) > ,
447+ cache : Arc < PlanProperties > ,
448+ }
449+
450+ impl CountingExec {
451+ fn new (
452+ schema : SchemaRef ,
453+ partitions : usize ,
454+ batches_per_partition : usize ,
455+ rows_per_batch : usize ,
456+ ) -> Self {
457+ let cache = Arc :: new ( PlanProperties :: new (
458+ EquivalenceProperties :: new ( Arc :: clone ( & schema) ) ,
459+ Partitioning :: UnknownPartitioning ( partitions) ,
460+ EmissionType :: Incremental ,
461+ Boundedness :: Bounded ,
462+ ) ) ;
463+
464+ Self {
465+ schema,
466+ partitions,
467+ batches_per_partition,
468+ rows_per_batch,
469+ plan_ref : Arc :: new ( ( ) ) ,
470+ cache,
471+ }
472+ }
473+
474+ fn refs ( & self ) -> Weak < ( ) > {
475+ Arc :: downgrade ( & self . plan_ref )
476+ }
477+ }
478+
479+ impl DisplayAs for CountingExec {
480+ fn fmt_as (
481+ & self ,
482+ t : DisplayFormatType ,
483+ f : & mut std:: fmt:: Formatter ,
484+ ) -> std:: fmt:: Result {
485+ match t {
486+ DisplayFormatType :: Default | DisplayFormatType :: Verbose => {
487+ write ! (
488+ f,
489+ "CountingExec: partitions={}, batches_per_partition={}" ,
490+ self . partitions, self . batches_per_partition
491+ )
492+ }
493+ DisplayFormatType :: TreeRender => {
494+ writeln ! ( f, "partitions={}" , self . partitions) ?;
495+ writeln ! ( f, "batches_per_partition={}" , self . batches_per_partition)
496+ }
497+ }
498+ }
499+ }
500+
501+ impl ExecutionPlan for CountingExec {
502+ fn name ( & self ) -> & ' static str {
503+ "CountingExec"
504+ }
505+
506+ fn properties ( & self ) -> & Arc < PlanProperties > {
507+ & self . cache
508+ }
509+
510+ fn children ( & self ) -> Vec < & Arc < dyn ExecutionPlan > > {
511+ vec ! [ ]
512+ }
513+
514+ fn apply_expressions (
515+ & self ,
516+ _f : & mut dyn FnMut ( & dyn PhysicalExpr ) -> Result < TreeNodeRecursion > ,
517+ ) -> Result < TreeNodeRecursion > {
518+ Ok ( TreeNodeRecursion :: Continue )
519+ }
520+
521+ fn with_new_children (
522+ self : Arc < Self > ,
523+ _: Vec < Arc < dyn ExecutionPlan > > ,
524+ ) -> Result < Arc < dyn ExecutionPlan > > {
525+ internal_err ! ( "Children cannot be replaced in {self:?}" )
526+ }
527+
528+ fn execute (
529+ & self ,
530+ partition : usize ,
531+ _context : Arc < TaskContext > ,
532+ ) -> Result < SendableRecordBatchStream > {
533+ Ok ( Box :: pin ( CountingStream {
534+ schema : Arc :: clone ( & self . schema ) ,
535+ next_value : partition * self . batches_per_partition * self . rows_per_batch ,
536+ remaining : self . batches_per_partition ,
537+ rows_per_batch : self . rows_per_batch ,
538+ } ) )
539+ }
540+ }
541+
542+ #[ derive( Debug ) ]
543+ struct CountingStream {
544+ schema : SchemaRef ,
545+ next_value : usize ,
546+ remaining : usize ,
547+ rows_per_batch : usize ,
548+ }
549+
550+ impl Stream for CountingStream {
551+ type Item = Result < RecordBatch > ;
552+
553+ fn poll_next (
554+ mut self : Pin < & mut Self > ,
555+ _cx : & mut Context < ' _ > ,
556+ ) -> Poll < Option < Self :: Item > > {
557+ if self . remaining == 0 {
558+ return Poll :: Ready ( None ) ;
559+ }
560+ self . remaining -= 1 ;
561+ let start = self . next_value as u64 ;
562+ self . next_value += self . rows_per_batch ;
563+ let values =
564+ UInt64Array :: from_iter_values ( start..start + self . rows_per_batch as u64 ) ;
565+ Poll :: Ready ( Some ( Ok ( RecordBatch :: try_new (
566+ Arc :: clone ( & self . schema ) ,
567+ vec ! [ Arc :: new( values) as ArrayRef ] ,
568+ ) ?) ) )
569+ }
570+ }
571+
572+ impl RecordBatchStream for CountingStream {
573+ fn schema ( & self ) -> SchemaRef {
574+ Arc :: clone ( & self . schema )
575+ }
576+ }
577+
578+ fn high_cardinality_partial_aggregate (
579+ input : Arc < dyn ExecutionPlan > ,
580+ schema : & SchemaRef ,
581+ ) -> Result < Arc < AggregateExec > > {
582+ let groups =
583+ PhysicalGroupBy :: new_single ( vec ! [ ( col( "a" , schema) ?, "a" . to_string( ) ) ] ) ;
584+ let count = AggregateExprBuilder :: new ( count_udaf ( ) , vec ! [ col( "a" , schema) ?] )
585+ . schema ( Arc :: clone ( schema) )
586+ . alias ( "COUNT(a)" )
587+ . build ( ) ?;
588+
589+ Ok ( Arc :: new ( AggregateExec :: try_new (
590+ AggregateMode :: Partial ,
591+ groups,
592+ vec ! [ Arc :: new( count) ] ,
593+ vec ! [ None ] ,
594+ input,
595+ Arc :: clone ( schema) ,
596+ ) ?) )
597+ }
598+
599+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 4 ) ]
600+ #[ ignore = "temporary diagnostic reproducer for layered cancellation delay" ]
601+ async fn cancellation_delay_coalesce_repartition ( ) -> Result < ( ) > {
602+ let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
603+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: UInt64 , true ) ] ) ) ;
604+ let batches_per_input_partition = 8 ;
605+ let rows_per_batch = 128_000 ;
606+ let input = Arc :: new ( CountingExec :: new (
607+ Arc :: clone ( & schema) ,
608+ 2 ,
609+ batches_per_input_partition,
610+ rows_per_batch,
611+ ) ) ;
612+ let input_refs = input. refs ( ) ;
613+ let mut plan: Arc < dyn ExecutionPlan > =
614+ high_cardinality_partial_aggregate ( input, & schema) ?;
615+
616+ let layers = 512 ;
617+ let output_partitions = 32 ;
618+ let mut repartition_refs = Vec :: with_capacity ( layers) ;
619+
620+ for _ in 0 ..layers {
621+ let repartition = Arc :: new ( RepartitionExec :: try_new (
622+ plan,
623+ Partitioning :: RoundRobinBatch ( output_partitions) ,
624+ ) ?) ;
625+ repartition_refs. push ( Arc :: downgrade ( & repartition) ) ;
626+
627+ plan = Arc :: new ( CoalescePartitionsExec :: new ( repartition) ) ;
628+ }
629+
630+ let handle = tokio:: spawn ( collect ( plan, task_ctx) ) ;
631+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
632+ assert ! ( !handle. is_finished( ) , "query finished before cancellation" ) ;
633+
634+ let start = Instant :: now ( ) ;
635+ handle. abort ( ) ;
636+
637+ let drop_times = wait_for_repartition_drop_times ( & repartition_refs, start) . await ;
638+ let total_elapsed = start. elapsed ( ) ;
639+
640+ for ( idx, elapsed) in drop_times. iter ( ) . enumerate ( ) . rev ( ) {
641+ let layer_from_top = layers - idx;
642+ if layer_from_top != 1 && layer_from_top != layers && layer_from_top % 32 != 0
643+ {
644+ continue ;
645+ }
646+ println ! (
647+ "layer_from_top={layer_from_top} repartition_drop_elapsed_ms={}" ,
648+ elapsed. as_millis( )
649+ ) ;
650+ }
651+ println ! (
652+ "layers={layers} output_partitions={output_partitions} input_rows_per_partition={} cancellation_elapsed_ms={}" ,
653+ batches_per_input_partition * rows_per_batch,
654+ total_elapsed. as_millis( )
655+ ) ;
656+ println ! (
657+ "input_plan_strong_count_after_cancel={}" ,
658+ input_refs. strong_count( )
659+ ) ;
660+
661+ Ok ( ( ) )
662+ }
663+
393664 #[ tokio:: test]
394665 async fn test_drop_cancel ( ) -> Result < ( ) > {
395666 let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
0 commit comments