Skip to content

Commit 533ef35

Browse files
authored
fix: Correct computation of selectivity for multi-key joins (apache#22725)
## Which issue does this PR close? - Closes apache#22724 ## Rationale for this change `estimate_inner_join_cardinality` sets `join_selectivity` to the selectivity of the last join key in the list. The intent was almost surely to instead use the selectivity of the most selective join key instead. ## What changes are included in this PR? * Fix formula for multi-key join selectivity estimation * Improve comment to reference Spark Catalyst behavior more clearly * Add unit test ## Are these changes tested? Yes, new test added. ## Are there any user-facing changes? No.
1 parent e71bd56 commit 533ef35

1 file changed

Lines changed: 69 additions & 3 deletions

File tree

  • datafusion/physical-plan/src/joins

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,8 +640,8 @@ fn estimate_inner_join_cardinality(
640640
..
641641
} = right_stats;
642642

643-
// The algorithm here is partly based on the non-histogram selectivity estimation
644-
// from Spark's Catalyst optimizer.
643+
// Follow Spark Catalyst's conservative NDV join estimate: for multi-key
644+
// joins, use the most selective key instead of multiplying all key denominators.
645645
let mut join_selectivity = Precision::Absent;
646646
for (left_stat, right_stat) in left_column_statistics
647647
.iter()
@@ -654,7 +654,11 @@ fn estimate_inner_join_cardinality(
654654
// Seems like there are a few implementations of this algorithm that implement
655655
// exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs
656656
// further exploration.
657-
join_selectivity = max_distinct;
657+
join_selectivity = if join_selectivity.get_value().is_some() {
658+
join_selectivity.max(&max_distinct)
659+
} else {
660+
max_distinct
661+
};
658662
}
659663
}
660664

@@ -2730,6 +2734,68 @@ mod tests {
27302734
Ok(())
27312735
}
27322736

2737+
#[test]
2738+
fn test_join_cardinality_key_order() -> Result<()> {
2739+
// Reversing join key order should not change estimated cardinality
2740+
let left_col_stats = vec![
2741+
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2742+
create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2743+
create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2744+
];
2745+
2746+
let right_col_stats = vec![
2747+
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2748+
create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2749+
create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2750+
];
2751+
2752+
let join_on_ab = vec![
2753+
(
2754+
Arc::new(Column::new("a", 0)) as _,
2755+
Arc::new(Column::new("c", 0)) as _,
2756+
),
2757+
(
2758+
Arc::new(Column::new("b", 1)) as _,
2759+
Arc::new(Column::new("d", 1)) as _,
2760+
),
2761+
];
2762+
let join_on_ba = vec![
2763+
(
2764+
Arc::new(Column::new("b", 1)) as _,
2765+
Arc::new(Column::new("d", 1)) as _,
2766+
),
2767+
(
2768+
Arc::new(Column::new("a", 0)) as _,
2769+
Arc::new(Column::new("c", 0)) as _,
2770+
),
2771+
];
2772+
2773+
let stats_ab = estimate_join_cardinality(
2774+
&JoinType::Inner,
2775+
create_stats(Some(1000), left_col_stats.clone(), false),
2776+
create_stats(Some(2000), right_col_stats.clone(), false),
2777+
&join_on_ab,
2778+
)
2779+
.unwrap();
2780+
let stats_ba = estimate_join_cardinality(
2781+
&JoinType::Inner,
2782+
create_stats(Some(1000), left_col_stats.clone(), false),
2783+
create_stats(Some(2000), right_col_stats.clone(), false),
2784+
&join_on_ba,
2785+
)
2786+
.unwrap();
2787+
2788+
assert_eq!(stats_ab.num_rows, 1000);
2789+
assert_eq!(stats_ba.num_rows, stats_ab.num_rows);
2790+
assert_eq!(stats_ba.column_statistics, stats_ab.column_statistics);
2791+
assert_eq!(
2792+
stats_ab.column_statistics,
2793+
[left_col_stats, right_col_stats].concat()
2794+
);
2795+
2796+
Ok(())
2797+
}
2798+
27332799
#[test]
27342800
fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
27352801
// Left table (rows=1000)

0 commit comments

Comments
 (0)