diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs index 6c7a12d609..f5f3ddf0a1 100644 --- a/ballista/client/tests/context_checks.rs +++ b/ballista/client/tests/context_checks.rs @@ -1109,6 +1109,78 @@ mod supported { Ok(()) } + /// Regression test: nested CollectLeft HashJoinExec with + /// CoalescePartitionsExec should not deadlock. + /// + /// This reproduces the pattern from TPC-H Q2 where a chain of + /// small-table inner joins (region→nation→supplier) is broadcast-joined + /// against a large partitioned table (partsupp). The scheduler enables + /// CollectLeft for inner joins under the broadcast threshold, and each + /// executor task runs exactly ONE partition. If any cross-partition + /// synchronisation (e.g. a tokio Barrier) is used in the build-side + /// completion path, it will deadlock because only one partition + /// participates per task. + #[rstest] + #[case::standalone(standalone_context())] + #[tokio::test] + async fn nested_collect_left_should_not_deadlock( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + // Use alltypes_plain.parquet registered as 3 different tables + // to create a nested inner join query where the optimizer + // should choose CollectLeft for the small tables. + ctx.register_parquet( + "fact_table", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + ctx.register_parquet( + "dim_a", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + ctx.register_parquet( + "dim_b", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + // Query with nested inner joins: dim_b → dim_a → fact_table + let result = tokio::time::timeout( + std::time::Duration::from_secs(120), + ctx.sql( + "SELECT f.id, a.int_col, b.string_col + FROM fact_table f + INNER JOIN dim_a a ON f.id = a.id + INNER JOIN dim_b b ON a.tinyint_col = b.tinyint_col + ORDER BY f.id + LIMIT 5", + ) + .await? + .collect(), + ) + .await + .expect("nested CollectLeft joins should complete within 120s, not deadlock"); + + let result = result?; + // Verify we got results + assert!(!result.is_empty(), "query should return results"); + assert!( + result[0].num_rows() > 0, + "query should return at least one row" + ); + + Ok(()) + } + #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] diff --git a/ballista/core/src/execution_plans/mod.rs b/ballista/core/src/execution_plans/mod.rs index 9f8f944b2b..96d4e5c1cb 100644 --- a/ballista/core/src/execution_plans/mod.rs +++ b/ballista/core/src/execution_plans/mod.rs @@ -48,3 +48,45 @@ pub use vortex_shuffle::{ LocalVortexShuffleStream, VortexWriteTracker, vortex_file_extension, write_stream_to_disk_vortex, }; + +use datafusion::common::tree_node::Transformed; +use datafusion::error::Result; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; +use std::sync::Arc; + +/// Rebuild a `HashJoinExec` node via `try_new()` to strip any dynamic-filter +/// accumulator (e.g. `SharedBuildAccumulator`). The accumulator uses a +/// cross-partition `Barrier` that deadlocks in Ballista where each task runs a +/// single partition. `try_new()` never attaches an accumulator, so this is +/// always safe. +/// +/// If `node` is not a `HashJoinExec`, returns `Transformed::no(node)`. +pub fn rebuild_hash_join_without_accumulator( + node: Arc, +) -> Result>> { + if let Some(hj) = node.as_any().downcast_ref::() { + let left = Arc::clone(hj.left()); + let left: Arc = if *hj.partition_mode() + == PartitionMode::CollectLeft + && left.properties().output_partitioning().partition_count() > 1 + { + Arc::new(CoalescePartitionsExec::new(left)) + } else { + left + }; + let rebuilt: Arc = Arc::new(HashJoinExec::try_new( + left, + Arc::clone(hj.right()), + hj.on().to_vec(), + hj.filter().cloned(), + hj.join_type(), + hj.projection.clone(), + *hj.partition_mode(), + hj.null_equality(), + )?); + return Ok(Transformed::yes(rebuilt)); + } + Ok(Transformed::no(node)) +} diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs b/ballista/core/src/execution_plans/shuffle_writer.rs index a4ee4c3927..251f89e82e 100644 --- a/ballista/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/core/src/execution_plans/shuffle_writer.rs @@ -58,6 +58,7 @@ use datafusion::physical_plan::metrics::{ self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; +use datafusion::common::tree_node::TreeNode; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, displayable, @@ -339,6 +340,9 @@ impl ShuffleWriterExec { async move { let now = Instant::now(); + let plan = plan + .transform_down(&super::rebuild_hash_join_without_accumulator)? + .data; let mut stream = plan.execute(input_partition, context)?; if use_memory { diff --git a/ballista/scheduler/src/state/mod.rs b/ballista/scheduler/src/state/mod.rs index 29e2dc683b..adaf9f7587 100644 --- a/ballista/scheduler/src/state/mod.rs +++ b/ballista/scheduler/src/state/mod.rs @@ -593,6 +593,15 @@ impl SchedulerState SchedulerState| { + let node = match ballista_core::execution_plans::rebuild_hash_join_without_accumulator(node)? { + t if t.transformed => return Ok(t), + t => t.data, + }; if node.output_partitioning().partition_count() == 0 { let empty: Arc = Arc::new(EmptyExec::new(node.schema()));