Skip to content

Commit 4920180

Browse files
Switch rounding calls to using new API (#13497)
This changes the plugin to call the new spark-rapids-jni API for rounding rather than the old cudf API. The rounding of floating point functionality will be removed from cudf so we've moved the float rounding (and the API) into spark-rapids-jni. This is part of resolving [sr-jni issue 3585](NVIDIA/cudf-spark-jni#3585) [This PR](rapidsai/cudf#20110) removes the old java code from cudf, and will be merged in cudf 25.12. [This PR](NVIDIA/cudf-spark-jni#3770) moves the floating rounding into spark-rapids-jni and changes the API. This will be merged in 25.10 and is a dependency of this PR. Note that I have tested this code applying both of the above PRs and all of the spark-rapids integration tests pass. ### Checklists - [ ] This PR has added documentation for new or modified features or behaviors. - [ ] This PR has added new tests or modified existing tests to cover new code paths. (Please explain in the PR description how the new code paths are tested, such as names of the new/existing tests that cover them.) - [ ] Performance testing has been performed and its results are added in the PR description. Or, an issue has been filed with a link in the PR description. --------- Signed-off-by: Paul Mattione <pmattione@nvidia.com>
1 parent 7a894de commit 4920180

6 files changed

Lines changed: 24 additions & 22 deletions

File tree

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ import java.util.Optional
2222
import scala.collection.mutable.ArrayBuffer
2323

2424
import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType, RegexProgram, Scalar}
25-
import ai.rapids.cudf
2625
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
2726
import com.nvidia.spark.rapids.RapidsPluginImplicits._
28-
import com.nvidia.spark.rapids.jni.{CastException, CastStrings, DecimalUtils, GpuTimeZoneDB}
27+
import com.nvidia.spark.rapids.jni.{Arithmetic, CastException, CastStrings, DecimalUtils, GpuTimeZoneDB, RoundMode}
2928
import com.nvidia.spark.rapids.shims.{AnsiUtil, CastTimeToIntShim, GpuCastShims, GpuIntervalUtils, GpuTypeShims, NullIntolerantShim, SparkShimImpl}
3029
import org.apache.commons.text.StringEscapeUtils
3130

