Skip to content

Commit 4c88f5d

Browse files
authored
feat: Native Broadcast nested loop join support (#4429)
* native_support_broadcast_nested_loop_join
1 parent 7e41335 commit 4c88f5d

36 files changed

Lines changed: 3165 additions & 805 deletions

File tree

dev/diffs/3.5.8.diff

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ index 7af826583bd..3c3def1eb67 100644
677677
assert(shuffleMergeJoins.size == 1)
678678
}
679679
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
680-
index 44c8cb92fc3..f098beeca26 100644
680+
index 44c8cb92fc3..e29cb93ecda 100644
681681
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
682682
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
683683
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -825,7 +825,19 @@ index 44c8cb92fc3..f098beeca26 100644
825825
checkAnswer(shjNonCodegenDF, Seq.empty)
826826
}
827827
}
828-
@@ -1486,7 +1507,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
828+
@@ -1470,7 +1491,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
829+
"/*+ BROADCAST(t2) */ t1.k as k"
830+
}
831+
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
832+
- assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
833+
+ assert(collect(plan) {
834+
+ case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
835+
+ true
836+
+ }.size === 1)
837+
// No extra shuffle before aggregation
838+
assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 0)
839+
}
840+
@@ -1486,7 +1510,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
829841
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
830842
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
831843
// Have shuffle before aggregation
@@ -835,12 +847,17 @@ index 44c8cb92fc3..f098beeca26 100644
835847
}
836848

837849
def getJoinQuery(selectExpr: String, joinType: String): String = {
838-
@@ -1515,9 +1537,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
850+
@@ -1514,10 +1539,16 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
851+
"/*+ BROADCAST(right_t) */ k1 as k0"
839852
}
840853
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
841-
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
854+
- assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
842855
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
843856
+ assert(collect(plan) {
857+
+ case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
858+
+ true
859+
+ }.size === 1)
860+
+ assert(collect(plan) {
844861
+ case _: SortMergeJoinExec => true
845862
+ case _: CometSortMergeJoinExec => true
846863
+ }.size === 3)
@@ -850,7 +867,7 @@ index 44c8cb92fc3..f098beeca26 100644
850867
}
851868

852869
// Test output ordering is not preserved
853-
@@ -1526,9 +1551,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
870+
@@ -1526,9 +1557,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
854871
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
855872
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
856873
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
@@ -865,7 +882,7 @@ index 44c8cb92fc3..f098beeca26 100644
865882
}
866883

