Skip to content

Commit e939e51

Browse files
committed
feat: route higher-order functions through codegen dispatcher
Register the array and map higher-order (lambda) functions that previously fell back to Spark so they stay native via the codegen dispatcher: - array: transform, exists, forall, aggregate, array_sort (comparator), zip_with - map: map_filter, transform_keys, transform_values, map_zip_with These have no native (rust) implementation and extend Spark's CodegenFallback, which the dispatcher's canHandle already admits, so the projection stays native and matches Spark exactly. When the dispatcher is disabled they fall back to Spark. Adds CometHigherOrderFunctionSuite covering the dispatch path and the fallback-when-disabled path for each family.
1 parent ceecae7 commit e939e51

4 files changed

Lines changed: 185 additions & 3 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
7373
classOf[Flatten] -> CometFlatten,
7474
classOf[GetArrayItem] -> CometGetArrayItem,
7575
classOf[Size] -> CometSize,
76-
classOf[ArraysZip] -> CometArraysZip)
76+
classOf[ArraysZip] -> CometArraysZip,
77+
classOf[ArrayTransform] -> CometArrayTransform,
78+
classOf[ArrayExists] -> CometArrayExists,
79+
classOf[ArrayForAll] -> CometArrayForAll,
80+
classOf[ArrayAggregate] -> CometArrayAggregate,
81+
classOf[ArraySort] -> CometArraySort,
82+
classOf[ZipWith] -> CometZipWith)
7783

7884
private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] =
7985
Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf)
@@ -153,7 +159,11 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
153159
classOf[MapFromArrays] -> CometMapFromArrays,
154160
classOf[MapContainsKey] -> CometMapContainsKey,
155161
classOf[MapFromEntries] -> CometMapFromEntries,
156-
classOf[StringToMap] -> CometStrToMap)
162+
classOf[StringToMap] -> CometStrToMap,
163+
classOf[MapFilter] -> CometMapFilter,
164+
classOf[TransformKeys] -> CometTransformKeys,
165+
classOf[TransformValues] -> CometTransformValues,
166+
classOf[MapZipWith] -> CometMapZipWith)
157167

158168
private[comet] val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] =
159169
Map(

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet.serde
2222
import scala.annotation.tailrec
2323
import scala.jdk.CollectionConverters._
2424

25-
import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray}
25+
import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray, ZipWith}
2626
import org.apache.spark.sql.catalyst.util.GenericArrayData
2727
import org.apache.spark.sql.internal.SQLConf
2828
import org.apache.spark.sql.types._
@@ -843,3 +843,19 @@ trait ArraysBase {
843843
}
844844
}
845845
}
846+
847+
// Array higher-order (lambda) functions have no native (rust) implementation. They extend
848+
// Spark's `CodegenFallback`, which the codegen dispatcher admits: the lambda body evaluates
849+
// through the kernel's typed Arrow getters, so the projection stays native while matching Spark
850+
// exactly. See `CometCodegenHOFSuite`.
851+
object CometArrayTransform extends CometCodegenDispatch[ArrayTransform]
852+
853+
object CometArrayExists extends CometCodegenDispatch[ArrayExists]
854+
855+
object CometArrayForAll extends CometCodegenDispatch[ArrayForAll]
856+
857+
object CometArrayAggregate extends CometCodegenDispatch[ArrayAggregate]
858+
859+
object CometArraySort extends CometCodegenDispatch[ArraySort]
860+
861+
object CometZipWith extends CometCodegenDispatch[ZipWith]

spark/src/main/scala/org/apache/comet/serde/maps.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,14 @@ object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from
163163
}
164164

