Skip to content

Commit c0c1828

Browse files
authored
Merge branch 'main' into parquet-reader-audit
2 parents fb8078c + 184a883 commit c0c1828

31 files changed

Lines changed: 472 additions & 73 deletions

File tree

dev/diffs/3.4.3.diff

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,17 @@ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/s
500500
index a6b295578d6..91acca4306f 100644
501501
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
502502
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
503-
@@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
503+
@@ -260,7 +260,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
504+
}
505+
}
506+
507+
- test("SPARK-33853: explain codegen - check presence of subquery") {
508+
+ test("SPARK-33853: explain codegen - check presence of subquery",
509+
+ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) {
510+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
511+
withTempView("df") {
512+
val df1 = spark.range(1, 100)
513+
@@ -463,7 +464,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
504514
}
505515
}
506516

@@ -510,7 +520,7 @@ index a6b295578d6..91acca4306f 100644
510520
withTempDir { dir =>
511521
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
512522
val basePath = dir.getCanonicalPath + "/" + fmt
513-
@@ -541,7 +542,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
523+
@@ -541,7 +543,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
514524
}
515525
}
516526

dev/diffs/3.5.8.diff

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,20 @@ index f33432ddb6f..b375e285dde 100644
478478
}
479479
assert(scanOption.isDefined)
480480
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
481-
index a206e97c353..fea1149b67d 100644
481+
index a206e97c353..8bd3ab5985a 100644
482482
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
483483
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
484-
@@ -467,7 +467,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
484+
@@ -264,7 +264,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
485+
}
486+
}
487+
488+
- test("SPARK-33853: explain codegen - check presence of subquery") {
489+
+ test("SPARK-33853: explain codegen - check presence of subquery",
490+
+ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) {
491+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
492+
withTempView("df") {
493+
val df1 = spark.range(1, 100)
494+
@@ -467,7 +468,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
485495
}
486496
}
487497

@@ -491,7 +501,7 @@ index a206e97c353..fea1149b67d 100644
491501
withTempDir { dir =>
492502
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
493503
val basePath = dir.getCanonicalPath + "/" + fmt
494-
@@ -545,7 +546,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
504+
@@ -545,7 +547,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
495505
}
496506
}
497507

dev/diffs/4.0.2.diff

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,20 @@ index 2c24cc7d570..12d897866da 100644
615615
}
616616
assert(scanOption.isDefined)
617617
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
618-
index 9c90e0105a4..fadf2f0f698 100644
618+
index 9c90e0105a4..ed6d4887b13 100644
619619
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
620620
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
621-
@@ -470,7 +470,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
621+
@@ -267,7 +267,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
622+
}
623+
}
624+
625+
- test("SPARK-33853: explain codegen - check presence of subquery") {
626+
+ test("SPARK-33853: explain codegen - check presence of subquery",
627+
+ IgnoreComet("Comet changes the WholeStageCodegen subtree count")) {
628+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
629+
withTempView("df") {
630+
val df1 = spark.range(1, 100)
631+
@@ -470,7 +471,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
622632
}
623633
}
624634

@@ -628,7 +638,7 @@ index 9c90e0105a4..fadf2f0f698 100644
628638
withTempDir { dir =>
629639
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
630640
val basePath = dir.getCanonicalPath + "/" + fmt
631-
@@ -548,7 +549,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
641+
@@ -548,7 +550,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
632642
}
633643
}
634644

dev/diffs/4.1.1.diff

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,10 +695,20 @@ index e1a2fd33c7c..632f4b695df 100644
695695
}
696696
assert(scanOption.isDefined)
697697
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
698-
index b27122a8de2..a4c5aac8212 100644
698+
index b27122a8de2..3c690dbe788 100644
699699
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
700700
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
701-
@@ -470,7 +470,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
701+
@@ -267,7 +267,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
702+
}
703+
}
704+
705+
- test("SPARK-33853: explain codegen - check presence of subquery") {
706+
+ test("SPARK-33853: explain codegen - check presence of subquery",
707+
+ IgnoreComet("Comet plan has a different WholeStageCodegen subtree count")) {
708+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
709+
withTempView("df") {
710+
val df1 = spark.range(1, 100)
711+
@@ -470,7 +471,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
702712
}
703713
}
704714

@@ -708,7 +718,7 @@ index b27122a8de2..a4c5aac8212 100644
708718
withTempDir { dir =>
709719
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
710720
val basePath = dir.getCanonicalPath + "/" + fmt
711-
@@ -548,7 +549,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
721+
@@ -548,7 +550,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
712722
}
713723
}
714724

docs/source/contributor-guide/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@
411411
- [x] randn
412412
- [ ] random
413413
- [ ] randstr
414-
- [ ] rint
414+
- [x] rint
415415
- [x] round
416416
- [x] sec
417417
- [x] shiftleft