867884
// Test singe partition
868-
@@ -1538,7 +1566,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
885+
@@ -1538,7 +1572,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
869886
|FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
870887
|""".stripMargin)
871888
val plan = fullJoinDF.queryExecution.executedPlan
@@ -875,7 +892,7 @@ index 44c8cb92fc3..f098beeca26 100644
875892
checkAnswer(fullJoinDF, Row(100))
876893
}
877894
}
878-
@@ -1611,6 +1640,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
895+
@@ -1611,6 +1646,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
879896
Seq(semiJoinDF, antiJoinDF).foreach { df =>
880897
assert(collect(df.queryExecution.executedPlan) {
881898
case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true
@@ -885,7 +902,7 @@ index 44c8cb92fc3..f098beeca26 100644
885902
}.size == 1)
886903
}
887904
}
888-
@@ -1655,14 +1687,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
905+
@@ -1655,14 +1693,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
889906

890907
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") {
891908
def check(plan: SparkPlan): Unit = {
@@ -908,7 +925,7 @@ index 44c8cb92fc3..f098beeca26 100644
908925
}
909926
dupStreamSideColTest("SHUFFLE_HASH", check)
910927
}
911-
@@ -1798,7 +1836,8 @@ class ThreadLeakInSortMergeJoinSuite
928+
@@ -1798,7 +1842,8 @@ class ThreadLeakInSortMergeJoinSuite
912929
sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20))
913930
}
914931

dev/diffs/4.0.2.diff

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ index 53e47f428c3..a55d8f0c161 100644
824824
assert(shuffleMergeJoins.size == 1)
825825
}
826826
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
827-
index aaac0ebc9aa..fbef0774d46 100644
827+
index aaac0ebc9aa..276c592ec88 100644
828828
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
829829
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
830830
@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -972,7 +972,19 @@ index aaac0ebc9aa..fbef0774d46 100644
972972
checkAnswer(shjNonCodegenDF, Seq.empty)
973973
}
974974
}
975-
@@ -1489,7 +1510,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
975+
@@ -1473,7 +1494,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
976+
"/*+ BROADCAST(t2) */ t1.k as k"
977+
}
978+
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
979+
- assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
980+
+ assert(collect(plan) {
981+
+ case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
982+
+ true
983+
+ }.size === 1)
984+
// No extra shuffle before aggregation
985+
assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 0)
986+
}
987+
@@ -1489,7 +1513,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
976988
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
977989
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
978990
// Have shuffle before aggregation
@@ -982,12 +994,17 @@ index aaac0ebc9aa..fbef0774d46 100644
982994
}
983995

984996
def getJoinQuery(selectExpr: String, joinType: String): String = {
985-
@@ -1518,9 +1540,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
997+
@@ -1517,10 +1542,16 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
998+
"/*+ BROADCAST(right_t) */ k1 as k0"
986999
}
9871000
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
988-
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
1001+
- assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
9891002
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
9901003
+ assert(collect(plan) {
1004+
+ case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
1005+
+ true
1006+
+ }.size === 1)
1007+
+ assert(collect(plan) {
9911008
+ case _: SortMergeJoinExec => true
9921009
+ case _: CometSortMergeJoinExec => true
9931010
+ }.size === 3)
@@ -997,7 +1014,7 @@ index aaac0ebc9aa..fbef0774d46 100644
9971014
}
9981015

9991016
// Test output ordering is not preserved
1000-
@@ -1529,9 +1554,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1017+
@@ -1529,9 +1560,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10011018
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
10021019
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
10031020
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
@@ -1012,7 +1029,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10121029
}
10131030

10141031
// Test singe partition
1015-
@@ -1541,7 +1569,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1032+
@@ -1541,7 +1575,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10161033
|FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
10171034
|""".stripMargin)
10181035
val plan = fullJoinDF.queryExecution.executedPlan
@@ -1022,7 +1039,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10221039
checkAnswer(fullJoinDF, Row(100))
10231040
}
10241041
}
1025-
@@ -1614,6 +1643,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1042+
@@ -1614,6 +1649,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10261043
Seq(semiJoinDF, antiJoinDF).foreach { df =>
10271044
assert(collect(df.queryExecution.executedPlan) {
10281045
case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true
@@ -1032,7 +1049,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10321049
}.size == 1)
10331050
}
10341051
}
1035-
@@ -1658,14 +1690,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1052+
@@ -1658,14 +1696,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10361053

10371054
test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") {
10381055
def check(plan: SparkPlan): Unit = {
@@ -1055,7 +1072,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10551072
}
10561073
dupStreamSideColTest("SHUFFLE_HASH", check)
10571074
}
1058-
@@ -1801,7 +1839,8 @@ class ThreadLeakInSortMergeJoinSuite
1075+
@@ -1801,7 +1845,8 @@ class ThreadLeakInSortMergeJoinSuite
10591076
sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20))
10601077
}
10611078

