Skip to content

Commit e9f6d52

Browse files
committed
Fix
1 parent a8bb264 commit e9f6d52

2 files changed

Lines changed: 44 additions & 39 deletions

File tree

datafusion/spark/src/function/math/isnan.rs

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::any::Any;
1918
use std::sync::Arc;
2019

21-
use arrow::array::{Array, ArrayRef, BooleanArray, Float32Array, Float64Array};
22-
use arrow::datatypes::DataType;
20+
use arrow::array::{Array, ArrayRef, BooleanArray, PrimitiveArray};
21+
use arrow::datatypes::{ArrowPrimitiveType, DataType, Float32Type, Float64Type};
2322
use datafusion_common::utils::take_function_args;
24-
use datafusion_common::{Result, ScalarValue, exec_err};
23+
use datafusion_common::{Result, exec_err};
2524
use datafusion_expr::{
2625
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
2726
Volatility,
2827
};
28+
use num_traits::Float;
2929

3030
/// Spark-compatible `isnan` expression
3131
/// <https://spark.apache.org/docs/latest/api/sql/index.html#isnan>
@@ -60,10 +60,6 @@ impl SparkIsNaN {
6060
}
6161

6262
impl ScalarUDFImpl for SparkIsNaN {
63-
fn as_any(&self) -> &dyn Any {
64-
self
65-
}
66-
6763
fn name(&self) -> &str {
6864
"isnan"
6965
}
@@ -82,37 +78,23 @@ impl ScalarUDFImpl for SparkIsNaN {
8278
}
8379

8480
fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue> {
85-
let [value] = take_function_args("isnan", args)?;
86-
87-
match value {
88-
ColumnarValue::Array(array) => match array.data_type() {
89-
DataType::Float32 => {
90-
let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
91-
Ok(ColumnarValue::Array(nulls_to_false(
92-
BooleanArray::from_unary(array, |x| x.is_nan()),
93-
)))
94-
}
95-
DataType::Float64 => {
96-
let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
97-
Ok(ColumnarValue::Array(nulls_to_false(
98-
BooleanArray::from_unary(array, |x| x.is_nan()),
99-
)))
100-
}
101-
other => exec_err!("Unsupported data type {other:?} for function isnan"),
102-
},
103-
ColumnarValue::Scalar(sv) => match sv {
104-
ScalarValue::Float32(v) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
105-
Some(v.is_some_and(|x| x.is_nan())),
106-
))),
107-
ScalarValue::Float64(v) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
108-
Some(v.is_some_and(|x| x.is_nan())),
109-
))),
110-
_ => exec_err!(
111-
"Unsupported data type {:?} for function isnan",
112-
sv.data_type()
113-
),
114-
},
115-
}
81+
let [array] = take_function_args("isnan", ColumnarValue::values_to_arrays(args)?)?;
82+
83+
let result = match array.data_type() {
84+
DataType::Float32 => isnan_array::<Float32Type>(&array),
85+
DataType::Float64 => isnan_array::<Float64Type>(&array),
86+
other => return exec_err!("Unsupported data type {other:?} for function isnan"),
87+
};
88+
Ok(ColumnarValue::Array(result))
89+
}
90+
91+
fn isnan_array<T>(array: &ArrayRef) -> ArrayRef
92+
where
93+
T: ArrowPrimitiveType,
94+
T::Native: Float,
95+
{
96+
let array = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
97+
nulls_to_false(BooleanArray::from_unary(array, |x| x.is_nan()))
11698
}
11799

118100
/// Replaces null values with false in a BooleanArray.
@@ -131,6 +113,8 @@ fn nulls_to_false(is_nan: BooleanArray) -> ArrayRef {
131113
#[cfg(test)]
132114
mod tests {
133115
use super::*;
116+
use arrow::array::{Float32Array, Float64Array};
117+
use datafusion_common::ScalarValue;
134118

135119
#[test]
136120
fn test_isnan_float64() {

datafusion/sqllogictest/test_files/spark/math/isnan.slt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,24 @@ true
7070
false
7171
false
7272
false
73+
74+
# Spark's isnan accepts any numeric type by implicitly casting to double, so
75+
# integers, bigints, and decimals all return false (they can never be NaN).
76+
# This test guards against the signature being accidentally narrowed in a way
77+
# that would diverge from Spark.
78+
query BBBB
79+
SELECT
80+
isnan(1) AS int_lit,
81+
isnan(CAST(1 AS BIGINT)) AS bigint,
82+
isnan(CAST(1 AS DECIMAL(10, 2))) AS dec,
83+
isnan(CAST(NULL AS INT)) AS null_int;
84+
----
85+
false false false false
86+
87+
# Untyped NULL literal: relies on DataFusion coercing Null -> Float64 to satisfy
88+
# the Exact(Float32)/Exact(Float64) signature. Without that coercion, planning
89+
# would fail and we'd diverge from Spark, which returns false here.
90+
query B
91+
SELECT isnan(NULL);
92+
----
93+
false

0 commit comments

Comments
 (0)