Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
37fd9fc
fix: round for float/double
kazuyukitanimura Apr 25, 2026
9707af3
fix: round for float/double
kazuyukitanimura Apr 25, 2026
62b910d
fix: round for float/double
kazuyukitanimura Apr 25, 2026
1cbf7b4
fix: round for float/double
kazuyukitanimura Apr 25, 2026
dbc7019
fix: round for float/double
kazuyukitanimura Apr 25, 2026
ce10f8b
fix: round for float/double
kazuyukitanimura Apr 25, 2026
643a38e
fix: round for float/double
kazuyukitanimura Apr 25, 2026
e31e825
fix: round for float/double
kazuyukitanimura Apr 25, 2026
57e0287
fix: round for float/double
kazuyukitanimura Apr 25, 2026
3a7eea6
fix: round for float/double
kazuyukitanimura Apr 25, 2026
17eac11
fix: round for float/double
kazuyukitanimura Apr 25, 2026
300e403
fix: round for float/double
kazuyukitanimura Apr 25, 2026
3e39898
fix: round for float/double
kazuyukitanimura Apr 25, 2026
e045cc6
fix: round for float/double
kazuyukitanimura Apr 25, 2026
a3ab1f0
fix: round for float/double
kazuyukitanimura Apr 25, 2026
9adb5fd
fix: round for float/double
kazuyukitanimura Apr 25, 2026
60080a9
fix: round for float/double
kazuyukitanimura Apr 25, 2026
49a0209
fix: round for float/double
kazuyukitanimura Apr 25, 2026
fa786aa
fix: round for float/double
kazuyukitanimura Apr 25, 2026
7731338
fix: round for float/double
kazuyukitanimura Apr 26, 2026
f88a04c
fix: round for float/double
kazuyukitanimura Apr 27, 2026
8204171
Merge remote-tracking branch 'upstream/main' into fix-flaot-round
kazuyukitanimura May 12, 2026
2dd4d5a
Merge remote-tracking branch 'upstream/main' into fix-flaot-round
kazuyukitanimura May 16, 2026
203c319
fix: round for float/double
kazuyukitanimura May 16, 2026
3470e5f
fix: round for float/double
kazuyukitanimura May 16, 2026
f9fdfae
fix: round for float/double
kazuyukitanimura May 16, 2026
3c2af76
Merge remote-tracking branch 'upstream/main' into fix-flaot-round
kazuyukitanimura May 22, 2026
b69469b
address review comments
kazuyukitanimura May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,14 @@ object CometRound extends CometExpressionSerde[Round] {
exprToProtoInternal(Literal(null), inputs, binding)
case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 =>
childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark
case _: FloatType | DoubleType =>
// We cannot properly match with the Spark behavior for floating-point numbers.
// Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a
// double to string internally in order to create its own internal representation.
// The problem is BigDecimal uses java.lang.Double.toString() and it has complicated
// rounding algorithm. E.g. -5.81855622136895E8 is actually
// -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of
// 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a
// difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be
// -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that
// toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can
// be rounded up to 6.13171162472835E18 that still represents the same double number.
// I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not.
// That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead
// of 6.1317116247283999E18.
withInfo(r, "Comet does not support Spark's BigDecimal rounding")
None
case _: FloatType =>
// Spark rounds floats by widening to double, building a BigDecimal via
// java.lang.Double.toString, applying HALF_UP, and narrowing back. The toString
// algorithm differs between JDKs (notably 17 vs 21), so a native implementation
// can't match every JDK. Delegate to a JVM UDF that runs on the executor's JDK.
convertViaJvmUdf(r, "org.apache.comet.udf.RoundFloatUDF", _scale, inputs, binding)
case _: DoubleType =>
convertViaJvmUdf(r, "org.apache.comet.udf.RoundDoubleUDF", _scale, inputs, binding)
case _ =>
// `scale` must be Int64 type in DataFusion
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding)
Expand All @@ -334,6 +325,36 @@ object CometRound extends CometExpressionSerde[Round] {
}

}

