Skip to content

Commit 7a3ab1c

Browse files
Kontinuationpaleolimbot
authored andcommitted
fix(rust/sedona-spatial-join): wrap probe-side repartition in ProbeShuffleExec to prevent optimizer stripping (#677)
1 parent 31adbe1 commit 7a3ab1c

5 files changed

Lines changed: 267 additions & 10 deletions

File tree

rust/sedona-spatial-join/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ pub use exec::SpatialJoinExec;
3333
// Re-export function for register the spatial join planner
3434
pub use planner::register_planner;
3535

36+
// Re-export ProbeShuffleExec so that integration tests (and other crates) can verify
37+
// its presence in optimized physical plans.
38+
pub use planner::probe_shuffle_exec::ProbeShuffleExec;
39+
3640
// Re-export types needed for external usage (e.g., in Comet)
3741
pub use index::{SpatialIndex, SpatialJoinBuildMetrics};
3842
pub use spatial_predicate::SpatialPredicate;

rust/sedona-spatial-join/src/planner.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use datafusion_common::Result;
2626
mod logical_plan_node;
2727
mod optimizer;
2828
mod physical_planner;
29+
pub mod probe_shuffle_exec;
2930
mod spatial_expr_utils;
3031

3132
/// Register Sedona spatial join planning hooks.

rust/sedona-spatial-join/src/planner/physical_planner.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,13 @@ use datafusion_common::{plan_err, DFSchema, JoinSide, Result};
3131
use datafusion_expr::logical_plan::UserDefinedLogicalNode;
3232
use datafusion_expr::LogicalPlan;
3333
use datafusion_physical_expr::create_physical_expr;
34-
use datafusion_physical_expr::Partitioning;
3534
use datafusion_physical_plan::joins::utils::JoinFilter;
3635
use datafusion_physical_plan::joins::NestedLoopJoinExec;
37-
use datafusion_physical_plan::repartition::RepartitionExec;
38-
use datafusion_physical_plan::ExecutionPlanProperties;
3936
use sedona_common::sedona_internal_err;
4037

4138
use crate::exec::SpatialJoinExec;
4239
use crate::planner::logical_plan_node::SpatialJoinPlanNode;
40+
use crate::planner::probe_shuffle_exec::ProbeShuffleExec;
4341
use crate::planner::spatial_expr_utils::{is_spatial_predicate_supported, transform_join_filter};
4442
use crate::spatial_predicate::SpatialPredicate;
4543
use sedona_common::option::SedonaOptions;
@@ -325,11 +323,7 @@ fn repartition_probe_side(
325323
}
326324
};
327325

328-
let num_partitions = probe_plan.output_partitioning().partition_count();
329-
*probe_plan = Arc::new(RepartitionExec::try_new(
330-
Arc::clone(probe_plan),
331-
Partitioning::RoundRobinBatch(num_partitions),
332-
)?);
326+
*probe_plan = Arc::new(ProbeShuffleExec::try_new(Arc::clone(probe_plan))?);
333327

