Skip to content

Commit 8e7439d

Browse files
committed
smaller repro
1 parent 58a9778 commit 8e7439d

2 files changed

Lines changed: 58 additions & 125 deletions

File tree

datafusion/physical-plan/src/coalesce_partitions.rs

Lines changed: 57 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,9 @@ impl ExecutionPlan for CoalescePartitionsExec {
350350
#[cfg(test)]
351351
mod tests {
352352
use super::*;
353-
use crate::RecordBatchStream;
354353
use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
355-
use crate::execution_plan::{Boundedness, EmissionType};
356354
use crate::expressions::col;
355+
use crate::memory::{LazyBatchGenerator, LazyMemoryExec};
357356
use crate::repartition::RepartitionExec;
358357
use crate::test::exec::{
359358
BlockingExec, PanicExec, assert_strong_count_converges_to_zero,
@@ -365,17 +364,14 @@ mod tests {
365364
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
366365
use arrow::record_batch::RecordBatch;
367366
use datafusion_common::Result;
368-
use datafusion_common::internal_err;
369-
use datafusion_common::tree_node::TreeNodeRecursion;
370367
use datafusion_functions_aggregate::count::count_udaf;
371-
use datafusion_physical_expr::EquivalenceProperties;
372-
use datafusion_physical_expr::PhysicalExpr;
373368
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
374369

375-
use futures::{FutureExt, Stream};
376-
use std::pin::Pin;
370+
use futures::FutureExt;
371+
use parking_lot::RwLock;
372+
use std::any::Any;
373+
use std::fmt;
377374
use std::sync::{Arc, Weak};
378-
use std::task::{Context, Poll};
379375
use std::time::{Duration, Instant};
380376

381377
#[tokio::test]
@@ -437,141 +433,73 @@ mod tests {
437433
.collect()
438434
}
439435

440-
#[derive(Debug)]
441-
struct CountingExec {
436+
#[derive(Debug, Clone)]
437+
struct CountingGenerator {
442438
schema: SchemaRef,
443-
partitions: usize,
444-
batches_per_partition: usize,
439+
partition: usize,
440+
next_batch: usize,
441+
max_batches: usize,
445442
rows_per_batch: usize,
446-
plan_ref: Arc<()>,
447-
cache: Arc<PlanProperties>,
448443
}
449444

450-
impl CountingExec {
445+
impl CountingGenerator {
451446
fn new(
452447
schema: SchemaRef,
453-
partitions: usize,
454-
batches_per_partition: usize,
448+
partition: usize,
449+
max_batches: usize,
455450
rows_per_batch: usize,
456451
) -> Self {
457-
let cache = Arc::new(PlanProperties::new(
458-
EquivalenceProperties::new(Arc::clone(&schema)),
459-
Partitioning::UnknownPartitioning(partitions),
460-
EmissionType::Incremental,
461-
Boundedness::Bounded,
462-
));
463-
464452
Self {
465453
schema,
466-
partitions,
467-
batches_per_partition,
454+
partition,
455+
next_batch: 0,
456+
max_batches,
468457
rows_per_batch,
469-
plan_ref: Arc::new(()),
470-
cache,
471458
}
472459
}
473-
474-
fn refs(&self) -> Weak<()> {
475-
Arc::downgrade(&self.plan_ref)
476-
}
477460
}
478461

479-
impl DisplayAs for CountingExec {
480-
fn fmt_as(
481-
&self,
482-
t: DisplayFormatType,
483-
f: &mut std::fmt::Formatter,
484-
) -> std::fmt::Result {
485-
match t {
486-
DisplayFormatType::Default | DisplayFormatType::Verbose => {
487-
write!(
488-
f,
489-
"CountingExec: partitions={}, batches_per_partition={}",
490-
self.partitions, self.batches_per_partition
491-
)
492-
}
493-
DisplayFormatType::TreeRender => {
494-
writeln!(f, "partitions={}", self.partitions)?;
495-
writeln!(f, "batches_per_partition={}", self.batches_per_partition)
496-
}
497-
}
462+
impl fmt::Display for CountingGenerator {
463+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
464+
write!(
465+
f,
466+
"CountingGenerator: partition={}, max_batches={}, rows_per_batch={}",
467+
self.partition, self.max_batches, self.rows_per_batch
468+
)
498469
}
499470
}
500471

501-
impl ExecutionPlan for CountingExec {
502-
fn name(&self) -> &'static str {
503-
"CountingExec"
472+
impl LazyBatchGenerator for CountingGenerator {
473+
fn as_any(&self) -> &dyn Any {
474+
self
504475
}
505476

506-
fn properties(&self) -> &Arc<PlanProperties> {
507-
&self.cache
508-
}
509-
510-
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
511-
vec![]
512-
}
513-
514-
fn apply_expressions(
515-
&self,
516-
_f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result<TreeNodeRecursion>,
517-
) -> Result<TreeNodeRecursion> {
518-
Ok(TreeNodeRecursion::Continue)
519-
}
520-
521-
fn with_new_children(
522-
self: Arc<Self>,
523-
_: Vec<Arc<dyn ExecutionPlan>>,
524-
) -> Result<Arc<dyn ExecutionPlan>> {
525-
internal_err!("Children cannot be replaced in {self:?}")
526-
}
527-
528-
fn execute(
529-
&self,
530-
partition: usize,
531-
_context: Arc<TaskContext>,
532-
) -> Result<SendableRecordBatchStream> {
533-
Ok(Box::pin(CountingStream {
534-
schema: Arc::clone(&self.schema),
535-
next_value: partition * self.batches_per_partition * self.rows_per_batch,
536-
remaining: self.batches_per_partition,
537-
rows_per_batch: self.rows_per_batch,
538-
}))
539-
}
540-
}
541-
542-
#[derive(Debug)]
543-
struct CountingStream {
544-
schema: SchemaRef,
545-
next_value: usize,
546-
remaining: usize,
547-
rows_per_batch: usize,
548-
}
477+
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
478+
if self.next_batch == self.max_batches {
479+
return Ok(None);
480+
}
549481

550-
impl Stream for CountingStream {
551-
type Item = Result<RecordBatch>;
482+
let start = ((self.partition * self.max_batches + self.next_batch)
483+
* self.rows_per_batch) as u64;
484+
self.next_batch += 1;
552485

553-
fn poll_next(
554-
mut self: Pin<&mut Self>,
555-
_cx: &mut Context<'_>,
556-
) -> Poll<Option<Self::Item>> {
557-
if self.remaining == 0 {
558-
return Poll::Ready(None);
559-
}
560-
self.remaining -= 1;
561-
let start = self.next_value as u64;
562-
self.next_value += self.rows_per_batch;
563486
let values =
564487
UInt64Array::from_iter_values(start..start + self.rows_per_batch as u64);
565-
Poll::Ready(Some(Ok(RecordBatch::try_new(
488+
489+
Ok(Some(RecordBatch::try_new(
566490
Arc::clone(&self.schema),
567491
vec![Arc::new(values) as ArrayRef],
568-
)?)))
492+
)?))
569493
}
570-
}
571494

572-
impl RecordBatchStream for CountingStream {
573-
fn schema(&self) -> SchemaRef {
574-
Arc::clone(&self.schema)
495+
fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
496+
Arc::new(RwLock::new(Self {
497+
schema: Arc::clone(&self.schema),
498+
partition: self.partition,
499+
next_batch: 0,
500+
max_batches: self.max_batches,
501+
rows_per_batch: self.rows_per_batch,
502+
}))
575503
}
576504
}
577505

@@ -601,15 +529,21 @@ mod tests {
601529
async fn cancellation_delay_coalesce_repartition() -> Result<()> {
602530
let task_ctx = Arc::new(TaskContext::default());
603531
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::UInt64, true)]));
532+
let input_partitions = 2;
604533
let batches_per_input_partition = 8;
605534
let rows_per_batch = 128_000;
606-
let input = Arc::new(CountingExec::new(
607-
Arc::clone(&schema),
608-
2,
609-
batches_per_input_partition,
610-
rows_per_batch,
611-
));
612-
let input_refs = input.refs();
535+
let generators = (0..input_partitions)
536+
.map(|partition| {
537+
Arc::new(RwLock::new(CountingGenerator::new(
538+
Arc::clone(&schema),
539+
partition,
540+
batches_per_input_partition,
541+
rows_per_batch,
542+
))) as Arc<RwLock<dyn LazyBatchGenerator>>
543+
})
544+
.collect();
545+
let input = Arc::new(LazyMemoryExec::try_new(Arc::clone(&schema), generators)?);
546+
let input_refs = Arc::downgrade(&input);
613547
let mut plan: Arc<dyn ExecutionPlan> =
614548
high_cardinality_partial_aggregate(input, &schema)?;
615549

datafusion/physical-plan/src/memory.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use std::fmt;
2222
use std::sync::Arc;
2323
use std::task::{Context, Poll};
2424

25-
use crate::coop::cooperative;
2625
use crate::execution_plan::{Boundedness, EmissionType, SchedulingType};
2726
use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
2827
use crate::{
@@ -354,7 +353,7 @@ impl ExecutionPlan for LazyMemoryExec {
354353
generator,
355354
baseline_metrics,
356355
};
357-
Ok(Box::pin(cooperative(stream)))
356+
Ok(Box::pin(stream))
358357
}
359358

360359
fn metrics(&self) -> Option<MetricsSet> {

0 commit comments

Comments
 (0)