private def convertViaJvmUdf(
r: Round,
className: String,
scale: Int,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val valueProto = exprToProtoInternal(r.child, inputs, binding)
val scaleProto = exprToProtoInternal(Literal(scale, IntegerType), inputs, binding)
if (valueProto.isEmpty || scaleProto.isEmpty) {
withInfo(r, r.child)
return None
}
val returnType = serializeDataType(r.dataType).getOrElse {
withInfo(r, s"Unsupported return type ${r.dataType} for Round JVM UDF")
return None
}
val udfBuilder = ExprOuterClass.JvmScalarUdf
.newBuilder()
.setClassName(className)
.addArgs(valueProto.get)
.addArgs(scaleProto.get)
.setReturnType(returnType)
.setReturnNullable(r.nullable)
Some(
ExprOuterClass.Expr
.newBuilder()
.setJvmScalarUdf(udfBuilder.build())
.build())
}
}
object CometUnaryMinus extends CometExpressionSerde[UnaryMinus] {

Expand Down
83 changes: 83 additions & 0 deletions spark/src/main/scala/org/apache/comet/udf/RoundDoubleUDF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.udf

import org.apache.arrow.vector.{Float8Vector, IntVector, ValueVector}

import org.apache.comet.CometArrowAllocator

/**
* `round(double, scale)` implemented by delegating to Scala's `BigDecimal(d)`, which goes through
* `java.lang.Double.toString` before applying the requested scale. This matches Spark's
* `RoundBase` for `DoubleType` exactly on whatever JDK the executor is running, so output stays
* consistent across Java 17 / 21 even though the underlying `Double.toString` algorithm differs.
*
* Inputs:
* - inputs(0): Float8Vector value column (length = numRows, or length 1 when literal-folded)
* - inputs(1): IntVector scale, length-1 scalar (serde guarantees this)
*
* Output: Float8Vector, length numRows.
*/
class RoundDoubleUDF extends CometUDF {

override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
require(inputs.length == 2, s"RoundDoubleUDF expects 2 inputs, got ${inputs.length}")
val values = inputs(0).asInstanceOf[Float8Vector]
val scaleVec = inputs(1).asInstanceOf[IntVector]
require(
scaleVec.getValueCount >= 1 && !scaleVec.isNull(0),
"RoundDoubleUDF requires a non-null scalar scale")
val scale = scaleVec.get(0)

val out = new Float8Vector("round_double", CometArrowAllocator)
out.allocateNew(numRows)

val valueIsScalar = values.getValueCount == 1 && numRows != 1
if (valueIsScalar) {
if (values.isNull(0)) {
var i = 0
while (i < numRows) { out.setNull(i); i += 1 }
} else {
val rounded = RoundDoubleUDF.roundDouble(values.get(0), scale)
var i = 0
while (i < numRows) { out.set(i, rounded); i += 1 }
}
} else {
var i = 0
while (i < numRows) {
if (values.isNull(i)) {
out.setNull(i)
} else {
out.set(i, RoundDoubleUDF.roundDouble(values.get(i), scale))
}
i += 1
}
}
out.setValueCount(numRows)
out
}
}

object RoundDoubleUDF {
def roundDouble(v: Double, scale: Int): Double = {
if (v.isNaN || v.isInfinite) v
else BigDecimal(v).setScale(scale, BigDecimal.RoundingMode.HALF_UP).doubleValue
}
}
83 changes: 83 additions & 0 deletions spark/src/main/scala/org/apache/comet/udf/RoundFloatUDF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.udf

import org.apache.arrow.vector.{Float4Vector, IntVector, ValueVector}

import org.apache.comet.CometArrowAllocator

/**
* `round(float, scale)` implemented to mirror Spark's `RoundBase` for `FloatType`: widen to
* double, build a `BigDecimal` via `java.lang.Double.toString`, apply HALF_UP at the requested
* scale, then narrow back to float. The widening before BigDecimal construction is intentional:
* it matches Spark and produces the same result string the JDK uses for the value.
*
* Inputs:
* - inputs(0): Float4Vector value column (length = numRows, or length 1 when literal-folded)
* - inputs(1): IntVector scale, length-1 scalar (serde guarantees this)
*
* Output: Float4Vector, length numRows.
*/
class RoundFloatUDF extends CometUDF {

override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
require(inputs.length == 2, s"RoundFloatUDF expects 2 inputs, got ${inputs.length}")
val values = inputs(0).asInstanceOf[Float4Vector]
val scaleVec = inputs(1).asInstanceOf[IntVector]
require(
scaleVec.getValueCount >= 1 && !scaleVec.isNull(0),
"RoundFloatUDF requires a non-null scalar scale")
val scale = scaleVec.get(0)

val out = new Float4Vector("round_float", CometArrowAllocator)
out.allocateNew(numRows)

val valueIsScalar = values.getValueCount == 1 && numRows != 1
if (valueIsScalar) {
if (values.isNull(0)) {
var i = 0
while (i < numRows) { out.setNull(i); i += 1 }
} else {
val rounded = RoundFloatUDF.roundFloat(values.get(0), scale)
var i = 0
while (i < numRows) { out.set(i, rounded); i += 1 }
}
} else {
var i = 0
while (i < numRows) {
if (values.isNull(i)) {
out.setNull(i)
} else {
out.set(i, RoundFloatUDF.roundFloat(values.get(i), scale))
}
i += 1
}
}
out.setValueCount(numRows)
out
}
}

