Skip to content

Commit 1620eeb

Browse files
authored
[VL] Use the same result type for decimal round (#4621)
1 parent 5c6aae0 commit 1620eeb

File tree

4 files changed

+98
-1
lines changed

4 files changed

+98
-1
lines changed

backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
3535
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
3636
import org.apache.spark.sql.catalyst.catalog.BucketSpec
3737
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
38-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, NamedExpression, NaNvl, StringSplit, StringTrim}
38+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, NamedExpression, NaNvl, Round, StringSplit, StringTrim}
3939
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter}
4040
import org.apache.spark.sql.catalyst.optimizer.BuildSide
4141
import org.apache.spark.sql.catalyst.plans.JoinType
@@ -357,6 +357,14 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
357357
* * Expressions.
358358
*/
359359

360+
/** Generates a transformer for decimal round. */
361+
override def genDecimalRoundTransformer(
362+
substraitExprName: String,
363+
child: ExpressionTransformer,
364+
original: Round): ExpressionTransformer = {
365+
DecimalRoundTransformer(substraitExprName, child, original)
366+
}
367+
360368
/** Generate StringSplit transformer. */
361369
override def genStringSplitTransformer(
362370
substraitExprName: String,

gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,13 @@ trait SparkPlanExecApi {
295295
*/
296296
def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]]
297297

298+
def genDecimalRoundTransformer(
299+
substraitExprName: String,
300+
child: ExpressionTransformer,
301+
original: Round): ExpressionTransformer = {
302+
GenericExpressionTransformer(substraitExprName, Seq(child), original)
303+
}
304+
298305
def genGetStructFieldTransformer(
299306
substraitExprName: String,
300307
childTransformer: ExpressionTransformer,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package io.glutenproject.expression
18+
19+
import io.glutenproject.expression.ConverterUtils.FunctionConfig
20+
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
21+
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.types.{DataType, DecimalType}
24+
25+
import com.google.common.collect.Lists
26+
27+
case class DecimalRoundTransformer(
28+
substraitExprName: String,
29+
child: ExpressionTransformer,
30+
original: Round)
31+
extends ExpressionTransformer {
32+
33+
val toScale: Int = original.scale.eval(EmptyRow).asInstanceOf[Int]
34+
35+
// Use the same result type for different Spark versions.
36+
val dataType: DataType = original.child.dataType match {
37+
case decimalType: DecimalType =>
38+
val p = decimalType.precision
39+
val s = decimalType.scale
40+
// After rounding we may need one more digit in the integral part,
41+
// e.g. `ceil(9.9, 0)` -> `10`, `ceil(99, -1)` -> `100`.
42+
val integralLeastNumDigits = p - s + 1
43+
if (toScale < 0) {
44+
// negative scale means we need to adjust `-scale` number of digits before the decimal
45+
// point, which means we need at lease `-scale + 1` digits (after rounding).
46+
val newPrecision = math.max(integralLeastNumDigits, -toScale + 1)
47+
// We have to accept the risk of overflow as we can't exceed the max precision.
48+
DecimalType(math.min(newPrecision, DecimalType.MAX_PRECISION), 0)
49+
} else {
50+
val newScale = math.min(s, toScale)
51+
// We have to accept the risk of overflow as we can't exceed the max precision.
52+
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
53+
}
54+
case _ =>
55+
throw new UnsupportedOperationException(
56+
s"Decimal type is expected but received ${original.child.dataType.typeName}.")
57+
}
58+
59+
override def doTransform(args: Object): ExpressionNode = {
60+
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
61+
val functionId = ExpressionBuilder.newScalarFunction(
62+
functionMap,
63+
ConverterUtils.makeFuncName(
64+
substraitExprName,
65+
Seq(original.child.dataType),
66+
FunctionConfig.OPT))
67+
68+
ExpressionBuilder.makeScalarFunction(
69+
functionId,
70+
Lists.newArrayList[ExpressionNode](
71+
child.doTransform(args),
72+
ExpressionBuilder.makeIntLiteral(toScale)),
73+
ConverterUtils.getTypeNode(dataType, original.nullable)
74+
)
75+
}
76+
}

gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ object ExpressionConverter extends SQLConfHelper with Logging {
210210
replaceWithExpressionTransformerInternal(d.startDate, attributeSeq, expressionsMap),
211211
d
212212
)
213+
case r: Round if r.child.dataType.isInstanceOf[DecimalType] =>
214+
BackendsApiManager.getSparkPlanExecApiInstance.genDecimalRoundTransformer(
215+
substraitExprName,
216+
replaceWithExpressionTransformerInternal(r.child, attributeSeq, expressionsMap),
217+
r
218+
)
213219
case t: ToUnixTimestamp =>
214220
BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
215221
substraitExprName,

0 commit comments

Comments
 (0)