dev/diffs/4.1.2.diff

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,18 @@ index 885512d4d19..113ae17ad9f 100644
10431043
checkAnswer(shjNonCodegenDF, Seq.empty)
10441044
}
10451045
}
1046+
@@ -1485,7 +1507,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1047+
"/*+ BROADCAST(t2) */ t1.k as k"
1048+
}
1049+
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
1050+
- assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
1051+
+ assert(collect(plan) {
1052+
+ case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
1053+
+ true
1054+
+ }.size === 1)
1055+
// No extra shuffle before aggregation
1056+
assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 0)
1057+
}
10461058
@@ -1501,7 +1523,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10471059
val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
10481060
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
@@ -1053,12 +1065,16 @@ index 885512d4d19..113ae17ad9f 100644
10531065
}
10541066

10551067
def getJoinQuery(selectExpr: String, joinType: String): String = {
1056-
@@ -1530,9 +1553,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1068+
@@ -1530,10 +1556,16 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10571069
}
10581070
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
1059-
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
1071+
- assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
10601072
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
10611073
+ assert(collect(plan) {
1074+
+ case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
1075+
+ true
1076+
+ }.size === 1)
1077+
+ assert(collect(plan) {
10621078
+ case _: SortMergeJoinExec => true
10631079
+ case _: CometSortMergeJoinExec => true
10641080
+ }.size === 3)

docs/source/user-guide/latest/operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ not supported by Comet will fall back to regular Spark execution.
2727
| BatchScanExec | Yes | Supports Parquet files and Apache Iceberg Parquet scans. See the [Comet Compatibility Guide] for more information. |
2828
| BroadcastExchangeExec | Yes | |
2929
| BroadcastHashJoinExec | Yes | |
30+
| BroadcastNestedLoopJoinExec | Yes | Falls back to Spark when the preserved side is broadcast (e.g. LEFT OUTER with BROADCAST on the left). |
3031
| ExpandExec | Yes | |
3132
| FileSourceScanExec | Yes | Supports Parquet files. See the [Comet Compatibility Guide] for more information. |
3233
| FilterExec | Yes | |

docs/source/user-guide/latest/understanding-comet-plans.md

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -202,21 +202,22 @@ by role. Names match what is shown in the plan output.
202202
These run natively in DataFusion. When several appear consecutively in a plan,
203203
they execute as a single fused native block.
204204

205-
| Node | Spark equivalent |
206-
| ---------------------------- | ----------------------------------------------- |
207-
| `CometProject` | `ProjectExec` |
208-
| `CometFilter` | `FilterExec` |
209-
| `CometSort` | `SortExec` |
210-
| `CometLocalLimit` | `LocalLimitExec` |
211-
| `CometGlobalLimit` | `GlobalLimitExec` |
212-
| `CometExpand` | `ExpandExec` |
213-
| `CometExplode` | `GenerateExec` (for `explode` and `posexplode`) |
214-
| `CometHashAggregate` | `HashAggregateExec`, `ObjectHashAggregateExec` |
215-
| `CometHashJoin` | `ShuffledHashJoinExec` |
216-
| `CometBroadcastHashJoin` | `BroadcastHashJoinExec` |
217-
| `CometSortMergeJoin` | `SortMergeJoinExec` |
218-
| `CometWindow` | `WindowExec` |
219-
| `CometTakeOrderedAndProject` | `TakeOrderedAndProjectExec` |
205+
| Node | Spark equivalent |
206+
| ------------------------------ | ----------------------------------------------- |
207+
| `CometProject` | `ProjectExec` |
208+
| `CometFilter` | `FilterExec` |
209+
| `CometSort` | `SortExec` |
210+
| `CometLocalLimit` | `LocalLimitExec` |
211+
| `CometGlobalLimit` | `GlobalLimitExec` |
212+
| `CometExpand` | `ExpandExec` |
213+
| `CometExplode` | `GenerateExec` (for `explode` and `posexplode`) |
214+
| `CometHashAggregate` | `HashAggregateExec`, `ObjectHashAggregateExec` |
215+
| `CometHashJoin` | `ShuffledHashJoinExec` |
216+
| `CometBroadcastHashJoin` | `BroadcastHashJoinExec` |
217+
| `CometBroadcastNestedLoopJoin` | `BroadcastNestedLoopJoinExec` |
218+
| `CometSortMergeJoin` | `SortMergeJoinExec` |
219+
| `CometWindow` | `WindowExec` |
220+
| `CometTakeOrderedAndProject` | `TakeOrderedAndProjectExec` |
220221