object RoundFloatUDF {
def roundFloat(v: Float, scale: Int): Float = {
if (v.isNaN || v.isInfinite) v
else BigDecimal(v.toDouble).setScale(scale, BigDecimal.RoundingMode.HALF_UP).floatValue
}
}
78 changes: 73 additions & 5 deletions spark/src/test/resources/sql-tests/expressions/math/round.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,86 @@ CREATE TABLE test_round(d double, i int) USING parquet
statement
INSERT INTO test_round VALUES (2.5, 0), (3.5, 0), (-2.5, 0), (123.456, 2), (123.456, -1), (NULL, 0), (cast('NaN' as double), 0), (cast('Infinity' as double), 0), (0.0, 0)

query expect_fallback(BigDecimal rounding)
query
SELECT round(d, 0) FROM test_round WHERE i = 0

query expect_fallback(BigDecimal rounding)
query
SELECT round(d, 2) FROM test_round WHERE i = 2

query expect_fallback(BigDecimal rounding)
query
SELECT round(d, -1) FROM test_round WHERE i = -1

query expect_fallback(BigDecimal rounding)
query
SELECT round(d) FROM test_round

-- literal + literal
query expect_fallback(BigDecimal rounding)
query
SELECT round(123.456, 2), round(2.5, 0), round(3.5, 0), round(-2.5, 0), round(NULL, 0)

-- HALF_UP semantics: .5 always rounds away from zero
statement
CREATE TABLE test_round_half_up(d double) USING parquet

statement
INSERT INTO test_round_half_up VALUES (0.5), (1.5), (2.5), (-0.5), (-1.5), (-2.5)

query
SELECT d, round(d, 0) FROM test_round_half_up

-- various scales on a single value
query
SELECT round(123.456, 0), round(123.456, 1), round(123.456, 2), round(123.456, 3), round(123.456, 5)

query
SELECT round(123.456, -1), round(123.456, -2), round(123.456, -3)

-- special values
query
SELECT round(cast('NaN' as double), 2), round(cast('Infinity' as double), 2), round(cast('-Infinity' as double), 2)

query
SELECT round(0.0, 5), round(-0.0, 5)

-- very small values
query
SELECT round(1.0E-10, 15), round(1.0E-10, 10), round(1.0E-10, 5)

-- negative scale on doubles
query
SELECT round(9999.9, -1), round(9999.9, -2), round(9999.9, -3), round(9999.9, -4)

query
SELECT round(-9999.9, -1), round(-9999.9, -2), round(-9999.9, -3), round(-9999.9, -4)

-- float type
statement
CREATE TABLE test_round_float(f float) USING parquet

statement
INSERT INTO test_round_float VALUES (cast(2.5 as float)), (cast(3.5 as float)), (cast(-2.5 as float)), (cast(0.125 as float)), (cast(0.785 as float)), (cast(123.456 as float)), (cast('NaN' as float)), (cast('Infinity' as float)), (NULL)

query
SELECT round(f, 0) FROM test_round_float

query
SELECT round(f, 2) FROM test_round_float

query
SELECT round(f, -1) FROM test_round_float

-- BigDecimal rounding edge case from Spark
statement
CREATE TABLE test_round_edge(d double) USING parquet

statement
INSERT INTO test_round_edge VALUES (-5.81855622136895E8), (6.1317116247283497E18), (6.13171162472835E18)

query
SELECT round(d, 4), round(d, 5), round(d, 6) FROM test_round_edge

query
SELECT round('-8316362075006449156', -5)

-- round with column from table (not literals)
query
SELECT d, round(d, 0), round(d, 2), round(d, -1) FROM test_round
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
TakeOrderedAndProject
+- Project [COMET: Comet does not support Spark's BigDecimal rounding]
+- CometNativeColumnarToRow
CometNativeColumnarToRow
+- CometTakeOrderedAndProject
+- CometProject
+- CometSortMergeJoin
:- CometProject
: +- CometSortMergeJoin
Expand Down Expand Up @@ -76,4 +76,4 @@ TakeOrderedAndProject
+- CometFilter
+- CometNativeScan parquet spark_catalog.default.date_dim

Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet.
Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
TakeOrderedAndProject
+- Project [COMET: Comet does not support Spark's BigDecimal rounding]
+- CometNativeColumnarToRow
CometNativeColumnarToRow
+- CometTakeOrderedAndProject
+- CometProject
+- CometSortMergeJoin
:- CometProject
: +- CometSortMergeJoin
Expand Down Expand Up @@ -76,4 +76,4 @@ TakeOrderedAndProject
+- CometFilter
+- CometNativeScan parquet spark_catalog.default.date_dim

Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet.
Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet.
Loading
Loading