docs/source/user-guide/latest/expressions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ of expressions that be disabled.
174174
| Rand | `rand` |
175175
| Randn | `randn` |
176176
| Remainder | `%` |
177+
| Rint | `rint` |
177178
| Round | `round` |
178179
| Sec | `sec` |
179180
| Signum | `signum` |

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ use datafusion_spark::function::map::str_to_map::SparkStrToMap;
6060
use datafusion_spark::function::math::expm1::SparkExpm1;
6161
use datafusion_spark::function::math::factorial::SparkFactorial;
6262
use datafusion_spark::function::math::hex::SparkHex;
63+
use datafusion_spark::function::math::rint::SparkRint;
6364
use datafusion_spark::function::math::trigonometry::SparkCsc;
6465
use datafusion_spark::function::math::trigonometry::SparkSec;
6566
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
@@ -605,6 +606,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
605606
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkTryParseUrl::default()));
606607
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkFactorial::default()));
607608
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSec::default()));
609+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkRint::default()));
608610
}
609611

610612
/// Prepares arrow arrays for output.

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ import scala.collection.mutable.ListBuffer
2323

2424
import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
26+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial}
2627
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2728
import org.apache.spark.sql.catalyst.rules.Rule
29+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
2830
import org.apache.spark.sql.catalyst.util.sideBySide
2931
import org.apache.spark.sql.comet._
3032
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
3133
import org.apache.spark.sql.comet.util.Utils
3234
import org.apache.spark.sql.execution._
3335
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
34-
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
36+
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
3537
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
3638
import org.apache.spark.sql.execution.datasources.WriteFilesExec
3739
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -57,6 +59,14 @@ import org.apache.comet.shims.{ShimCometStreaming, ShimSubqueryBroadcast}
5759

5860
object CometExecRule {
5961

62+
/**
63+
* Tag applied to Partial-mode aggregate operators that must NOT be converted to Comet because
64+
* the corresponding Final-mode aggregate cannot be converted, and the aggregate functions have
65+
* incompatible intermediate buffer formats between Spark and Comet.
66+
*/
67+
val COMET_UNSAFE_PARTIAL: TreeNodeTag[String] =
68+
TreeNodeTag[String]("comet.unsafePartialAgg")
69+
6070
/**
6171
* Fully native operators.
6272
*/
@@ -568,6 +578,12 @@ case class CometExecRule(session: SparkSession)
568578
normalizedPlan
569579
}
570580

581+
// Tag Partial aggregates that must not be converted to Comet because the
582+
// corresponding Final aggregate cannot be converted and the intermediate buffer
583+
// formats are incompatible. This runs before transform() so the tags are checked
584+
// during the bottom-up conversion. Tags persist through AQE stage creation.
585+
tagUnsafePartialAggregates(planWithJoinRewritten)
586+
571587
var newPlan = transform(planWithJoinRewritten)
572588

573589
// if the plan cannot be run fully natively then explain why (when appropriate
@@ -788,4 +804,129 @@ case class CometExecRule(session: SparkSession)
788804
}
789805
}
790806

