Skip to content

Commit 7cf04c8

Browse files
authored
fix: prevent native sort crash for Struct(Map(...)) keys (#4157)
1 parent 8f3cee5 commit 7cf04c8

2 files changed

Lines changed: 54 additions & 7 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
4343
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
4444
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
4545
import org.apache.spark.sql.internal.SQLConf
46-
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType, TimestampType}
46+
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType}
4747
import org.apache.spark.sql.vectorized.ColumnarBatch
4848
import org.apache.spark.util.SerializableConfiguration
4949
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -1404,6 +1404,13 @@ case class CometUnionExec(
14041404

14051405
trait CometBaseAggregate {
14061406

1407+
private def containsMapType(dt: DataType): Boolean = dt match {
1408+
case _: MapType => true
1409+
case StructType(fields) => fields.exists(f => containsMapType(f.dataType))
1410+
case ArrayType(elementType, _) => containsMapType(elementType)
1411+
case _ => false
1412+
}
1413+
14071414
def doConvert(
14081415
aggregate: BaseAggregateExec,
14091416
builder: Operator.Builder,
@@ -1434,12 +1441,8 @@ trait CometBaseAggregate {
14341441
return None
14351442
}
14361443

1437-
if (groupingExpressions.exists(expr =>
1438-
expr.dataType match {
1439-
case _: MapType => true
1440-
case _ => false
1441-
})) {
1442-
withInfo(aggregate, "Grouping on map types is not supported")
1444+
if (groupingExpressions.exists(expr => containsMapType(expr.dataType))) {
1445+
withInfo(aggregate, "Grouping on map-containing types is not supported")
14431446
return None
14441447
}
14451448

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
3333

3434
import org.apache.comet.CometConf
3535
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
36+
import org.apache.comet.CometSparkSessionExtensions.isSpark41Plus
3637
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions}
3738

3839
/**
@@ -408,6 +409,33 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
408409
}
409410
}
410411

412+
test("grouping on struct containing map should fallback to Spark") {
413+
assume(isSpark41Plus, "Spark 4.1+ supports grouping on map-containing types")
414+
withSQLConf(
415+
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true",
416+
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
417+
val query =
418+
"""SELECT col1.data['key']
419+
|FROM VALUES
420+
| (NAMED_STRUCT('data', MAP('key', 'value', 'num', '42'))),
421+
| (NAMED_STRUCT('data', MAP('key', 'other', 'num', '7')))
422+
|t (col1)
423+
|GROUP BY col1
424+
|HAVING col1.data['num'] IS NOT NULL
425+
|ORDER BY col1.data['key']
426+
|""".stripMargin
427+
428+
val (_, cometPlan) =
429+
checkSparkAnswerAndFallbackReason(
430+
query,
431+
"Grouping on map-containing types is not supported")
432+
433+
assert(
434+
stripAQEPlan(cometPlan).collect { case s: CometHashAggregateExec => s }.isEmpty,
435+
"Expected aggregate to fall back to Spark for grouping on Struct(Map(...))")
436+
}
437+
}
438+
411439
test("simple SUM, COUNT, MIN, MAX, AVG with non-distinct + null group keys") {
412440
Seq(true, false).foreach { dictionaryEnabled =>
413441
withParquetTable(
@@ -2059,4 +2087,20 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
20592087
sparkPlan.collect { case s: CometHashAggregateExec => s }.size
20602088
}
20612089

2090+
test("group by array of map falls back to Spark (issue #4123)") {
2091+
assume(isSpark41Plus, "Spark 4.1+ supports grouping on map-containing types")
2092+
withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
2093+
checkSparkAnswerAndFallbackReason(
2094+
"""SELECT a, COUNT(*)
2095+
|FROM VALUES
2096+
| (ARRAY(MAP('x', 10))),
2097+
| (ARRAY(MAP('y', 20))),
2098+
| (ARRAY(MAP('x', 10)))
2099+
|t (a)
2100+
|GROUP BY a
2101+
|""".stripMargin,
2102+
"Grouping on map-containing types is not supported")
2103+
}
2104+
}
2105+
20622106
}

0 commit comments

Comments
 (0)