Skip to content

Commit 3bb082a

Browse files
authored
[spark] Fix group by partial partition of a multi partition table (apache#6375)
1 parent db56793 commit 3bb082a

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import org.apache.paimon.spark.data.SparkInternalRow
2424
import org.apache.paimon.stats.SimpleStatsEvolutions
2525
import org.apache.paimon.table.FileStoreTable
2626
import org.apache.paimon.table.source.DataSplit
27-
import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow}
27+
import org.apache.paimon.types.RowType
28+
import org.apache.paimon.utils.ProjectedRow
2829

2930
import org.apache.spark.sql.catalyst.InternalRow
3031
import org.apache.spark.sql.catalyst.expressions.JoinedRow
@@ -40,7 +41,8 @@ class LocalAggregator(table: FileStoreTable) {
4041
private val partitionType = SparkTypeUtils.toPartitionType(table)
4142
private val groupByEvaluatorMap = new mutable.HashMap[InternalRow, Seq[AggFuncEvaluator[_]]]()
4243
private var requiredGroupByType: Seq[DataType] = _
43-
private var requiredGroupByIndexMapping: Seq[Int] = _
44+
private var requiredGroupByIndexMapping: Array[Int] = _
45+
private var requiredGroupByPaimonType: RowType = _
4446
private var aggFuncEvaluatorGetter: () => Seq[AggFuncEvaluator[_]] = _
4547
private var isInitialized = false
4648
private lazy val simpleStatsEvolutions = {
@@ -78,15 +80,14 @@ class LocalAggregator(table: FileStoreTable) {
7880
partitionType.getFieldIndex(r.fieldNames().head)
7981
}
8082

83+
requiredGroupByPaimonType = partitionType.project(requiredGroupByIndexMapping)
84+
8185
isInitialized = true
8286
}
8387

8488
private def requiredGroupByRow(partitionRow: BinaryRow): InternalRow = {
85-
val projectedRow =
86-
ProjectedRow.from(requiredGroupByIndexMapping.toArray).replaceRow(partitionRow)
87-
// `ProjectedRow` does not support `hashCode`, so do a deep copy
88-
val genericRow = InternalRowUtils.copyInternalRow(projectedRow, partitionType)
89-
SparkInternalRow.create(partitionType).replace(genericRow)
89+
val projectedRow = ProjectedRow.from(requiredGroupByIndexMapping).replaceRow(partitionRow)
90+
SparkInternalRow.create(requiredGroupByPaimonType).replace(projectedRow)
9091
}
9192

9293
def update(dataSplit: DataSplit): Unit = {

paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,4 +267,31 @@ class PushDownAggregatesTest extends PaimonSparkTestBase with AdaptiveSparkPlanH
267267
})
268268
})
269269
}
270+
271+
test("Push down aggregate: group by partial partition of a multi partition table") {
272+
sql(s"""
273+
|CREATE TABLE T (
274+
|c1 STRING,
275+
|c2 STRING,
276+
|c3 STRING,
277+
|c4 STRING,
278+
|c5 DATE)
279+
|PARTITIONED BY (c5, c1)
280+
|TBLPROPERTIES ('primary-key' = 'c5, c1, c3')
281+
|""".stripMargin)
282+
283+
sql("INSERT INTO T VALUES ('t1', 'k1', 'v1', 'r1', '2025-01-01')")
284+
checkAnswer(
285+
sql("SELECT COUNT(*) FROM T GROUP BY c1"),
286+
Seq(Row(1))
287+
)
288+
checkAnswer(
289+
sql("SELECT c1, COUNT(*) FROM T GROUP BY c1"),
290+
Seq(Row("t1", 1))
291+
)
292+
checkAnswer(
293+
sql("SELECT COUNT(*), c1 FROM T GROUP BY c1"),
294+
Seq(Row(1, "t1"))
295+
)
296+
}
270297
}

0 commit comments

Comments
 (0)