Skip to content

Commit 33c53fe

Browse files
authored
Optimize format number implementation (#14586)
### Description Contributes to #14588. This change optimizes the implementation of `format_number` originally introduced in #9281. The optimizations were found by Claude after it was tasked with optimizing this implementation against microbenchmarks. The optimizations are as follows: - Use a scalar "," separator with the [scalar stringConcatenate API](https://docs.rapids.ai/api/cudf-java/legacy/ai/rapids/cudf/columnvector#stringConcatenate(ai.rapids.cudf.Scalar,ai.rapids.cudf.Scalar,ai.rapids.cudf.ColumnView%5B%5D)) instead of allocating a full column. - Use numeric sign detection by casting to float, rather than a string conversion + string pattern match to look for a negative sign. - Use a small `signCol` that is conditionally prepended to negative numbers, instead of allocating a fully copy of the column with "-" prepended. #### Performance Note that this operator does not appear in NDS/NDS-H, so I have the following microbenchmarks for a targeted performance diff. Generated data: ```bash cat << 'EOF' | SPARK_HOME=/opt/spark-3.5.5 /opt/spark-3.5.5/bin/spark-shell \ --jars datagen/target/datagen_2.12-26.06.0-SNAPSHOT-spark355.jar \ --master "local[*]" import org.apache.spark.sql.tests.datagen._ val types = Seq("long", "int", "short", "byte", "DECIMAL(38,10)") for (t <- types) { DBGen().addTable("data", s"a $t", 10000000) .toDF(spark).write.mode("overwrite") .parquet(s"/tmp/format_number_bench/${t.replaceAll("[^a-zA-Z0-9]", "_")}") } println("done") :quit EOF ``` Benchmark runs: ```bash cat << 'EOF' | SPARK_HOME=/opt/spark-3.5.5 /opt/spark-3.5.5/bin/spark-shell \ --jars $PLUGIN_JAR \ --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --master "local[*]" val types = Seq( ("long", "/tmp/format_number_bench/long"), ("int", "/tmp/format_number_bench/int"), ("short", "/tmp/format_number_bench/short"), ("byte", "/tmp/format_number_bench/byte"), ("decimal", "/tmp/format_number_bench/DECIMAL_38_10_") ) for ((name, path) <- types) { val df = spark.read.parquet(path) // warmup df.selectExpr("COUNT(format_number(a, 0))", "COUNT(format_number(a, 5))").collect() df.selectExpr("COUNT(format_number(a, 0))", "COUNT(format_number(a, 5))").collect() // timed for (i <- 1 to 3) { print(s"${name}_run${i}: ") spark.time(df.selectExpr("COUNT(format_number(a, 0))", "COUNT(format_number(a, 5))").collect()) } } :quit EOF ``` Results: | dtype | main | this change | speedup | |-----------|-----------------|-----------|---------| | long | 540 ms | 380 ms | **1.42x** | | int | 362 ms | 285 ms | **1.27x** | | short | 378 ms | 238 ms | **1.59x** | | byte | 244 ms | 186 ms | **1.31x** | | decimal(38,10) | 921 ms | 551 ms | **1.67x** | ### Checklists Documentation - [ ] Updated for new or modified user-facing features or behaviors - [X] No user-facing change Testing - [ ] Added or modified tests to cover new code paths - [X] Covered by existing tests - [ ] Not required Performance - [X] Tests ran and results are added in the PR description - [ ] Issue filed with a link in the PR description - [ ] Not required --------- Signed-off-by: Rishi Chandra <rishic@nvidia.com>
1 parent ef71f42 commit 33c53fe

1 file changed

Lines changed: 77 additions & 60 deletions

File tree

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,14 +2383,6 @@ case class GpuFormatNumber(x: Expression, d: Expression)
23832383
}
23842384
}
23852385