221222
### JVM-Side Operators
222223

native/core/src/execution/jni_api.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ fn op_name(op: &OpStruct) -> &'static str {
239239
OpStruct::Explode(_) => "Explode",
240240
OpStruct::CsvScan(_) => "CsvScan",
241241
OpStruct::ShuffleScan(_) => "ShuffleScan",
242+
OpStruct::BroadcastNestedLoopJoin(_) => "BroadcastNestedLoopJoin",
242243
}
243244
}
244245

native/core/src/execution/planner.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ use arrow::row::{OwnedRow, RowConverter, SortField};
109109
use datafusion::common::utils::SingleRowListArrayBuilder;
110110
use datafusion::common::UnnestOptions;
111111
use datafusion::physical_plan::filter::FilterExec;
112+
use datafusion::physical_plan::joins::NestedLoopJoinExec;
112113
use datafusion::physical_plan::limit::GlobalLimitExec;
113114
use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec};
114115
use datafusion_comet_proto::spark_expression::ListLiteral;
@@ -1228,6 +1229,57 @@ impl PhysicalPlanner {
12281229
Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])),
12291230
))
12301231
}
1232+
1233+
OpStruct::BroadcastNestedLoopJoin(bnlj) => {
1234+
let (join_params, scans, shuffle_scans) = self.parse_join_parameters(
1235+
inputs,
1236+
children,
1237+
&[],
1238+
&[],
1239+
bnlj.join_type,
1240+
&bnlj.condition,
1241+
partition_count,
1242+
)?;
1243+
1244+
let left = Arc::clone(&join_params.left.native_plan);
1245+
let right = Arc::clone(&join_params.right.native_plan);
1246+
1247+
let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
1248+
left,
1249+
right,
1250+
join_params.join_filter,
1251+
&join_params.join_type,
1252+
None,
1253+
)?);
1254+
1255+
if bnlj.build_side == BuildSide::BuildRight as i32 {
1256+
let swapped_join = nested_loop_join.as_ref().swap_inputs()?;
1257+
let mut additional_native_plans = vec![];
1258+
if swapped_join.as_any().is::<ProjectionExec>() {
1259+
additional_native_plans.push(Arc::clone(swapped_join.children()[0]));
1260+
}
1261+
Ok((
1262+
scans,
1263+
shuffle_scans,
1264+
Arc::new(SparkPlan::new_with_additional(
1265+
spark_plan.plan_id,
1266+
swapped_join,
1267+
vec![join_params.left, join_params.right],
1268+
additional_native_plans,
1269+
)),
1270+
))
1271+
} else {
1272+
Ok((
1273+
scans,
1274+
shuffle_scans,
1275+
Arc::new(SparkPlan::new(
1276+
spark_plan.plan_id,
1277+
nested_loop_join,
1278+
vec![join_params.left, join_params.right],
1279+
)),
1280+
))
1281+
}
1282+
}
12311283
OpStruct::Limit(limit) => {
12321284
assert_eq!(children.len(), 1);
12331285
let num = limit.limit;

native/core/src/execution/planner/operator_registry.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,6 @@ fn get_operator_type(spark_operator: &Operator) -> Option<OperatorType> {
151151
OpStruct::Explode(_) => None, // Not yet in OperatorType enum
152152
OpStruct::CsvScan(_) => Some(OperatorType::CsvScan),
153153
OpStruct::ShuffleScan(_) => None, // Not yet in OperatorType enum
154+
OpStruct::BroadcastNestedLoopJoin(_) => None,
154155
}
155156
}

0 commit comments

Comments
 (0)