165165
object CometStrToMap extends CometScalarFunction[StringToMap]("str_to_map")
166+
167+
// Map higher-order (lambda) functions have no native (rust) implementation. They extend Spark's
168+
// `CodegenFallback`, which the codegen dispatcher admits, so the projection stays native while
169+
// matching Spark exactly. See `CometCodegenHOFSuite`.
170+
object CometMapFilter extends CometCodegenDispatch[MapFilter]
171+
172+
object CometTransformKeys extends CometCodegenDispatch[TransformKeys]
173+
174+
object CometTransformValues extends CometCodegenDispatch[TransformValues]
175+
176+
object CometMapZipWith extends CometCodegenDispatch[MapZipWith]
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import org.apache.spark.sql.CometTestBase
23+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
24+
25+
/**
26+
* Higher-order (lambda) functions have no native (rust) implementation. They are wired into the
27+
* codegen dispatcher via [[org.apache.comet.serde.CometArrayTransform]] and friends, so a
28+
* top-level HOF projection stays native (running Spark's own `doGenCode`/interpreted-eval inside
29+
* the Comet kernel) and matches Spark exactly. When the dispatcher is disabled, the HOF has no
30+
* native path and the projection falls back to Spark.
31+
*
32+
* `CometCodegenHOFSuite` covers HOFs nested inside a registered `ScalaUDF`; this suite covers the
33+
* HOF as the top-level expression, which only stays native once the serde is registered.
34+
*/
35+
class CometHigherOrderFunctionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
36+
37+
private def withArrayTable(thunk: => Unit): Unit = {
38+
withTable("t") {
39+
sql("CREATE TABLE t (a ARRAY<INT>, b ARRAY<INT>) USING parquet")
40+
sql(
41+
"INSERT INTO t VALUES " +
42+
"(array(1, 2, 3), array(10, 20, 30)), " +
43+
"(array(-5, 5), array(1, 2)), " +
44+
"(array(), array()), " +
45+
"(null, null)")
46+
thunk
47+
}
48+
}
49+
50+
private def withMapTable(thunk: => Unit): Unit = {
51+
withTable("t") {
52+
sql("CREATE TABLE t (m MAP<STRING, INT>, n MAP<STRING, INT>) USING parquet")
53+
sql(
54+
"INSERT INTO t VALUES " +
55+
"(map('a', 1, 'b', 2), map('a', 10, 'c', 30)), " +
56+
"(map('x', -1), map('x', 5)), " +
57+
"(map(), map()), " +
58+
"(null, null)")
59+
thunk
60+
}
61+
}
62+
63+
test("transform") {
64+
withArrayTable {
65+
checkSparkAnswerAndOperator("SELECT transform(a, x -> x + 1) FROM t")
66+
}
67+
}
68+
69+
test("exists") {
70+
withArrayTable {
71+
checkSparkAnswerAndOperator("SELECT exists(a, x -> x > 2) FROM t")
72+
}
73+
}
74+
75+
test("forall") {
76+
withArrayTable {
77+
checkSparkAnswerAndOperator("SELECT forall(a, x -> x > 0) FROM t")
78+
}
79+
}
80+
81+
test("aggregate") {
82+
withArrayTable {
83+
checkSparkAnswerAndOperator("SELECT aggregate(a, 0, (acc, x) -> acc + x) FROM t")
84+
}
85+
}
86+
87+
test("array_sort with comparator") {
88+
withArrayTable {
89+
checkSparkAnswerAndOperator(
90+
"SELECT array_sort(a, (l, r) -> " +
91+
"CASE WHEN l < r THEN 1 WHEN l > r THEN -1 ELSE 0 END) FROM t")
92+
}
93+
}
94+
95+
test("zip_with") {
96+
withArrayTable {
97+
checkSparkAnswerAndOperator("SELECT zip_with(a, b, (x, y) -> x + y) FROM t")
98+
}
99+
}
100+
101+
test("map_filter") {
102+
withMapTable {
103+
checkSparkAnswerAndOperator("SELECT map_filter(m, (k, v) -> v > 1) FROM t")
104+
}
105+
}
106+
107+
test("transform_keys") {
108+
withMapTable {
109+
checkSparkAnswerAndOperator("SELECT transform_keys(m, (k, v) -> concat(k, 'X')) FROM t")
110+
}
111+
}
112+
113+
test("transform_values") {
114+
withMapTable {
115+
checkSparkAnswerAndOperator("SELECT transform_values(m, (k, v) -> v + 1) FROM t")
116+
}
117+
}
118+
119+
test("map_zip_with") {
120+
withMapTable {
121+
checkSparkAnswerAndOperator(
122+
"SELECT map_zip_with(m, n, (k, v1, v2) -> coalesce(v1, 0) + coalesce(v2, 0)) FROM t")
123+
}
124+
}
125+
126+
test("array HOF falls back to Spark when codegen dispatcher disabled") {
127+
withSQLConf(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") {
128+
withArrayTable {
129+
checkSparkAnswerAndFallbackReason(
130+
"SELECT transform(a, x -> x + 1) FROM t",
131+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key)
132+
}
133+
}
134+
}
135+
136+
test("map HOF falls back to Spark when codegen dispatcher disabled") {
137+
withSQLConf(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") {
138+
withMapTable {
139+
checkSparkAnswerAndFallbackReason(
140+
"SELECT transform_values(m, (k, v) -> v + 1) FROM t",
141+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key)
142+
}
143+
}
144+
}
145+
}

0 commit comments

Comments
 (0)