2386-
private def negativeCheck(cv: ColumnVector): ColumnVector = {
2387-
withResource(cv.castTo(DType.STRING)) { cvStr =>
2388-
withResource(Scalar.fromString("-")) { negativeSign =>
2389-
cvStr.startsWith(negativeSign)
2390-
}
2391-
}
2392-
}
2393-
23942386
private def removeExtraCommas(str: ColumnVector): ColumnVector = {
23952387
withResource(Scalar.fromString(",")) { comma =>
23962388
str.rstrip(comma)
@@ -2406,18 +2398,21 @@ case class GpuFormatNumber(x: Expression, d: Expression)
24062398
}
24072399
}
24082400
}
2409-
val sepCol = withResource(Scalar.fromString(",")) { sep =>
2410-
ColumnVector.fromScalar(sep, str.getRowCount.toInt)
2411-
}
2412-
val substrs = closeOnExcept(sepCol) { _ =>
2413-
(0 until maxstrlen by 3).safeMap { i =>
2401+
if (maxstrlen <= 3) {
2402+
// no commas are needed for strings of 3 or fewer chars
2403+
str.incRefCount()
2404+
} else {
2405+
val substrs = (0 until maxstrlen by 3).safeMap { i =>
24142406
str.substring(i, i + 3).asInstanceOf[ColumnView]
24152407
}.toArray
2416-
}
2417-
withResource(substrs) { _ =>
2418-
withResource(sepCol) { _ =>
2419-
withResource(ColumnVector.stringConcatenate(substrs, sepCol)) { res =>
2420-
removeExtraCommas(res)
2408+
withResource(substrs) { _ =>
2409+
// join the 3-char chunks with commas using a scalar separator
2410+
withResource(Scalar.fromString(",")) { sep =>
2411+
withResource(Scalar.fromString("")) { narep =>
2412+
withResource(ColumnVector.stringConcatenate(sep, narep, substrs)) { res =>
2413+
removeExtraCommas(res)
2414+
}
2415+
}
24212416
}
24222417
}
24232418
}
@@ -2426,65 +2421,87 @@ case class GpuFormatNumber(x: Expression, d: Expression)
24262421
private def formatNumberNonKernel(cv: ColumnVector, d: Int): ColumnVector = {
24272422
val (integerPart, decimalPart) = getParts(cv, d)
24282423
// reverse integer part for adding commas
2429-
val resWithDecimalPart = withResource(decimalPart) { _ =>
2430-
val reversedIntegerPart = withResource(integerPart) { intPart =>
2431-
intPart.reverseStringsOrLists()
2432-
}
2433-
val reversedIntegerPartWithCommas = withResource(reversedIntegerPart) { _ =>
2434-
addCommas(reversedIntegerPart)
2435-
}
2436-
// reverse result back
2437-
val reverseBack = withResource(reversedIntegerPartWithCommas) { r =>
2438-
r.reverseStringsOrLists()
2439-
}
2440-
d match {
2441-
case 0 => {
2442-
// d == 0, only return integer part
2443-
reverseBack
2424+
val integerWithCommas = closeOnExcept(decimalPart) { _ =>
2425+
val reversed = withResource(integerPart) { _ =>
2426+
integerPart.reverseStringsOrLists()
2427+
}
2428+
val reversedWithCommas = withResource(reversed) { _ =>
2429+
addCommas(reversed)
2430+
}
2431+
withResource(reversedWithCommas) { _ =>
2432+
reversedWithCommas.reverseStringsOrLists()
2433+
}
2434+
}
2435+
// build a small per-row sign prefix column ("-" or "") based on the sign
2436+
// of the value, that we will prepend at the end.
2437+
// this way, we avoid creating bigger signed/unsigned versions of the formatted column
2438+
// followed by an ifElse.
2439+
val signCol = closeOnExcept(decimalPart) { _ =>
2440+
closeOnExcept(integerWithCommas) { _ =>
2441+
// since we only need the sign bit, cast to float and compare < 0.
2442+
// this is cheaper than casting to string and checking for "-".
2443+
val isNeg = withResource(cv.castTo(DType.FLOAT32)) { cvFloat =>
2444+
withResource(Scalar.fromFloat(0.0f)) { zero =>
2445+
cvFloat.lessThan(zero)
2446+
}
24442447
}
2445-
case _ => {
2446-
// d > 0, append decimal part to result
2447-
withResource(reverseBack) { _ =>
2448-
withResource(Scalar.fromString(".")) { point =>
2449-
withResource(Scalar.fromString("")) { empty =>
2450-
ColumnVector.stringConcatenate(point, empty, Array(reverseBack, decimalPart))
2451-
}
2448+
withResource(isNeg) { _ =>
2449+
withResource(Scalar.fromString("-")) { neg =>
2450+
withResource(Scalar.fromString("")) { empty =>
2451+
isNeg.ifElse(neg, empty)
24522452
}
24532453
}
24542454
}
24552455
}
24562456
}
2457-
// add negative sign back
2458-
val negCv = withResource(Scalar.fromString("-")) { negativeSign =>
2459-
ColumnVector.fromScalar(negativeSign, cv.getRowCount.toInt)
2460-
}
2461-
val formated = withResource(resWithDecimalPart) { _ =>
2462-
val resWithNeg = withResource(negCv) { _ =>
2463-
ColumnVector.stringConcatenate(Array(negCv, resWithDecimalPart))
2464-
}
2465-
withResource(negativeCheck(cv)) { isNegative =>
2466-
withResource(resWithNeg) { _ =>
2467-
isNegative.ifElse(resWithNeg, resWithDecimalPart)
2457+
// single concatenation pass for sign + integer [+ "." + decimal]
2458+
val formatted = d match {
2459+
case 0 =>
2460+
decimalPart.close()
2461+
// no decimal - just prepend the precomputed sign prefix
2462+
withResource(signCol) { _ =>
2463+
withResource(integerWithCommas) { _ =>
2464+
ColumnVector.stringConcatenate(
2465+
Array[ColumnView](signCol, integerWithCommas))
2466+
}
2467+
}
2468+
case _ =>
2469+
// join integer and decimal with scalar separator "."
2470+
val intDotDec = closeOnExcept(signCol) { _ =>
2471+
withResource(integerWithCommas) { _ =>
2472+
withResource(decimalPart) { _ =>
2473+
withResource(Scalar.fromString(".")) { dot =>
2474+
withResource(Scalar.fromString("")) { narep =>
2475+
ColumnVector.stringConcatenate(dot, narep,
2476+
Array[ColumnView](integerWithCommas, decimalPart))
2477+
}
2478+
}
2479+
}
2480+
}
2481+
}
2482+
// prepend the precomputed sign prefix to the formatted number
2483+
withResource(signCol) { _ =>
2484+
withResource(intDotDec) { _ =>
2485+
ColumnVector.stringConcatenate(
2486+
Array[ColumnView](signCol, intDotDec))
2487+
}
24682488
}
2469-
}
24702489
}
24712490
// handle null case
2472-
val anyNull = closeOnExcept(formated) { _ =>
2491+
val anyNull = closeOnExcept(formatted) { _ =>
24732492
cv.getNullCount > 0
24742493
}
2475-
val formatedWithNull = anyNull match {
2476-
case true => {
2477-
withResource(formated) { _ =>
2494+
anyNull match {
2495+
case true =>
2496+
withResource(formatted) { _ =>
24782497
withResource(cv.isNull) { isNull =>
24792498
withResource(Scalar.fromNull(DType.STRING)) { nullScalar =>
2480-
isNull.ifElse(nullScalar, formated)
2499+
isNull.ifElse(nullScalar, formatted)
24812500
}
24822501
}
24832502
}
2484-
}
2485-
case false => formated
2503+
case false => formatted
24862504
}
2487-
formatedWithNull
24882505
}
24892506

24902507
override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {

0 commit comments

Comments
 (0)