Skip to content

Commit 48a352f

Browse files
author
weixiuli
committed
Support the null values in bloom_filter Spark aggregate
- Add unittests - Related to the oap-project/velox#458
1 parent 8c6f164 commit 48a352f

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package io.glutenproject.execution
18+
19+
import org.apache.spark.sql.Row
20+
import org.apache.spark.sql.catalyst.FunctionIdentifier
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
23+
import org.apache.spark.sql.types.{IntegerType, LongType, StructType}
24+
25+
class VeloxBloomFilterAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite {
26+
val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
27+
override protected val backend: String = "velox"
28+
override protected val resourcePath: String = "/tpch-data-parquet-velox"
29+
override protected val fileFormat: String = "parquet"
30+
val table = "bloomTable"
31+
32+
protected def registerFunAndcCreatTable(): Unit = {
33+
val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
34+
// Register 'bloom_filter_agg'
35+
spark.sessionState.functionRegistry.registerFunction(
36+
funcId_bloom_filter_agg,
37+
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
38+
(children: Seq[Expression]) =>
39+
children.size match {
40+
case 1 => new BloomFilterAggregate(children.head)
41+
case 2 => new BloomFilterAggregate(children.head, children(1))
42+
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
43+
}
44+
)
45+
val schema2 = new StructType()
46+
.add("a2", IntegerType, nullable = true)
47+
.add("b2", LongType, nullable = true)
48+
.add("c2", IntegerType, nullable = true)
49+
.add("d2", IntegerType, nullable = true)
50+
.add("e2", IntegerType, nullable = true)
51+
.add("f2", IntegerType, nullable = true)
52+
val data2 = Seq(
53+
Seq(67, 17L, 45, 91, null, null),
54+
Seq(98, 63L, 0, 89, null, 40),
55+
Seq(null, null, 68, 75, 20, 19))
56+
val rdd2 = spark.sparkContext.parallelize(data2)
57+
val rddRow2 = rdd2.map(s => Row.fromSeq(s))
58+
spark.createDataFrame(rddRow2, schema2).write.saveAsTable(table)
59+
}
60+
61+
protected def dropFunctionAndTable(): Unit = {
62+
spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
63+
spark.sql(s"DROP TABLE IF EXISTS $table")
64+
}
65+
66+
override def beforeAll(): Unit = {
67+
super.beforeAll()
68+
registerFunAndcCreatTable()
69+
}
70+
71+
override def afterAll(): Unit = {
72+
dropFunctionAndTable()
73+
super.afterAll()
74+
}
75+
76+
test("Test bloom_filter_agg with Nulls input") {
77+
spark
78+
.sql(s"""
79+
|SELECT bloom_filter_agg(b2) from $table
80+
""".stripMargin)
81+
.show
82+
}
83+
}

0 commit comments

Comments
 (0)