Skip to content

Commit 4bbff18

Browse files
authored
fix(rust, python): explicit nan comparison in min/max agg (#5403)
1 parent 982c10e commit 4bbff18

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

polars/polars-arrow/src/kernels/rolling/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type WindowSize = usize;
2222
type Len = usize;
2323

2424
#[inline]
25+
/// NaN will be smaller than every valid value
2526
pub fn compare_fn_nan_min<T>(a: &T, b: &T) -> Ordering
2627
where
2728
T: PartialOrd + IsFloat,
@@ -42,6 +43,7 @@ where
4243
}
4344

4445
#[inline]
46+
/// NaN will be larger than every valid value
4547
pub fn compare_fn_nan_max<T>(a: &T, b: &T) -> Ordering
4648
where
4749
T: PartialOrd + IsFloat,

polars/polars-core/src/chunked_array/ops/aggregate.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
//! Implementations of the ChunkAgg trait.
2+
use std::cmp::Ordering;
23
use std::ops::Add;
34

45
use arrow::compute;
56
use arrow::types::simd::Simd;
67
use num::{Float, ToPrimitive};
8+
use polars_arrow::kernels::rolling::{compare_fn_nan_max, compare_fn_nan_min};
79
use polars_arrow::prelude::QuantileInterpolOptions;
810

911
use crate::chunked_array::ChunkedArray;
@@ -88,7 +90,13 @@ where
8890
IsSorted::Not => self
8991
.downcast_iter()
9092
.filter_map(compute::aggregate::min_primitive)
91-
.fold_first_(|acc, v| if acc < v { acc } else { v }),
93+
.fold_first_(|acc, v| {
94+
if matches!(compare_fn_nan_max(&acc, &v), Ordering::Less) {
95+
acc
96+
} else {
97+
v
98+
}
99+
}),
92100
}
93101
}
94102

@@ -111,7 +119,13 @@ where
111119
IsSorted::Not => self
112120
.downcast_iter()
113121
.filter_map(compute::aggregate::max_primitive)
114-
.fold_first_(|acc, v| if acc > v { acc } else { v }),
122+
.fold_first_(|acc, v| {
123+
if matches!(compare_fn_nan_min(&acc, &v), Ordering::Greater) {
124+
acc
125+
} else {
126+
v
127+
}
128+
}),
115129
}
116130
}
117131

0 commit comments

Comments
 (0)