Skip to content

Commit 58a9778

Browse files
committed
reproducer and fix for coalesce delaying cancellation
1 parent 83c2c01 commit 58a9778

2 files changed

Lines changed: 279 additions & 8 deletions

File tree

datafusion/physical-plan/src/coalesce_partitions.rs

Lines changed: 274 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,15 +350,33 @@ impl ExecutionPlan for CoalescePartitionsExec {
350350
#[cfg(test)]
351351
mod tests {
352352
use super::*;
353+
use crate::RecordBatchStream;
354+
use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
355+
use crate::execution_plan::{Boundedness, EmissionType};
356+
use crate::expressions::col;
357+
use crate::repartition::RepartitionExec;
353358
use crate::test::exec::{
354359
BlockingExec, PanicExec, assert_strong_count_converges_to_zero,
355360
};
356361
use crate::test::{self, assert_is_pending};
357362
use crate::{collect, common};
358363

359-
use arrow::datatypes::{DataType, Field, Schema};
360-
361-
use futures::FutureExt;
364+
use arrow::array::{ArrayRef, UInt64Array};
365+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
366+
use arrow::record_batch::RecordBatch;
367+
use datafusion_common::Result;
368+
use datafusion_common::internal_err;
369+
use datafusion_common::tree_node::TreeNodeRecursion;
370+
use datafusion_functions_aggregate::count::count_udaf;
371+
use datafusion_physical_expr::EquivalenceProperties;
372+
use datafusion_physical_expr::PhysicalExpr;
373+
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
374+
375+
use futures::{FutureExt, Stream};
376+
use std::pin::Pin;
377+
use std::sync::{Arc, Weak};
378+
use std::task::{Context, Poll};
379+
use std::time::{Duration, Instant};
362380

363381
#[tokio::test]
364382
async fn merge() -> Result<()> {
@@ -390,6 +408,259 @@ mod tests {
390408
Ok(())
391409
}
392410

411+
async fn wait_for_repartition_drop_times(
412+
refs: &[Weak<RepartitionExec>],
413+
start: Instant,
414+
) -> Vec<Duration> {
415+
let mut drop_times = vec![None; refs.len()];
416+
tokio::time::timeout(Duration::from_secs(10), async {
417+
loop {
418+
for (idx, refs) in refs.iter().enumerate() {
419+
if drop_times[idx].is_none() && refs.strong_count() == 0 {
420+
drop_times[idx] = Some(start.elapsed());
421+
}
422+
}
423+
424+
if drop_times.iter().all(Option::is_some) {
425+
break;
426+
}
427+
428+
tokio::time::sleep(Duration::from_millis(1)).await;
429+
}
430+
})
431+
.await
432+
.unwrap();
433+
434+
drop_times
435+
.into_iter()
436+
.map(|drop_time| drop_time.expect("all repartition refs dropped"))
437+
.collect()
438+
}
439+
440+
#[derive(Debug)]
441+
struct CountingExec {
442+
schema: SchemaRef,
443+
partitions: usize,
444+
batches_per_partition: usize,
445+
rows_per_batch: usize,
446+
plan_ref: Arc<()>,
447+
cache: Arc<PlanProperties>,
448+
}
449+
450+
impl CountingExec {
451+
fn new(
452+
schema: SchemaRef,
453+
partitions: usize,
454+
batches_per_partition: usize,
455+
rows_per_batch: usize,
456+
) -> Self {
457+
let cache = Arc::new(PlanProperties::new(
458+
EquivalenceProperties::new(Arc::clone(&schema)),
459+
Partitioning::UnknownPartitioning(partitions),
460+
EmissionType::Incremental,
461+
Boundedness::Bounded,
462+
));
463+
464+
Self {
465+
schema,
466+
partitions,
467+
batches_per_partition,
468+
rows_per_batch,
469+
plan_ref: Arc::new(()),
470+
cache,
471+
}
472+
}
473+
474+
fn refs(&self) -> Weak<()> {
475+
Arc::downgrade(&self.plan_ref)
476+
}
477+
}
478+
479+
impl DisplayAs for CountingExec {
480+
fn fmt_as(
481+
&self,
482+
t: DisplayFormatType,
483+
f: &mut std::fmt::Formatter,
484+
) -> std::fmt::Result {
485+
match t {
486+
DisplayFormatType::Default | DisplayFormatType::Verbose => {
487+
write!(
488+
f,
489+
"CountingExec: partitions={}, batches_per_partition={}",
490+
self.partitions, self.batches_per_partition
491+
)
492+
}
493+
DisplayFormatType::TreeRender => {
494+
writeln!(f, "partitions={}", self.partitions)?;
495+
writeln!(f, "batches_per_partition={}", self.batches_per_partition)
496+
}
497+
}
498+
}
499+
}
500+
501+
impl ExecutionPlan for CountingExec {
502+
fn name(&self) -> &'static str {
503+
"CountingExec"
504+
}
505+
506+
fn properties(&self) -> &Arc<PlanProperties> {
507+
&self.cache
508+
}
509+
510+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
511+
vec![]
512+
}
513+
514+
fn apply_expressions(
515+
&self,
516+
_f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result<TreeNodeRecursion>,
517+
) -> Result<TreeNodeRecursion> {
518+
Ok(TreeNodeRecursion::Continue)
519+
}
520+
521+
fn with_new_children(
522+
self: Arc<Self>,
523+
_: Vec<Arc<dyn ExecutionPlan>>,
524+
) -> Result<Arc<dyn ExecutionPlan>> {
525+
internal_err!("Children cannot be replaced in {self:?}")
526+
}
527+
528+
fn execute(
529+
&self,
530+
partition: usize,
531+
_context: Arc<TaskContext>,
532+
) -> Result<SendableRecordBatchStream> {
533+
Ok(Box::pin(CountingStream {
534+
schema: Arc::clone(&self.schema),
535+
next_value: partition * self.batches_per_partition * self.rows_per_batch,
536+
remaining: self.batches_per_partition,
537+
rows_per_batch: self.rows_per_batch,
538+
}))
539+
}
540+
}
541+
542+
#[derive(Debug)]
543+
struct CountingStream {
544+
schema: SchemaRef,
545+
next_value: usize,
546+
remaining: usize,
547+
rows_per_batch: usize,
548+
}
549+
550+
impl Stream for CountingStream {
551+
type Item = Result<RecordBatch>;
552+
553+
fn poll_next(
554+
mut self: Pin<&mut Self>,
555+
_cx: &mut Context<'_>,
556+
) -> Poll<Option<Self::Item>> {
557+
if self.remaining == 0 {
558+
return Poll::Ready(None);
559+
}
560+
self.remaining -= 1;
561+
let start = self.next_value as u64;
562+
self.next_value += self.rows_per_batch;
563+
let values =
564+
UInt64Array::from_iter_values(start..start + self.rows_per_batch as u64);
565+
Poll::Ready(Some(Ok(RecordBatch::try_new(
566+
Arc::clone(&self.schema),
567+
vec![Arc::new(values) as ArrayRef],
568+
)?)))
569+
}
570+
}
571+
572+
impl RecordBatchStream for CountingStream {
573+
fn schema(&self) -> SchemaRef {
574+
Arc::clone(&self.schema)
575+
}
576+
}
577+
578+
fn high_cardinality_partial_aggregate(
579+
input: Arc<dyn ExecutionPlan>,
580+
schema: &SchemaRef,
581+
) -> Result<Arc<AggregateExec>> {
582+
let groups =
583+
PhysicalGroupBy::new_single(vec![(col("a", schema)?, "a".to_string())]);
584+
let count = AggregateExprBuilder::new(count_udaf(), vec![col("a", schema)?])
585+
.schema(Arc::clone(schema))
586+
.alias("COUNT(a)")
587+
.build()?;
588+
589+
Ok(Arc::new(AggregateExec::try_new(
590+
AggregateMode::Partial,
591+
groups,
592+
vec![Arc::new(count)],
593+
vec![None],
594+
input,
595+
Arc::clone(schema),
596+
)?))
597+
}
598+
599+
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
600+
#[ignore = "temporary diagnostic reproducer for layered cancellation delay"]
601+
async fn cancellation_delay_coalesce_repartition() -> Result<()> {
602+
let task_ctx = Arc::new(TaskContext::default());
603+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::UInt64, true)]));
604+
let batches_per_input_partition = 8;
605+
let rows_per_batch = 128_000;
606+
let input = Arc::new(CountingExec::new(
607+
Arc::clone(&schema),
608+
2,
609+
batches_per_input_partition,
610+
rows_per_batch,
611+
));
612+
let input_refs = input.refs();
613+
let mut plan: Arc<dyn ExecutionPlan> =
614+
high_cardinality_partial_aggregate(input, &schema)?;
615+
616+
let layers = 512;
617+
let output_partitions = 32;
618+
let mut repartition_refs = Vec::with_capacity(layers);
619+
620+
for _ in 0..layers {
621+
let repartition = Arc::new(RepartitionExec::try_new(
622+
plan,
623+
Partitioning::RoundRobinBatch(output_partitions),
624+
)?);
625+
repartition_refs.push(Arc::downgrade(&repartition));
626+
627+
plan = Arc::new(CoalescePartitionsExec::new(repartition));
628+
}
629+
630+
let handle = tokio::spawn(collect(plan, task_ctx));
631+
tokio::time::sleep(Duration::from_millis(100)).await;
632+
assert!(!handle.is_finished(), "query finished before cancellation");
633+
634+
let start = Instant::now();
635+
handle.abort();
636+
637+
let drop_times = wait_for_repartition_drop_times(&repartition_refs, start).await;
638+
let total_elapsed = start.elapsed();
639+
640+
for (idx, elapsed) in drop_times.iter().enumerate().rev() {
641+
let layer_from_top = layers - idx;
642+
if layer_from_top != 1 && layer_from_top != layers && layer_from_top % 32 != 0
643+
{
644+
continue;
645+
}
646+
println!(
647+
"layer_from_top={layer_from_top} repartition_drop_elapsed_ms={}",
648+
elapsed.as_millis()
649+
);
650+
}
651+
println!(
652+
"layers={layers} output_partitions={output_partitions} input_rows_per_partition={} cancellation_elapsed_ms={}",
653+
batches_per_input_partition * rows_per_batch,
654+
total_elapsed.as_millis()
655+
);
656+
println!(
657+
"input_plan_strong_count_after_cancel={}",
658+
input_refs.strong_count()
659+
);
660+
661+
Ok(())
662+
}
663+
393664
#[tokio::test]
394665
async fn test_drop_cancel() -> Result<()> {
395666
let task_ctx = Arc::new(TaskContext::default());

datafusion/physical-plan/src/stream.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ impl RecordBatchReceiverStreamBuilder {
346346
Ok(stream) => stream,
347347
};
348348

349+
let plan_display = displayable(input.as_ref()).one_line().to_string();
350+
drop(input);
351+
349352
// Transfer batches from inner stream to the output tx
350353
// immediately.
351354
while let Some(item) = stream.next().await {
@@ -356,18 +359,15 @@ impl RecordBatchReceiverStreamBuilder {
356359
if output.send(item).await.is_err() {
357360
debug!(
358361
"Stopping execution: output is gone, plan cancelling: {}",
359-
displayable(input.as_ref()).one_line()
362+
plan_display
360363
);
361364
return Ok(());
362365
}
363366

364367
// Stop after the first error is encountered (Don't
365368
// drive all streams to completion)
366369
if is_err {
367-
debug!(
368-
"Stopping execution: plan returned error: {}",
369-
displayable(input.as_ref()).one_line()
370-
);
370+
debug!("Stopping execution: plan returned error: {}", plan_display);
371371
return Ok(());
372372
}
373373
}

0 commit comments

Comments
 (0)