@@ -23,15 +23,17 @@ import scala.collection.mutable.ListBuffer
2323
2424import org .apache .spark .sql .SparkSession
2525import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
26+ import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateMode , Final , Partial }
2627import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
2728import org .apache .spark .sql .catalyst .rules .Rule
29+ import org .apache .spark .sql .catalyst .trees .TreeNodeTag
2830import org .apache .spark .sql .catalyst .util .sideBySide
2931import org .apache .spark .sql .comet ._
3032import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec }
3133import org .apache .spark .sql .comet .util .Utils
3234import org .apache .spark .sql .execution ._
3335import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanExec , AQEShuffleReadExec , BroadcastQueryStageExec , ShuffleQueryStageExec }
34- import org .apache .spark .sql .execution .aggregate .{HashAggregateExec , ObjectHashAggregateExec }
36+ import org .apache .spark .sql .execution .aggregate .{BaseAggregateExec , HashAggregateExec , ObjectHashAggregateExec }
3537import org .apache .spark .sql .execution .command .{DataWritingCommandExec , ExecutedCommandExec }
3638import org .apache .spark .sql .execution .datasources .WriteFilesExec
3739import org .apache .spark .sql .execution .datasources .csv .CSVFileFormat
@@ -57,6 +59,14 @@ import org.apache.comet.shims.{ShimCometStreaming, ShimSubqueryBroadcast}
5759
5860object CometExecRule {
5961
62+ /**
63+ * Tag applied to Partial-mode aggregate operators that must NOT be converted to Comet because
64+ * the corresponding Final-mode aggregate cannot be converted, and the aggregate functions have
65+ * incompatible intermediate buffer formats between Spark and Comet.
66+ */
67+ val COMET_UNSAFE_PARTIAL : TreeNodeTag [String ] =
68+ TreeNodeTag [String ](" comet.unsafePartialAgg" )
69+
6070 /**
6171 * Fully native operators.
6272 */
@@ -568,6 +578,12 @@ case class CometExecRule(session: SparkSession)
568578 normalizedPlan
569579 }
570580
581+ // Tag Partial aggregates that must not be converted to Comet because the
582+ // corresponding Final aggregate cannot be converted and the intermediate buffer
583+ // formats are incompatible. This runs before transform() so the tags are checked
584+ // during the bottom-up conversion. Tags persist through AQE stage creation.
585+ tagUnsafePartialAggregates(planWithJoinRewritten)
586+
571587 var newPlan = transform(planWithJoinRewritten)
572588
573589 // if the plan cannot be run fully natively then explain why (when appropriate
@@ -788,4 +804,129 @@ case class CometExecRule(session: SparkSession)
788804 }
789805 }
790806
807+ /**
808+ * Walk the plan to find Final-mode aggregates that cannot be converted to Comet. For each such
809+ * Final, if the aggregate functions have incompatible intermediate buffer formats, tag the
810+ * corresponding Partial-mode aggregate so it will also be skipped during conversion.
811+ *
812+ * This prevents the crash described in issue #1389 where a Comet Partial produces intermediate
813+ * data in a format that the Spark Final cannot interpret.
814+ */
815+ private def tagUnsafePartialAggregates (plan : SparkPlan ): Unit = {
816+ plan.foreach {
817+ case agg : BaseAggregateExec =>
818+ // Only consider single-mode Final aggregates. Multi-mode Finals come from Spark's
819+ // distinct-aggregate rewrite, where the Comet partial (if any) feeds into a Spark
820+ // PartialMerge rather than directly into a Final, which is a different code path
821+ // than the Comet-Partial → Spark-Final crash scenario from issue #1389.
822+ val modes = agg.aggregateExpressions.map(_.mode).distinct
823+ if (modes == Seq (Final ) &&
824+ ! QueryPlanSerde .allAggsSupportMixedExecution(agg.aggregateExpressions) &&
825+ ! canAggregateBeConverted(agg, Final )) {
826+ findPartialAggInPlan(agg.child).foreach { partial =>
827+ // Only tag if the Partial would otherwise have been converted. If the Partial
828+ // itself cannot be converted (e.g. the aggregate function is incompatible for the
829+ // input type), there is no buffer-format mismatch to guard against, and tagging
830+ // would mask the natural, more specific fallback reason.
831+ if (canAggregateBeConverted(partial, Partial )) {
832+ partial.setTagValue(
833+ CometExecRule .COMET_UNSAFE_PARTIAL ,
834+ " Partial aggregate disabled: corresponding final aggregate " +
835+ " cannot be converted to Comet and intermediate buffer formats are incompatible" )
836+ }
837+ }
838+ }
839+ case _ =>
840+ }
841+ }
842+
843+ /**
844+ * Conservative check for whether an aggregate could be converted to Comet. Checks operator
845+ * enablement, grouping expressions, aggregate expressions, and result expressions.
846+ * Intentionally skips the sparkFinalMode / child-native checks since those depend on
847+ * transformation state.
848+ *
849+ * WARNING: this intentionally mirrors the predicate checks in `CometBaseAggregate.doConvert`
850+ * (operators.scala). Any change to the convertibility rules there must be reflected here or
851+ * this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A
852+ * shared predicate helper would be preferable.
853+ */
854+ private def canAggregateBeConverted (
855+ agg : BaseAggregateExec ,
856+ expectedMode : AggregateMode ): Boolean = {
857+ val handler = allExecs.get(agg.getClass)
858+ if (handler.isEmpty) return false
859+ val serde = handler.get.asInstanceOf [CometOperatorSerde [SparkPlan ]]
860+ if (! isOperatorEnabled(serde, agg.asInstanceOf [SparkPlan ])) return false
861+
862+ // ObjectHashAggregate has an extra shuffle-enabled guard in its convert method
863+ agg match {
864+ case _ : ObjectHashAggregateExec if ! isCometShuffleEnabled(agg.conf) => return false
865+ case _ =>
866+ }
867+
868+ val aggregateExpressions = agg.aggregateExpressions
869+ val groupingExpressions = agg.groupingExpressions
870+
871+ if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return false
872+
873+ if (groupingExpressions.exists(e => QueryPlanSerde .containsMapType(e.dataType))) return false
874+
875+ if (! groupingExpressions.forall(e =>
876+ QueryPlanSerde .exprToProto(e, agg.child.output).isDefined)) {
877+ return false
878+ }
879+
880+ if (aggregateExpressions.isEmpty) {
881+ // Result expressions always checked when there are no aggregate expressions
882+ val attributes =
883+ groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
884+ return agg.resultExpressions.forall(e =>
885+ QueryPlanSerde .exprToProto(e, attributes).isDefined)
886+ }
887+
888+ val modes = aggregateExpressions.map(_.mode).distinct
889+ if (modes.size != 1 || modes.head != expectedMode) return false
890+
891+ // In Final mode, exprToProto resolves against the child's output; in Partial/non-Final mode
892+ // it must bind to input attributes. This mirrors the `binding` calculation in
893+ // `CometBaseAggregate.doConvert`.
894+ val binding = expectedMode != Final
895+ if (! aggregateExpressions.forall(e =>
896+ QueryPlanSerde .aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
897+ return false
898+ }
899+
900+ // doConvert only checks resultExpressions in Final mode when aggregate expressions exist
901+ // (Partial emits the buffer directly). Mirror that here to avoid false negatives.
902+ if (expectedMode == Final ) {
903+ val attributes =
904+ groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
905+ agg.resultExpressions.forall(e => QueryPlanSerde .exprToProto(e, attributes).isDefined)
906+ } else {
907+ true
908+ }
909+ }
910+
911+ /**
912+ * Look for a Partial-mode aggregate that feeds directly into the given plan (the child of a
913+ * Final). Walks through exchanges and AQE stages only, stopping at anything else including
914+ * other aggregate stages. This avoids tagging unrelated Partials found deeper in the plan (e.g.
915+ * the non-distinct Partial in a distinct-aggregate rewrite, which is separated from the Final
916+ * by intermediate PartialMerge stages). Requires `aggregateExpressions.nonEmpty` so that
917+ * group-by-only dedup stages are not mistaken for the partial we want to tag.
918+ */
919+ private def findPartialAggInPlan (plan : SparkPlan ): Option [BaseAggregateExec ] = plan match {
920+ case agg : BaseAggregateExec
921+ if agg.aggregateExpressions.nonEmpty &&
922+ agg.aggregateExpressions.forall(e => e.mode == Partial ) =>
923+ Some (agg)
924+ case a : AQEShuffleReadExec => findPartialAggInPlan(a.child)
925+ case s : ShuffleQueryStageExec => findPartialAggInPlan(s.plan)
926+ case e : ShuffleExchangeExec => findPartialAggInPlan(e.child)
927+ case other =>
928+ logDebug(s " findPartialAggInPlan: stopping at ${other.nodeName}; not a known passthrough " )
929+ None
930+ }
931+
791932}
0 commit comments