334328
Ok((physical_left, physical_right))
335329
}
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ProbeShuffleExec`] — a round-robin repartitioning wrapper that is invisible
19+
//! to DataFusion's `EnforceDistribution` / `EnforceSorting` optimizer passes.
20+
//!
21+
//! Those passes unconditionally strip every [`RepartitionExec`] before
22+
//! re-evaluating distribution requirements. Because `SpatialJoinExec` reports
23+
//! `UnspecifiedDistribution` for its inputs, a bare `RepartitionExec` that was
24+
//! inserted by the extension planner is removed and never re-added.
25+
//!
26+
//! `ProbeShuffleExec` wraps a hidden, internal `RepartitionExec` so that:
27+
//! * **Optimizer passes** see an opaque node (not a `RepartitionExec`) and leave
28+
//! it alone.
29+
//! * **`children()` / `with_new_children()`** expose the *original* input so
30+
//! the rest of the optimizer tree can still be rewritten normally.
31+
//! * **`execute()`** delegates to the internal `RepartitionExec` which performs
32+
//! the actual round-robin shuffle.
33+
34+
use std::any::Any;
35+
use std::fmt;
36+
use std::sync::Arc;
37+
38+
use datafusion_common::config::ConfigOptions;
39+
use datafusion_common::{internal_err, plan_err, Result, Statistics};
40+
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
41+
use datafusion_physical_expr::PhysicalExpr;
42+
use datafusion_physical_plan::execution_plan::CardinalityEffect;
43+
use datafusion_physical_plan::filter_pushdown::{
44+
ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation,
45+
};
46+
use datafusion_physical_plan::metrics::MetricsSet;
47+
use datafusion_physical_plan::projection::ProjectionExec;
48+
use datafusion_physical_plan::repartition::RepartitionExec;
49+
use datafusion_physical_plan::{
50+
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
51+
PlanProperties,
52+
};
53+
54+
/// A round-robin repartitioning node that is invisible to DataFusion's
55+
/// physical optimizer passes.
56+
///
57+
/// See [module-level documentation](self) for motivation and design.
58+
#[derive(Debug)]
59+
pub struct ProbeShuffleExec {
60+
inner_repartition: RepartitionExec,
61+
}
62+
63+
impl ProbeShuffleExec {
64+
/// Create a new [`ProbeShuffleExec`] that round-robin repartitions `input`
65+
/// into the same number of output partitions as `input`. This will ensure
66+
/// that the probe workload of a spatial join will be evenly distributed.
67+
/// More importantly, shuffled probe side data will be less likely to
68+
/// cause skew issues when out-of-core, spatial partitioned spatial join is enabled,
69+
/// especially when the input probe data is sorted by their spatial locations.
70+
pub fn try_new(input: Arc<dyn ExecutionPlan>) -> Result<Self> {
71+
let num_partitions = input.output_partitioning().partition_count();
72+
let inner_repartition = RepartitionExec::try_new(
73+
Arc::clone(&input),
74+
Partitioning::RoundRobinBatch(num_partitions),
75+
)?;
76+
Ok(Self { inner_repartition })
77+
}
78+
79+
/// Try to wrap the given [`RepartitionExec`] `plan` with [`ProbeShuffleExec`].
80+
pub fn try_wrap_repartition(plan: Arc<dyn ExecutionPlan>) -> Result<Self> {
81+
let Some(repartition_exec) = plan.as_any().downcast_ref::<RepartitionExec>() else {
82+
return plan_err!(
83+
"ProbeShuffleExec can only wrap RepartitionExec, but got {}",
84+
plan.name()
85+
);
86+
};
87+
Ok(Self {
88+
inner_repartition: repartition_exec.clone(),
89+
})
90+
}
91+
92+
/// Number of output partitions.
93+
pub fn num_partitions(&self) -> usize {
94+
self.inner_repartition
95+
.properties()
96+
.output_partitioning()
97+
.partition_count()
98+
}
99+
}
100+
101+
impl DisplayAs for ProbeShuffleExec {
102+
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
103+
match t {
104+
DisplayFormatType::Default | DisplayFormatType::Verbose => {
105+
write!(
106+
f,
107+
"ProbeShuffleExec: partitioning=RoundRobinBatch({})",
108+
self.num_partitions()
109+
)
110+
}
111+
DisplayFormatType::TreeRender => {
112+
write!(f, "partitioning=RoundRobinBatch({})", self.num_partitions())
113+
}
114+
}
115+
}
116+
}
117+
118+
impl ExecutionPlan for ProbeShuffleExec {
119+
fn name(&self) -> &str {
120+
"ProbeShuffleExec"
121+
}
122+
123+
fn as_any(&self) -> &dyn Any {
124+
self
125+
}
126+
127+
fn properties(&self) -> &PlanProperties {
128+
self.inner_repartition.properties()
129+
}
130+
131+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
132+
vec![self.inner_repartition.input()]
133+
}
134+
135+
fn with_new_children(
136+
self: Arc<Self>,
137+
mut children: Vec<Arc<dyn ExecutionPlan>>,
138+
) -> Result<Arc<dyn ExecutionPlan>> {
139+
if children.len() != 1 {
140+
return internal_err!(
141+
"ProbeShuffleExec expects exactly 1 child, got {}",
142+
children.len()
143+
);
144+
}
145+
let child = children.remove(0);
146+
Ok(Arc::new(Self::try_new(child)?))
147+
}
148+
149+
fn execute(
150+
&self,
151+
partition: usize,
152+
context: Arc<TaskContext>,
153+
) -> Result<SendableRecordBatchStream> {
154+
self.inner_repartition.execute(partition, context)
155+
}
156+
157+
fn maintains_input_order(&self) -> Vec<bool> {
158+
self.inner_repartition.maintains_input_order()
159+
}
160+
161+
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
162+
self.inner_repartition.benefits_from_input_partitioning()
163+
}
164+
165+
fn cardinality_effect(&self) -> CardinalityEffect {
166+
self.inner_repartition.cardinality_effect()
167+
}
168+
169+
fn metrics(&self) -> Option<MetricsSet> {
170+
self.inner_repartition.metrics()
171+
}
172+
173+
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
174+
self.inner_repartition.partition_statistics(partition)
175+
}
176+
177+
fn try_swapping_with_projection(
178+
&self,
179+
projection: &ProjectionExec,
180+
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
181+
let Some(new_repartition) = self
182+
.inner_repartition
183+
.try_swapping_with_projection(projection)?
184+
else {
185+
return Ok(None);
186+
};
187+
let new_plan = Self::try_wrap_repartition(new_repartition)?;
188+
Ok(Some(Arc::new(new_plan)))
189+
}
190+
191+
fn gather_filters_for_pushdown(
192+
&self,
193+
phase: FilterPushdownPhase,
194+
parent_filters: Vec<Arc<dyn PhysicalExpr>>,
195+
config: &ConfigOptions,
196+
) -> Result<FilterDescription> {
197+
self.inner_repartition
198+
.gather_filters_for_pushdown(phase, parent_filters, config)
199+
}
200+
201+
fn handle_child_pushdown_result(
202+
&self,
203+
phase: FilterPushdownPhase,
204+
child_pushdown_result: ChildPushdownResult,
205+
config: &ConfigOptions,
206+
) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
207+
self.inner_repartition
208+
.handle_child_pushdown_result(phase, child_pushdown_result, config)
209+
}
210+
211+
fn repartitioned(
212+
&self,
213+
target_partitions: usize,
214+
config: &ConfigOptions,
215+
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
216+
let Some(plan) = self
217+
.inner_repartition
218+
.repartitioned(target_partitions, config)?
219+
else {
220+
return Ok(None);
221+
};
222+
let new_plan = Self::try_wrap_repartition(plan)?;
223+
Ok(Some(Arc::new(new_plan)))
224+
}
225+
}

rust/sedona-spatial-join/tests/spatial_join_integration.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use datafusion::{
2626
prelude::{SessionConfig, SessionContext},
2727
};
2828
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
29-
use datafusion_common::Result;
29+
use datafusion_common::{JoinSide, Result};
3030
use datafusion_expr::{ColumnarValue, JoinType};
3131
use datafusion_physical_plan::filter::FilterExec;
3232
use datafusion_physical_plan::joins::NestedLoopJoinExec;
@@ -43,7 +43,8 @@ use sedona_schema::{
4343
matchers::ArgMatcher,
4444
};
4545
use sedona_spatial_join::{
46-
register_planner, spatial_predicate::RelationPredicate, SpatialJoinExec, SpatialPredicate,
46+
register_planner, spatial_predicate::RelationPredicate, ProbeShuffleExec, SpatialJoinExec,
47+
SpatialPredicate,
4748
};
4849
use sedona_testing::datagen::RandomPartitionedDataBuilder;
4950
use tokio::sync::OnceCell;
@@ -801,6 +802,7 @@ async fn run_spatial_join_query(
801802
)?);
802803

803804
let is_optimized_spatial_join = options.is_some();
805+
let repartition_probe_side = options.as_ref().is_some_and(|o| o.repartition_probe_side);
804806
let ctx = setup_context(options, batch_size)?;
805807
ctx.register_table("L", Arc::clone(&mem_table_left))?;
806808
ctx.register_table("R", Arc::clone(&mem_table_right))?;
@@ -810,6 +812,9 @@ async fn run_spatial_join_query(
810812
let spatial_join_execs = collect_spatial_join_exec(&plan)?;
811813
if is_optimized_spatial_join {
812814
assert_eq!(spatial_join_execs.len(), 1);
815+
if repartition_probe_side {
816+
probe_side_of_spatial_join_exec_should_be_shuffled(spatial_join_execs[0]);
817+
}
813818
} else {
814819
assert!(spatial_join_execs.is_empty());
815820
}
@@ -829,6 +834,20 @@ fn collect_spatial_join_exec(plan: &Arc<dyn ExecutionPlan>) -> Result<Vec<&Spati
829834
Ok(spatial_join_execs)
830835
}
831836

837+
fn probe_side_of_spatial_join_exec_should_be_shuffled(sj: &SpatialJoinExec) {
838+
let probe_child = match &sj.on {
839+
SpatialPredicate::KNearestNeighbors(knn) => match knn.probe_side {
840+
JoinSide::Left => &sj.left,
841+
_ => &sj.right,
842+
},
843+
_ => &sj.right, // non-KNN: probe is always right after swap
844+
};
845+
assert!(
846+
subtree_contains_probe_shuffle_exec(probe_child),
847+
"ProbeShuffleExec should be present on the probe side of SpatialJoinExec"
848+
);
849+
}
850+
832851
async fn test_mark_join(
833852
join_type: JoinType,
834853
options: SpatialJoinOptions,
@@ -1613,6 +1632,20 @@ fn subtree_contains_filter_exec(plan: &Arc<dyn ExecutionPlan>) -> bool {
16131632
found
16141633
}
16151634

1635+
/// Recursively check whether any node in the physical plan tree is a `ProbeShuffleExec`.
1636+
fn subtree_contains_probe_shuffle_exec(plan: &Arc<dyn ExecutionPlan>) -> bool {
1637+
let mut found = false;
1638+
plan.apply(|node| {
1639+
if node.as_any().downcast_ref::<ProbeShuffleExec>().is_some() {
1640+
found = true;
1641+
return Ok(TreeNodeRecursion::Stop);
1642+
}
1643+
Ok(TreeNodeRecursion::Continue)
1644+
})
1645+
.expect("failed to walk plan");
1646+
found
1647+
}
1648+
16161649
/// Create a session context with two small tables for filter-pushdown tests.
16171650
///
16181651
/// L(id INT, x DOUBLE) and R(id INT, x DOUBLE) are all empty, this is just for exercising the

0 commit comments

Comments
 (0)