@@ -350,10 +350,9 @@ impl ExecutionPlan for CoalescePartitionsExec {
350350#[ cfg( test) ]
351351mod tests {
352352 use super :: * ;
353- use crate :: RecordBatchStream ;
354353 use crate :: aggregates:: { AggregateExec , AggregateMode , PhysicalGroupBy } ;
355- use crate :: execution_plan:: { Boundedness , EmissionType } ;
356354 use crate :: expressions:: col;
355+ use crate :: memory:: { LazyBatchGenerator , LazyMemoryExec } ;
357356 use crate :: repartition:: RepartitionExec ;
358357 use crate :: test:: exec:: {
359358 BlockingExec , PanicExec , assert_strong_count_converges_to_zero,
@@ -365,17 +364,14 @@ mod tests {
365364 use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
366365 use arrow:: record_batch:: RecordBatch ;
367366 use datafusion_common:: Result ;
368- use datafusion_common:: internal_err;
369- use datafusion_common:: tree_node:: TreeNodeRecursion ;
370367 use datafusion_functions_aggregate:: count:: count_udaf;
371- use datafusion_physical_expr:: EquivalenceProperties ;
372- use datafusion_physical_expr:: PhysicalExpr ;
373368 use datafusion_physical_expr:: aggregate:: AggregateExprBuilder ;
374369
375- use futures:: { FutureExt , Stream } ;
376- use std:: pin:: Pin ;
370+ use futures:: FutureExt ;
371+ use parking_lot:: RwLock ;
372+ use std:: any:: Any ;
373+ use std:: fmt;
377374 use std:: sync:: { Arc , Weak } ;
378- use std:: task:: { Context , Poll } ;
379375 use std:: time:: { Duration , Instant } ;
380376
381377 #[ tokio:: test]
@@ -437,141 +433,73 @@ mod tests {
437433 . collect ( )
438434 }
439435
440- #[ derive( Debug ) ]
441- struct CountingExec {
436+ #[ derive( Debug , Clone ) ]
437+ struct CountingGenerator {
442438 schema : SchemaRef ,
443- partitions : usize ,
444- batches_per_partition : usize ,
439+ partition : usize ,
440+ next_batch : usize ,
441+ max_batches : usize ,
445442 rows_per_batch : usize ,
446- plan_ref : Arc < ( ) > ,
447- cache : Arc < PlanProperties > ,
448443 }
449444
450- impl CountingExec {
445+ impl CountingGenerator {
451446 fn new (
452447 schema : SchemaRef ,
453- partitions : usize ,
454- batches_per_partition : usize ,
448+ partition : usize ,
449+ max_batches : usize ,
455450 rows_per_batch : usize ,
456451 ) -> Self {
457- let cache = Arc :: new ( PlanProperties :: new (
458- EquivalenceProperties :: new ( Arc :: clone ( & schema) ) ,
459- Partitioning :: UnknownPartitioning ( partitions) ,
460- EmissionType :: Incremental ,
461- Boundedness :: Bounded ,
462- ) ) ;
463-
464452 Self {
465453 schema,
466- partitions,
467- batches_per_partition,
454+ partition,
455+ next_batch : 0 ,
456+ max_batches,
468457 rows_per_batch,
469- plan_ref : Arc :: new ( ( ) ) ,
470- cache,
471458 }
472459 }
473-
474- fn refs ( & self ) -> Weak < ( ) > {
475- Arc :: downgrade ( & self . plan_ref )
476- }
477460 }
478461
479- impl DisplayAs for CountingExec {
480- fn fmt_as (
481- & self ,
482- t : DisplayFormatType ,
483- f : & mut std:: fmt:: Formatter ,
484- ) -> std:: fmt:: Result {
485- match t {
486- DisplayFormatType :: Default | DisplayFormatType :: Verbose => {
487- write ! (
488- f,
489- "CountingExec: partitions={}, batches_per_partition={}" ,
490- self . partitions, self . batches_per_partition
491- )
492- }
493- DisplayFormatType :: TreeRender => {
494- writeln ! ( f, "partitions={}" , self . partitions) ?;
495- writeln ! ( f, "batches_per_partition={}" , self . batches_per_partition)
496- }
497- }
462+ impl fmt:: Display for CountingGenerator {
463+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
464+ write ! (
465+ f,
466+ "CountingGenerator: partition={}, max_batches={}, rows_per_batch={}" ,
467+ self . partition, self . max_batches, self . rows_per_batch
468+ )
498469 }
499470 }
500471
501- impl ExecutionPlan for CountingExec {
502- fn name ( & self ) -> & ' static str {
503- "CountingExec"
472+ impl LazyBatchGenerator for CountingGenerator {
473+ fn as_any ( & self ) -> & dyn Any {
474+ self
504475 }
505476
506- fn properties ( & self ) -> & Arc < PlanProperties > {
507- & self . cache
508- }
509-
510- fn children ( & self ) -> Vec < & Arc < dyn ExecutionPlan > > {
511- vec ! [ ]
512- }
513-
514- fn apply_expressions (
515- & self ,
516- _f : & mut dyn FnMut ( & dyn PhysicalExpr ) -> Result < TreeNodeRecursion > ,
517- ) -> Result < TreeNodeRecursion > {
518- Ok ( TreeNodeRecursion :: Continue )
519- }
520-
521- fn with_new_children (
522- self : Arc < Self > ,
523- _: Vec < Arc < dyn ExecutionPlan > > ,
524- ) -> Result < Arc < dyn ExecutionPlan > > {
525- internal_err ! ( "Children cannot be replaced in {self:?}" )
526- }
527-
528- fn execute (
529- & self ,
530- partition : usize ,
531- _context : Arc < TaskContext > ,
532- ) -> Result < SendableRecordBatchStream > {
533- Ok ( Box :: pin ( CountingStream {
534- schema : Arc :: clone ( & self . schema ) ,
535- next_value : partition * self . batches_per_partition * self . rows_per_batch ,
536- remaining : self . batches_per_partition ,
537- rows_per_batch : self . rows_per_batch ,
538- } ) )
539- }
540- }
541-
542- #[ derive( Debug ) ]
543- struct CountingStream {
544- schema : SchemaRef ,
545- next_value : usize ,
546- remaining : usize ,
547- rows_per_batch : usize ,
548- }
477+ fn generate_next_batch ( & mut self ) -> Result < Option < RecordBatch > > {
478+ if self . next_batch == self . max_batches {
479+ return Ok ( None ) ;
480+ }
549481
550- impl Stream for CountingStream {
551- type Item = Result < RecordBatch > ;
482+ let start = ( ( self . partition * self . max_batches + self . next_batch )
483+ * self . rows_per_batch ) as u64 ;
484+ self . next_batch += 1 ;
552485
553- fn poll_next (
554- mut self : Pin < & mut Self > ,
555- _cx : & mut Context < ' _ > ,
556- ) -> Poll < Option < Self :: Item > > {
557- if self . remaining == 0 {
558- return Poll :: Ready ( None ) ;
559- }
560- self . remaining -= 1 ;
561- let start = self . next_value as u64 ;
562- self . next_value += self . rows_per_batch ;
563486 let values =
564487 UInt64Array :: from_iter_values ( start..start + self . rows_per_batch as u64 ) ;
565- Poll :: Ready ( Some ( Ok ( RecordBatch :: try_new (
488+
489+ Ok ( Some ( RecordBatch :: try_new (
566490 Arc :: clone ( & self . schema ) ,
567491 vec ! [ Arc :: new( values) as ArrayRef ] ,
568- ) ?) ) )
492+ ) ?) )
569493 }
570- }
571494
572- impl RecordBatchStream for CountingStream {
573- fn schema ( & self ) -> SchemaRef {
574- Arc :: clone ( & self . schema )
495+ fn reset_state ( & self ) -> Arc < RwLock < dyn LazyBatchGenerator > > {
496+ Arc :: new ( RwLock :: new ( Self {
497+ schema : Arc :: clone ( & self . schema ) ,
498+ partition : self . partition ,
499+ next_batch : 0 ,
500+ max_batches : self . max_batches ,
501+ rows_per_batch : self . rows_per_batch ,
502+ } ) )
575503 }
576504 }
577505
@@ -601,15 +529,21 @@ mod tests {
601529 async fn cancellation_delay_coalesce_repartition ( ) -> Result < ( ) > {
602530 let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
603531 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: UInt64 , true ) ] ) ) ;
532+ let input_partitions = 2 ;
604533 let batches_per_input_partition = 8 ;
605534 let rows_per_batch = 128_000 ;
606- let input = Arc :: new ( CountingExec :: new (
607- Arc :: clone ( & schema) ,
608- 2 ,
609- batches_per_input_partition,
610- rows_per_batch,
611- ) ) ;
612- let input_refs = input. refs ( ) ;
535+ let generators = ( 0 ..input_partitions)
536+ . map ( |partition| {
537+ Arc :: new ( RwLock :: new ( CountingGenerator :: new (
538+ Arc :: clone ( & schema) ,
539+ partition,
540+ batches_per_input_partition,
541+ rows_per_batch,
542+ ) ) ) as Arc < RwLock < dyn LazyBatchGenerator > >
543+ } )
544+ . collect ( ) ;
545+ let input = Arc :: new ( LazyMemoryExec :: try_new ( Arc :: clone ( & schema) , generators) ?) ;
546+ let input_refs = Arc :: downgrade ( & input) ;
613547 let mut plan: Arc < dyn ExecutionPlan > =
614548 high_cardinality_partial_aggregate ( input, & schema) ?;
615549
0 commit comments