@@ -19,9 +19,12 @@ use crate::arithmetic_overflow_error;
1919use crate :: math_funcs:: utils:: { get_precision_scale, make_decimal_array, make_decimal_scalar} ;
2020use arrow:: array:: { Array , ArrowNativeTypeOp } ;
2121use arrow:: array:: { Int16Array , Int32Array , Int64Array , Int8Array } ;
22- use arrow:: datatypes:: DataType ;
22+ use arrow:: datatypes:: { DataType , Field } ;
2323use arrow:: error:: ArrowError ;
24+ use datafusion:: common:: config:: ConfigOptions ;
2425use datafusion:: common:: { exec_err, internal_err, DataFusionError , ScalarValue } ;
26+ use datafusion:: functions:: math:: round:: RoundFunc ;
27+ use datafusion:: logical_expr:: { ScalarFunctionArgs , ScalarUDFImpl } ;
2528use datafusion:: physical_plan:: ColumnarValue ;
2629use std:: { cmp:: min, sync:: Arc } ;
2730
@@ -107,6 +110,8 @@ pub fn spark_round(
107110 let ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( point) ) ) = point else {
108111 return internal_err ! ( "Invalid point argument for Round(): {:#?}" , point) ;
109112 } ;
113+ // DataFusion's RoundFunc expects Int32 for decimal_places
114+ let point_i32 = ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( * point as i32 ) ) ) ;
110115 match value {
111116 ColumnarValue :: Array ( array) => match array. data_type ( ) {
112117 DataType :: Int64 if * point < 0 => {
@@ -126,9 +131,18 @@ pub fn spark_round(
126131 let ( precision, scale) = get_precision_scale ( data_type) ;
127132 make_decimal_array ( array, precision, scale, & f)
128133 }
129- // Float32 / Float64 are routed to a JVM UDF (RoundFloatUDF / RoundDoubleUDF) by the
130- // serde, because matching Spark's BigDecimal-via-Double.toString rounding from native
131- // code does not stay consistent across JDK versions.
134+ DataType :: Float32 | DataType :: Float64 => {
135+ let round_udf = RoundFunc :: new ( ) ;
136+ let return_field = Arc :: new ( Field :: new ( "round" , array. data_type ( ) . clone ( ) , true ) ) ;
137+ let args_for_round = ScalarFunctionArgs {
138+ args : vec ! [ ColumnarValue :: Array ( Arc :: clone( array) ) , point_i32. clone( ) ] ,
139+ number_rows : array. len ( ) ,
140+ return_field,
141+ arg_fields : vec ! [ ] ,
142+ config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
143+ } ;
144+ round_udf. invoke_with_args ( args_for_round)
145+ }
132146 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
133147 } ,
134148 ColumnarValue :: Scalar ( a) => match a {
@@ -149,6 +163,19 @@ pub fn spark_round(
149163 let ( precision, scale) = get_precision_scale ( data_type) ;
150164 make_decimal_scalar ( a, precision, scale, & f)
151165 }
166+ ScalarValue :: Float32 ( _) | ScalarValue :: Float64 ( _) => {
167+ let round_udf = RoundFunc :: new ( ) ;
168+ let data_type = a. data_type ( ) ;
169+ let return_field = Arc :: new ( Field :: new ( "round" , data_type, true ) ) ;
170+ let args_for_round = ScalarFunctionArgs {
171+ args : vec ! [ ColumnarValue :: Scalar ( a. clone( ) ) , point_i32. clone( ) ] ,
172+ number_rows : 1 ,
173+ return_field,
174+ arg_fields : vec ! [ ] ,
175+ config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
176+ } ;
177+ round_udf. invoke_with_args ( args_for_round)
178+ }
152179 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
153180 } ,
154181 }
@@ -180,92 +207,77 @@ mod test {
180207
181208 use crate :: spark_round;
182209
183- use arrow:: array:: Decimal128Array ;
210+ use arrow:: array:: { Float32Array , Float64Array } ;
184211 use arrow:: datatypes:: DataType ;
212+ use datafusion:: common:: cast:: { as_float32_array, as_float64_array} ;
185213 use datafusion:: common:: { Result , ScalarValue } ;
186214 use datafusion:: physical_plan:: ColumnarValue ;
187215
188216 #[ test]
189217 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
190- fn test_round_decimal128_array_pos_point ( ) -> Result < ( ) > {
191- // Decimal128(10, 4) values: 125.2345, 15.3455, 0.1234, 0.1250, 0.7850, 123.1230
192- let input = Decimal128Array :: from ( vec ! [ 1252345 , 153455 , 1234 , 1250 , 7850 , 1231230 ] )
193- . with_precision_and_scale ( 10 , 4 ) ?;
218+ fn test_round_f32_array ( ) -> Result < ( ) > {
194219 let args = vec ! [
195- ColumnarValue :: Array ( Arc :: new( input) ) ,
220+ ColumnarValue :: Array ( Arc :: new( Float32Array :: from( vec![
221+ 125.2345 , 15.3455 , 0.1234 , 0.125 , 0.785 , 123.123 ,
222+ ] ) ) ) ,
196223 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
197224 ] ;
198- let return_type = DataType :: Decimal128 ( 8 , 2 ) ;
199- let ColumnarValue :: Array ( result) = spark_round ( & args, & return_type, false ) ? else {
225+ let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float32 , false ) ? else {
200226 unreachable ! ( )
201227 } ;
202- // HALF_UP: 0.125 -> 0.13, 0.785 -> 0.79
203- let expected = Decimal128Array :: from ( vec ! [ 12523 , 1535 , 12 , 13 , 79 , 12312 ] )
204- . with_precision_and_scale ( 8 , 2 ) ?;
205- let actual = result. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
206- assert_eq ! ( actual, & expected) ;
228+ let floats = as_float32_array ( & result) ?;
229+ let expected = Float32Array :: from ( vec ! [ 125.23 , 15.35 , 0.12 , 0.13 , 0.79 , 123.12 ] ) ;
230+ assert_eq ! ( floats, & expected) ;
207231 Ok ( ( ) )
208232 }
209233
210234 #[ test]
211235 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
212- fn test_round_decimal128_array_neg_point ( ) -> Result < ( ) > {
213- // Decimal128(10, 4) values: 125.2345, -125.2345, 150.0000, -150.0000, 0.0000
214- let input = Decimal128Array :: from ( vec ! [ 1252345 , -1252345 , 1500000 , -1500000 , 0 ] )
215- . with_precision_and_scale ( 10 , 4 ) ?;
236+ fn test_round_f64_array ( ) -> Result < ( ) > {
216237 let args = vec ! [
217- ColumnarValue :: Array ( Arc :: new( input) ) ,
218- ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( -2 ) ) ) ,
238+ ColumnarValue :: Array ( Arc :: new( Float64Array :: from( vec![
239+ 125.2345 , 15.3455 , 0.1234 , 0.125 , 0.785 , 123.123 ,
240+ ] ) ) ) ,
241+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
219242 ] ;
220- let return_type = DataType :: Decimal128 ( 6 , 0 ) ;
221- let ColumnarValue :: Array ( result) = spark_round ( & args, & return_type, false ) ? else {
243+ let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float64 , false ) ? else {
222244 unreachable ! ( )
223245 } ;
224- // HALF_UP: 125.2345 rounds DOWN to 100, 150 ties round AWAY from zero to 200
225- let expected =
226- Decimal128Array :: from ( vec ! [ 100 , -100 , 200 , -200 , 0 ] ) . with_precision_and_scale ( 6 , 0 ) ?;
227- let actual = result. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
228- assert_eq ! ( actual, & expected) ;
246+ let floats = as_float64_array ( & result) ?;
247+ let expected = Float64Array :: from ( vec ! [ 125.23 , 15.35 , 0.12 , 0.13 , 0.79 , 123.12 ] ) ;
248+ assert_eq ! ( floats, & expected) ;
229249 Ok ( ( ) )
230250 }
231251
232252 #[ test]
233253 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
234- fn test_round_decimal128_scalar_pos_point ( ) -> Result < ( ) > {
235- // 125.2345, point=2 -> 125.23
254+ fn test_round_f32_scalar ( ) -> Result < ( ) > {
236255 let args = vec ! [
237- ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( 1252345 ) , 10 , 4 ) ) ,
256+ ColumnarValue :: Scalar ( ScalarValue :: Float32 ( Some ( 125.2345 ) ) ) ,
238257 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
239258 ] ;
240- let return_type = DataType :: Decimal128 ( 8 , 2 ) ;
241- let ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( result) , p, s) ) =
242- spark_round ( & args, & return_type, false ) ?
259+ let ColumnarValue :: Scalar ( ScalarValue :: Float32 ( Some ( result) ) ) =
260+ spark_round ( & args, & DataType :: Float32 , false ) ?
243261 else {
244262 unreachable ! ( )
245263 } ;
246- assert_eq ! ( result, 12523 ) ;
247- assert_eq ! ( p, 8 ) ;
248- assert_eq ! ( s, 2 ) ;
264+ assert_eq ! ( result, 125.23 ) ;
249265 Ok ( ( ) )
250266 }
251267
252268 #[ test]
253269 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
254- fn test_round_decimal128_scalar_neg_point ( ) -> Result < ( ) > {
255- // 150.0000, point=-2 -> 200 (HALF_UP rounds the .5 tie away from zero)
270+ fn test_round_f64_scalar ( ) -> Result < ( ) > {
256271 let args = vec ! [
257- ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( 1500000 ) , 10 , 4 ) ) ,
258- ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( - 2 ) ) ) ,
272+ ColumnarValue :: Scalar ( ScalarValue :: Float64 ( Some ( 125.2345 ) ) ) ,
273+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
259274 ] ;
260- let return_type = DataType :: Decimal128 ( 6 , 0 ) ;
261- let ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( result) , p, s) ) =
262- spark_round ( & args, & return_type, false ) ?
275+ let ColumnarValue :: Scalar ( ScalarValue :: Float64 ( Some ( result) ) ) =
276+ spark_round ( & args, & DataType :: Float64 , false ) ?
263277 else {
264278 unreachable ! ( )
265279 } ;
266- assert_eq ! ( result, 200 ) ;
267- assert_eq ! ( p, 6 ) ;
268- assert_eq ! ( s, 0 ) ;
280+ assert_eq ! ( result, 125.23 ) ;
269281 Ok ( ( ) )
270282 }
271283}
0 commit comments