Skip to content

Commit 67fa019

Browse files
committed
support allowDecimalPrecisionLoss
Signed-off-by: Yuan Zhou <[email protected]>
1 parent 9e2ec55 commit 67fa019

File tree

6 files changed

+33
-13
lines changed

6 files changed

+33
-13
lines changed

cpp/core/config/GlutenConfig.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ const std::string kLegacySize = "spark.sql.legacy.sizeOfNull";
3434

3535
const std::string kSessionTimezone = "spark.sql.session.timeZone";
3636

37+
const std::string kAllowPrecisionLoss = "spark.sql.decimalOperations.allowPrecisionLoss";
38+
3739
const std::string kIgnoreMissingFiles = "spark.sql.files.ignoreMissingFiles";
3840

3941
const std::string kDefaultSessionTimezone = "spark.gluten.sql.session.timeZone.default";

cpp/velox/compute/WholeStageResultIterator.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@ std::unordered_map<std::string, std::string> WholeStageResultIterator::getQueryC
490490
}
491491
// Adjust timestamp according to the above configured session timezone.
492492
configs[velox::core::QueryConfig::kAdjustTimestampToTimezone] = "true";
493+
// To align with Spark's behavior, allow decimal precision loss or not.
494+
configs[velox::core::QueryConfig::kAllowPrecisionLoss] = veloxCfg_->get<std::string>(kAllowPrecisionLoss, "true");
493495
// Align Velox size function with Spark.
494496
configs[velox::core::QueryConfig::kSparkLegacySizeOfNull] = std::to_string(veloxCfg_->get<bool>(kLegacySize, true));
495497

ep/build-velox/src/get_velox.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
set -exu
1818

19-
VELOX_REPO=https://github.com/oap-project/velox.git
20-
VELOX_BRANCH=2024_05_06
19+
VELOX_REPO=https://github.com/zhouyuan/velox.git
20+
VELOX_BRANCH=wip_decimal_precision_loss
2121
VELOX_HOME=""
2222

2323
#Set on run gluten on HDFS

gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -511,14 +511,6 @@ object ExpressionConverter extends SQLConfHelper with Logging {
511511
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
512512
expr)
513513
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
514-
// PrecisionLoss=true: velox support / ch not support
515-
// PrecisionLoss=false: velox not support / ch support
516-
// TODO ch support PrecisionLoss=true
517-
if (!BackendsApiManager.getSettings.allowDecimalArithmetic) {
518-
throw new GlutenNotSupportException(
519-
s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " +
520-
s"${conf.decimalOperationsAllowPrecisionLoss} mode")
521-
}
522514
val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
523515
DecimalArithmeticUtil.rescaleLiteral(b)
524516
} else {

gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer,
2222

2323
import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
2424
import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract}
25+
import org.apache.spark.sql.internal.SQLConf
2526
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType}
2627

2728
object DecimalArithmeticUtil {
@@ -33,12 +34,14 @@ object DecimalArithmeticUtil {
3334

3435
val MIN_ADJUSTED_SCALE = 6
3536
val MAX_PRECISION = 38
37+
val MAX_SCALE = 38
3638

3739
// Returns the result decimal type of a decimal arithmetic computing.
3840
def getResultTypeForOperation(
3941
operationType: OperationType.Config,
4042
type1: DecimalType,
4143
type2: DecimalType): DecimalType = {
44+
val allowPrecisionLoss = SQLConf.get.decimalOperationsAllowPrecisionLoss
4245
var resultScale = 0
4346
var resultPrecision = 0
4447
operationType match {
@@ -54,16 +57,32 @@ object DecimalArithmeticUtil {
5457
resultScale = type1.scale + type2.scale
5558
resultPrecision = type1.precision + type2.precision + 1
5659
case OperationType.DIVIDE =>
57-
resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
58-
resultPrecision = type1.precision - type1.scale + type2.scale + resultScale
60+
if (allowPrecisionLoss) {
61+
resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
62+
resultPrecision = type1.precision - type1.scale + type2.scale + resultScale
63+
} else {
64+
var intDig = Math.min(MAX_SCALE, type1.precision - type1.scale + type2.scale)
65+
var decDig = Math.min(MAX_SCALE, Math.max(6, type1.scale + type2.precision + 1))
66+
val diff = (intDig + decDig) - MAX_SCALE
67+
if (diff > 0) {
68+
decDig -= diff / 2 + 1
69+
intDig = MAX_SCALE - decDig
70+
}
71+
resultScale = intDig + decDig
72+
resultPrecision = decDig
73+
}
5974
case OperationType.MOD =>
6075
resultScale = Math.max(type1.scale, type2.scale)
6176
resultPrecision =
6277
Math.min(type1.precision - type1.scale, type2.precision - type2.scale + resultScale)
6378
case other =>
6479
throw new GlutenNotSupportException(s"$other is not supported.")
6580
}
66-
adjustScaleIfNeeded(resultPrecision, resultScale)
81+
if (allowPrecisionLoss) {
82+
adjustScaleIfNeeded(resultPrecision, resultScale)
83+
} else {
84+
bounded(resultPrecision, resultScale)
85+
}
6786
}
6887

6988
// Returns the adjusted decimal type when the precision is larger the maximum.
@@ -79,6 +98,10 @@ object DecimalArithmeticUtil {
7998
DecimalType(typePrecision, typeScale)
8099
}
81100

101+
def bounded(precision: Int, scale: Int): DecimalType = {
102+
DecimalType(Math.min(precision, MAX_PRECISION), Math.min(scale, MAX_SCALE))
103+
}
104+
82105
// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
83106
// because argument input type of "cast" is actually the res type of "+-*/".
84107
// Cast will use a wider input type, then calculates result type with less scale than expected.

shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ object GlutenConfig {
571571
GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY,
572572
SQLConf.LEGACY_SIZE_OF_NULL.key,
573573
"spark.io.compression.codec",
574+
"spark.sql.decimalOperations.allowPrecisionLoss",
574575
COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS.key,
575576
COLUMNAR_VELOX_BLOOM_FILTER_NUM_BITS.key,
576577
COLUMNAR_VELOX_BLOOM_FILTER_MAX_NUM_BITS.key,

0 commit comments

Comments
 (0)