Skip to content

Commit dddcf1d

Browse files
committed
two aggs on top of each other
1 parent 5f0a9d9 commit dddcf1d

1 file changed

Lines changed: 59 additions & 47 deletions

File tree

datafusion/physical-plan/src/coalesce_partitions.rs

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,49 @@ mod tests {
433433
.collect()
434434
}
435435

436+
fn high_cardinality_partitioned_aggregate(
437+
input: Arc<dyn ExecutionPlan>,
438+
schema: &SchemaRef,
439+
output_partitions: usize,
440+
) -> Result<(Arc<dyn ExecutionPlan>, Weak<RepartitionExec>)> {
441+
let groups =
442+
PhysicalGroupBy::new_single(vec![(col("a", schema)?, "a".to_string())]);
443+
let count = AggregateExprBuilder::new(count_udaf(), vec![col("a", schema)?])
444+
.schema(Arc::clone(schema))
445+
.alias("COUNT(a)")
446+
.build()
447+
.map(Arc::new)?;
448+
449+
let partial = Arc::new(AggregateExec::try_new(
450+
AggregateMode::Partial,
451+
groups,
452+
vec![count],
453+
vec![None],
454+
input,
455+
Arc::clone(schema),
456+
)?);
457+
458+
let hash_exprs = partial.output_group_expr();
459+
let final_grouping_set = partial.group_expr().as_final();
460+
let final_aggr_expr = partial.aggr_expr().to_vec();
461+
let repartition = Arc::new(RepartitionExec::try_new(
462+
partial,
463+
Partitioning::Hash(hash_exprs, output_partitions),
464+
)?);
465+
let repartition_ref = Arc::downgrade(&repartition);
466+
467+
let final_partitioned = Arc::new(AggregateExec::try_new(
468+
AggregateMode::FinalPartitioned,
469+
final_grouping_set,
470+
final_aggr_expr,
471+
vec![None],
472+
repartition,
473+
Arc::clone(schema),
474+
)?);
475+
476+
Ok((final_partitioned, repartition_ref))
477+
}
478+
436479
#[derive(Debug, Clone)]
437480
struct CountingGenerator {
438481
schema: SchemaRef,
@@ -503,49 +546,6 @@ mod tests {
503546
}
504547
}
505548

506-
fn high_cardinality_partitioned_aggregate(
507-
input: Arc<dyn ExecutionPlan>,
508-
schema: &SchemaRef,
509-
output_partitions: usize,
510-
) -> Result<(Arc<dyn ExecutionPlan>, Weak<RepartitionExec>)> {
511-
let groups =
512-
PhysicalGroupBy::new_single(vec![(col("a", schema)?, "a".to_string())]);
513-
let count = AggregateExprBuilder::new(count_udaf(), vec![col("a", schema)?])
514-
.schema(Arc::clone(schema))
515-
.alias("COUNT(a)")
516-
.build()
517-
.map(Arc::new)?;
518-
519-
let partial = Arc::new(AggregateExec::try_new(
520-
AggregateMode::Partial,
521-
groups,
522-
vec![count],
523-
vec![None],
524-
input,
525-
Arc::clone(schema),
526-
)?);
527-
528-
let hash_exprs = partial.output_group_expr();
529-
let final_grouping_set = partial.group_expr().as_final();
530-
let final_aggr_expr = partial.aggr_expr().to_vec();
531-
let repartition = Arc::new(RepartitionExec::try_new(
532-
partial,
533-
Partitioning::Hash(hash_exprs, output_partitions),
534-
)?);
535-
let repartition_ref = Arc::downgrade(&repartition);
536-
537-
let final_partitioned = Arc::new(AggregateExec::try_new(
538-
AggregateMode::FinalPartitioned,
539-
final_grouping_set,
540-
final_aggr_expr,
541-
vec![None],
542-
repartition,
543-
Arc::clone(schema),
544-
)?);
545-
546-
Ok((final_partitioned, repartition_ref))
547-
}
548-
549549
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
550550
#[ignore = "temporary diagnostic reproducer for layered cancellation delay"]
551551
async fn cancellation_delay_coalesce_repartition() -> Result<()> {
@@ -567,9 +567,16 @@ mod tests {
567567
let input = Arc::new(LazyMemoryExec::try_new(Arc::clone(&schema), generators)?);
568568
let input_refs = Arc::downgrade(&input);
569569
let output_partitions = 32;
570-
let (partitioned_aggregate, repartition_ref) =
570+
let (partitioned_aggregate, lower_repartition_ref) =
571571
high_cardinality_partitioned_aggregate(input, &schema, output_partitions)?;
572-
let repartition_refs = vec![repartition_ref];
572+
let lower_coalesce = Arc::new(CoalescePartitionsExec::new(partitioned_aggregate));
573+
let (partitioned_aggregate, upper_repartition_ref) =
574+
high_cardinality_partitioned_aggregate(
575+
lower_coalesce,
576+
&schema,
577+
output_partitions,
578+
)?;
579+
let repartition_refs = vec![lower_repartition_ref, upper_repartition_ref];
573580
let plan = Arc::new(CoalescePartitionsExec::new(partitioned_aggregate));
574581

575582
let handle = tokio::spawn(collect(plan, task_ctx));
@@ -582,9 +589,14 @@ mod tests {
582589
let drop_times = wait_for_repartition_drop_times(&repartition_refs, start).await;
583590
let total_elapsed = start.elapsed();
584591

585-
for elapsed in drop_times {
592+
for (idx, elapsed) in drop_times.into_iter().enumerate() {
593+
let repartition = match idx {
594+
0 => "lower",
595+
1 => "upper",
596+
_ => "unknown",
597+
};
586598
println!(
587-
"final_partitioned_hash_repartition_drop_elapsed_ms={}",
599+
"{repartition}_final_partitioned_hash_repartition_drop_elapsed_ms={}",
588600
elapsed.as_millis()
589601
);
590602
}

0 commit comments

Comments
 (0)