@@ -824,7 +824,7 @@ index 53e47f428c3..a55d8f0c161 100644
824824 assert(shuffleMergeJoins.size == 1)
825825 }
826826diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
827- index aaac0ebc9aa..fbef0774d46 100644
827+ index aaac0ebc9aa..276c592ec88 100644
828828--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
829829+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
830830@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -972,7 +972,19 @@ index aaac0ebc9aa..fbef0774d46 100644
972972 checkAnswer(shjNonCodegenDF, Seq.empty)
973973 }
974974 }
975- @@ -1489,7 +1510,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
975+ @@ -1473,7 +1494,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
976+ "/*+ BROADCAST(t2) */ t1.k as k"
977+ }
978+ val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
979+ - assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
980+ + assert(collect(plan) {
981+ + case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
982+ + true
983+ + }.size === 1)
984+ // No extra shuffle before aggregation
985+ assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 0)
986+ }
987+ @@ -1489,7 +1513,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
976988 val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan
977989 assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
978990 // Have shuffle before aggregation
@@ -982,12 +994,17 @@ index aaac0ebc9aa..fbef0774d46 100644
982994 }
983995
984996 def getJoinQuery(selectExpr: String, joinType: String): String = {
985- @@ -1518,9 +1540,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
997+ @@ -1517,10 +1542,16 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
998+ "/*+ BROADCAST(right_t) */ k1 as k0"
986999 }
9871000 val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
988- assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
1001+ - assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
9891002- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
9901003+ assert(collect(plan) {
1004+ + case _: BroadcastNestedLoopJoinExec | _: CometBroadcastNestedLoopJoinExec =>
1005+ + true
1006+ + }.size === 1)
1007+ + assert(collect(plan) {
9911008+ case _: SortMergeJoinExec => true
9921009+ case _: CometSortMergeJoinExec => true
9931010+ }.size === 3)
@@ -997,7 +1014,7 @@ index aaac0ebc9aa..fbef0774d46 100644
9971014 }
9981015
9991016 // Test output ordering is not preserved
1000- @@ -1529,9 +1554 ,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1017+ @@ -1529,9 +1560 ,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10011018 val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
10021019 val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
10031020 assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
@@ -1012,7 +1029,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10121029 }
10131030
10141031 // Test singe partition
1015- @@ -1541,7 +1569 ,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1032+ @@ -1541,7 +1575 ,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10161033 |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
10171034 |""".stripMargin)
10181035 val plan = fullJoinDF.queryExecution.executedPlan
@@ -1022,7 +1039,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10221039 checkAnswer(fullJoinDF, Row(100))
10231040 }
10241041 }
1025- @@ -1614,6 +1643 ,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1042+ @@ -1614,6 +1649 ,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10261043 Seq(semiJoinDF, antiJoinDF).foreach { df =>
10271044 assert(collect(df.queryExecution.executedPlan) {
10281045 case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true
@@ -1032,7 +1049,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10321049 }.size == 1)
10331050 }
10341051 }
1035- @@ -1658,14 +1690 ,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
1052+ @@ -1658,14 +1696 ,20 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10361053
10371054 test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SMJ)") {
10381055 def check(plan: SparkPlan): Unit = {
@@ -1055,7 +1072,7 @@ index aaac0ebc9aa..fbef0774d46 100644
10551072 }
10561073 dupStreamSideColTest("SHUFFLE_HASH", check)
10571074 }
1058- @@ -1801,7 +1839 ,8 @@ class ThreadLeakInSortMergeJoinSuite
1075+ @@ -1801,7 +1845 ,8 @@ class ThreadLeakInSortMergeJoinSuite
10591076 sparkConf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, 20))
10601077 }
10611078
0 commit comments