807+
/**
808+
* Walk the plan to find Final-mode aggregates that cannot be converted to Comet. For each such
809+
* Final, if the aggregate functions have incompatible intermediate buffer formats, tag the
810+
* corresponding Partial-mode aggregate so it will also be skipped during conversion.
811+
*
812+
* This prevents the crash described in issue #1389 where a Comet Partial produces intermediate
813+
* data in a format that the Spark Final cannot interpret.
814+
*/
815+
private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = {
816+
plan.foreach {
817+
case agg: BaseAggregateExec =>
818+
// Only consider single-mode Final aggregates. Multi-mode Finals come from Spark's
819+
// distinct-aggregate rewrite, where the Comet partial (if any) feeds into a Spark
820+
// PartialMerge rather than directly into a Final, which is a different code path
821+
// than the Comet-Partial → Spark-Final crash scenario from issue #1389.
822+
val modes = agg.aggregateExpressions.map(_.mode).distinct
823+
if (modes == Seq(Final) &&
824+
!QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions) &&
825+
!canAggregateBeConverted(agg, Final)) {
826+
findPartialAggInPlan(agg.child).foreach { partial =>
827+
// Only tag if the Partial would otherwise have been converted. If the Partial
828+
// itself cannot be converted (e.g. the aggregate function is incompatible for the
829+
// input type), there is no buffer-format mismatch to guard against, and tagging
830+
// would mask the natural, more specific fallback reason.
831+
if (canAggregateBeConverted(partial, Partial)) {
832+
partial.setTagValue(
833+
CometExecRule.COMET_UNSAFE_PARTIAL,
834+
"Partial aggregate disabled: corresponding final aggregate " +
835+
"cannot be converted to Comet and intermediate buffer formats are incompatible")
836+
}
837+
}
838+
}
839+
case _ =>
840+
}
841+
}
842+
843+
/**
844+
* Conservative check for whether an aggregate could be converted to Comet. Checks operator
845+
* enablement, grouping expressions, aggregate expressions, and result expressions.
846+
* Intentionally skips the sparkFinalMode / child-native checks since those depend on
847+
* transformation state.
848+
*
849+
* WARNING: this intentionally mirrors the predicate checks in `CometBaseAggregate.doConvert`
850+
* (operators.scala). Any change to the convertibility rules there must be reflected here or
851+
* this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A
852+
* shared predicate helper would be preferable.
853+
*/
854+
private def canAggregateBeConverted(
855+
agg: BaseAggregateExec,
856+
expectedMode: AggregateMode): Boolean = {
857+
val handler = allExecs.get(agg.getClass)
858+
if (handler.isEmpty) return false
859+
val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]]
860+
if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false
861+
862+
// ObjectHashAggregate has an extra shuffle-enabled guard in its convert method
863+
agg match {
864+
case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) => return false
865+
case _ =>
866+
}
867+
868+
val aggregateExpressions = agg.aggregateExpressions
869+
val groupingExpressions = agg.groupingExpressions
870+
871+
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return false
872+
873+
if (groupingExpressions.exists(e => QueryPlanSerde.containsMapType(e.dataType))) return false
874+
875+
if (!groupingExpressions.forall(e =>
876+
QueryPlanSerde.exprToProto(e, agg.child.output).isDefined)) {
877+
return false
878+
}
879+
880+
if (aggregateExpressions.isEmpty) {
881+
// Result expressions always checked when there are no aggregate expressions
882+
val attributes =
883+
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
884+
return agg.resultExpressions.forall(e =>
885+
QueryPlanSerde.exprToProto(e, attributes).isDefined)
886+
}
887+
888+
val modes = aggregateExpressions.map(_.mode).distinct
889+
if (modes.size != 1 || modes.head != expectedMode) return false
890+
891+
// In Final mode, exprToProto resolves against the child's output; in Partial/non-Final mode
892+
// it must bind to input attributes. This mirrors the `binding` calculation in
893+
// `CometBaseAggregate.doConvert`.
894+
val binding = expectedMode != Final
895+
if (!aggregateExpressions.forall(e =>
896+
QueryPlanSerde.aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
897+
return false
898+
}
899+
900+
// doConvert only checks resultExpressions in Final mode when aggregate expressions exist
901+
// (Partial emits the buffer directly). Mirror that here to avoid false negatives.
902+
if (expectedMode == Final) {
903+
val attributes =
904+
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
905+
agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e, attributes).isDefined)
906+
} else {
907+
true
908+
}
909+
}
910+
911+
/**
912+
* Look for a Partial-mode aggregate that feeds directly into the given plan (the child of a
913+
* Final). Walks through exchanges and AQE stages only, stopping at anything else including
914+
* other aggregate stages. This avoids tagging unrelated Partials found deeper in the plan (e.g.
915+
* the non-distinct Partial in a distinct-aggregate rewrite, which is separated from the Final
916+
* by intermediate PartialMerge stages). Requires `aggregateExpressions.nonEmpty` so that
917+
* group-by-only dedup stages are not mistaken for the partial we want to tag.
918+
*/
919+
private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match {
920+
case agg: BaseAggregateExec
921+
if agg.aggregateExpressions.nonEmpty &&
922+
agg.aggregateExpressions.forall(e => e.mode == Partial) =>
923+
Some(agg)
924+
case a: AQEShuffleReadExec => findPartialAggInPlan(a.child)
925+
case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan)
926+
case e: ShuffleExchangeExec => findPartialAggInPlan(e.child)
927+
case other =>
928+
logDebug(s"findPartialAggInPlan: stopping at ${other.nodeName}; not a known passthrough")
929+
None
930+
}
931+
791932
}

spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
8181
*/
8282
def getSupportLevel(expr: T): SupportLevel = Compatible(None)
8383

84+
/**
85+
* Whether this aggregate's intermediate buffer format is compatible between Spark and Comet,
86+
* making it safe to run the Partial in one engine and the Final in the other. Aggregates with
87+
* simple single-value buffers (MIN, MAX, bitwise) are safe; those with complex or
88+
* differently-encoded buffers (AVG, SUM with decimals, CollectSet, Variance) are not. COUNT is
89+
* intentionally excluded: mixed COUNT partial/final regressed AQE's
90+
* PropagateEmptyRelationAfterAQE pattern (which matches BaseAggregateExec only) and the Spark
91+
* 4.0 count-bug decorrelation for correlated IN subqueries.
92+
*/
93+
def supportsMixedPartialFinal: Boolean = false
94+
8495
/**
8596
* Convert a Spark expression into a protocol buffer representation that can be passed into
8697
* native code.

0 commit comments

Comments
 (0)