1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: any:: Any ;
1918use std:: sync:: Arc ;
2019
21- use arrow:: array:: { Array , ArrayRef , BooleanArray , Float32Array , Float64Array } ;
22- use arrow:: datatypes:: DataType ;
20+ use arrow:: array:: { Array , ArrayRef , BooleanArray , PrimitiveArray } ;
21+ use arrow:: datatypes:: { ArrowPrimitiveType , DataType , Float32Type , Float64Type } ;
2322use datafusion_common:: utils:: take_function_args;
24- use datafusion_common:: { Result , ScalarValue , exec_err} ;
23+ use datafusion_common:: { Result , exec_err} ;
2524use datafusion_expr:: {
2625 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
2726 Volatility ,
2827} ;
28+ use num_traits:: Float ;
2929
3030/// Spark-compatible `isnan` expression
3131/// <https://spark.apache.org/docs/latest/api/sql/index.html#isnan>
@@ -60,10 +60,6 @@ impl SparkIsNaN {
6060}
6161
6262impl ScalarUDFImpl for SparkIsNaN {
63- fn as_any ( & self ) -> & dyn Any {
64- self
65- }
66-
6763 fn name ( & self ) -> & str {
6864 "isnan"
6965 }
@@ -82,37 +78,23 @@ impl ScalarUDFImpl for SparkIsNaN {
8278}
8379
8480fn spark_isnan ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
85- let [ value] = take_function_args ( "isnan" , args) ?;
86-
87- match value {
88- ColumnarValue :: Array ( array) => match array. data_type ( ) {
89- DataType :: Float32 => {
90- let array = array. as_any ( ) . downcast_ref :: < Float32Array > ( ) . unwrap ( ) ;
91- Ok ( ColumnarValue :: Array ( nulls_to_false (
92- BooleanArray :: from_unary ( array, |x| x. is_nan ( ) ) ,
93- ) ) )
94- }
95- DataType :: Float64 => {
96- let array = array. as_any ( ) . downcast_ref :: < Float64Array > ( ) . unwrap ( ) ;
97- Ok ( ColumnarValue :: Array ( nulls_to_false (
98- BooleanArray :: from_unary ( array, |x| x. is_nan ( ) ) ,
99- ) ) )
100- }
101- other => exec_err ! ( "Unsupported data type {other:?} for function isnan" ) ,
102- } ,
103- ColumnarValue :: Scalar ( sv) => match sv {
104- ScalarValue :: Float32 ( v) => Ok ( ColumnarValue :: Scalar ( ScalarValue :: Boolean (
105- Some ( v. is_some_and ( |x| x. is_nan ( ) ) ) ,
106- ) ) ) ,
107- ScalarValue :: Float64 ( v) => Ok ( ColumnarValue :: Scalar ( ScalarValue :: Boolean (
108- Some ( v. is_some_and ( |x| x. is_nan ( ) ) ) ,
109- ) ) ) ,
110- _ => exec_err ! (
111- "Unsupported data type {:?} for function isnan" ,
112- sv. data_type( )
113- ) ,
114- } ,
115- }
81+ let [ array] = take_function_args ( "isnan" , ColumnarValue :: values_to_arrays ( args) ?) ?;
82+
83+ let result = match array. data_type ( ) {
84+ DataType :: Float32 => isnan_array :: < Float32Type > ( & array) ,
85+ DataType :: Float64 => isnan_array :: < Float64Type > ( & array) ,
86+ other => return exec_err ! ( "Unsupported data type {other:?} for function isnan" ) ,
87+ } ;
88+ Ok ( ColumnarValue :: Array ( result) )
89+ }
90+
91+ fn isnan_array < T > ( array : & ArrayRef ) -> ArrayRef
92+ where
93+ T : ArrowPrimitiveType ,
94+ T :: Native : Float ,
95+ {
96+ let array = array. as_any ( ) . downcast_ref :: < PrimitiveArray < T > > ( ) . unwrap ( ) ;
97+ nulls_to_false ( BooleanArray :: from_unary ( array, |x| x. is_nan ( ) ) )
11698}
11799
118100/// Replaces null values with false in a BooleanArray.
@@ -131,6 +113,8 @@ fn nulls_to_false(is_nan: BooleanArray) -> ArrayRef {
131113#[ cfg( test) ]
132114mod tests {
133115 use super :: * ;
116+ use arrow:: array:: { Float32Array , Float64Array } ;
117+ use datafusion_common:: ScalarValue ;
134118
135119 #[ test]
136120 fn test_isnan_float64 ( ) {
0 commit comments