@@ -1568,7 +1567,7 @@ object GpuCast {
15681567
val rounded = if (!isScaleUpcast) {
15691568
// We have to round the data to the desired scale. Spark uses HALF_UP rounding in
15701569
// this case so we need to also.
1571-
input.round(to.scale, cudf.RoundMode.HALF_UP)
1570+
Arithmetic.round(input, to.scale, RoundMode.HALF_UP)
15721571
} else {
15731572
input.copyToColumnVector()
15741573
}

sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,9 +15,9 @@
1515
*/
1616
package com.nvidia.spark.rapids
1717

18-
import ai.rapids.cudf
19-
import ai.rapids.cudf.{ColumnVector, DecimalUtils, DType, Scalar}
18+
import ai.rapids.cudf.{ColumnVector, DType, Scalar}
2019
import com.nvidia.spark.rapids.Arm.withResource
20+
import com.nvidia.spark.rapids.jni.{Arithmetic, RoundMode}
2121

2222
import org.apache.spark.sql.catalyst.expressions.Expression
2323
import org.apache.spark.sql.types.{DataType, DecimalType, LongType}
@@ -48,7 +48,7 @@ case class GpuCheckOverflow(child: Expression,
4848
val rounded = if (resultDType.equals(base.getType)) {
4949
base.incRefCount()
5050
} else {
51-
withResource(base.round(dataType.scale, cudf.RoundMode.HALF_UP)) { rounded =>
51+
withResource(Arithmetic.round(base, dataType.scale, RoundMode.HALF_UP)) { rounded =>
5252
if (resultDType.getTypeId != base.getType.getTypeId) {
5353
rounded.castTo(resultDType)
5454
} else {
@@ -101,12 +101,12 @@ case class GpuMakeDecimal(
101101
override def toString: String = s"MakeDecimal($child,$precision,$sparkScale)"
102102

103103
private lazy val (minValue, maxValue) = {
104-
val bounds = DecimalUtils.bounds(dataType.precision, dataType.scale)
104+
val bounds = ai.rapids.cudf.DecimalUtils.bounds(dataType.precision, dataType.scale)
105105
(bounds.getKey.unscaledValue().longValue(), bounds.getValue.unscaledValue().longValue())
106106
}
107107

108108
override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
109-
val outputType = DecimalUtils.createDecimalType(precision, sparkScale)
109+
val outputType = ai.rapids.cudf.DecimalUtils.createDecimalType(precision, sparkScale)
110110
val base = input.getBase
111111
val outOfBounds = withResource(Scalar.fromLong(maxValue)) { maxScalar =>
112112
withResource(base.greaterThan(maxScalar)) { over =>

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.Arm._
2727
import com.nvidia.spark.rapids.ExprMeta
2828
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
2929
import com.nvidia.spark.rapids.RapidsPluginImplicits._
30-
import com.nvidia.spark.rapids.jni.{DateTimeUtils, GpuTimeZoneDB}
30+
import com.nvidia.spark.rapids.jni.{Arithmetic, DateTimeUtils, GpuTimeZoneDB}
3131
import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimBinaryExpression, ShimExpression}
3232

3333
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp, TruncDate, TruncTimestamp}
@@ -1418,7 +1418,7 @@ case class GpuMonthsBetween(ts1: Expression,
14181418
}
14191419
val roundedPartialMonth = if (needsRoundOff.get) {
14201420
withResource(partialMonth) { _ =>
1421-
partialMonth.round(8)
1421+
Arithmetic.round(partialMonth, 8)
14221422
}
14231423
} else {
14241424
partialMonth

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import ai.rapids.cudf._
2222
import ai.rapids.cudf.ast.BinaryOperator
2323
import com.nvidia.spark.rapids._
2424
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
25-
import com.nvidia.spark.rapids.jni.CastStrings
25+
import com.nvidia.spark.rapids.jni.{Arithmetic, CastStrings, RoundMode}
2626

2727
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
2828
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
@@ -608,7 +608,7 @@ abstract class GpuRoundBase(child: Expression, scale: Expression, outputType: Da
608608
case DecimalType.Fixed(_, s) =>
609609
// Only needs to perform round when required scale < input scale
610610
val rounded = if (scaleVal < s) {
611-
lhsValue.round(scaleVal, roundMode)
611+
Arithmetic.round(lhsValue, scaleVal, roundMode)
612612
} else {
613613
lhsValue.incRefCount()
614614
}
@@ -672,7 +672,7 @@ abstract class GpuRoundBase(child: Expression, scale: Expression, outputType: Da
672672
}
673673
}
674674
} else {
675-
lhs.round(scale, roundMode)
675+
Arithmetic.round(lhs, scale, roundMode)
676676
}
677677
}
678678

@@ -769,7 +769,7 @@ abstract class GpuRoundBase(child: Expression, scale: Expression, outputType: Da
769769
// just returns the original values
770770
lhs.incRefCount()
771771
} else {
772-
lhs.round(scale, roundMode)
772+
Arithmetic.round(lhs, scale, roundMode)
773773
}
774774
}
775775

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ import scala.annotation.tailrec
2424
import scala.collection.mutable
2525
import scala.collection.mutable.ArrayBuffer
2626

27-
import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar, Table}
27+
import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, Scalar, Table}
2828
import com.nvidia.spark.rapids._
2929
import com.nvidia.spark.rapids.Arm._
3030
import com.nvidia.spark.rapids.RapidsPluginImplicits._
31+
import com.nvidia.spark.rapids.jni.{Arithmetic, RoundMode}
3132
import com.nvidia.spark.rapids.jni.CastStrings
3233
import com.nvidia.spark.rapids.jni.GpuSubstringIndexUtils
3334
import com.nvidia.spark.rapids.jni.NumberConverter
@@ -2283,7 +2284,8 @@ case class GpuFormatNumber(x: Expression, d: Expression)
22832284
val appendZeroNum = (d - scale).max(0).min(d)
22842285
val (intPart, decTemp) = if (roundingScale <= 0) {
22852286
withResource(ArrayBuffer.empty[ColumnVector]) { resourceArray =>
2286-
val intPart = withResource(cv.round(roundingScale, RoundMode.HALF_EVEN)) { rounded =>
2287+
val intPart = withResource(Arithmetic.round(cv, roundingScale,
2288+
RoundMode.HALF_EVEN)) { rounded =>
22872289
rounded.castTo(DType.STRING)
22882290
}
22892291
resourceArray += intPart
@@ -2307,7 +2309,7 @@ case class GpuFormatNumber(x: Expression, d: Expression)
23072309
(intPartZeroHandled.incRefCount(), decPart.incRefCount())
23082310
}
23092311
} else {
2310-
withResource(cv.round(roundingScale, RoundMode.HALF_EVEN)) { rounded =>
2312+
withResource(Arithmetic.round(cv, roundingScale, RoundMode.HALF_EVEN)) { rounded =>
23112313
withResource(rounded.castTo(DType.STRING)) { roundedStr =>
23122314
withResource(roundedStr.stringSplit(".", 2)) { intAndDec =>
23132315
(intAndDec.getColumn(0).incRefCount(), intAndDec.getColumn(1).incRefCount())

sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ package org.apache.spark.sql.rapids.shims
4545

4646
import java.math.BigInteger
4747

48-
import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, DType, RoundMode, Scalar}
48+
import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, DType, Scalar}
4949
import com.nvidia.spark.rapids.{BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar}
5050
import com.nvidia.spark.rapids.Arm.withResource
51+
import com.nvidia.spark.rapids.jni.{Arithmetic, RoundMode}
5152
import com.nvidia.spark.rapids.shims.NullIntolerantShim
5253

5354
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
@@ -160,7 +161,7 @@ object IntervalUtils {
160161
// check Inf, -Inf, NaN
161162
checkDoubleInfNan(doubleCv)
162163

163-
withResource(doubleCv.round(RoundMode.HALF_UP)) { roundedDouble =>
164+
withResource(Arithmetic.round(doubleCv, RoundMode.HALF_UP)) { roundedDouble =>
164165
// throws exception if the result exceeds int limits
165166
withResource(roundedDouble.castTo(DType.INT64)) { long =>
166167
castLongToIntWithOverflowCheck(long)
@@ -192,7 +193,7 @@ object IntervalUtils {
192193
val MIN_LONG_AS_DOUBLE: Double = -9.223372036854776E18
193194
val MAX_LONG_AS_DOUBLE_PLUS_ONE: Double = 9.223372036854776E18
194195

195-
withResource(doubleCv.round(RoundMode.HALF_UP)) { z =>
196+
withResource(Arithmetic.round(doubleCv, RoundMode.HALF_UP)) { z =>
196197
withResource(Scalar.fromDouble(MAX_LONG_AS_DOUBLE_PLUS_ONE)) { max =>
197198
withResource(z.greaterOrEqualTo(max)) { invalid =>
198199
if (BoolUtils.isAnyValidTrue(invalid)) {
@@ -323,7 +324,7 @@ object IntervalUtils {
323324
leftDecimal.div(q, dT)
324325
}
325326
withResource(t) { t =>
326-
t.round(RoundMode.HALF_UP)
327+
Arithmetic.round(t, RoundMode.HALF_UP)
327328
}
328329
}
329330
}

0 commit comments

Comments
 (0)