Skip to content

Commit 5b7edbe

Browse files
authored
Use isnull instead of isnan: engine="flox" (#105)
1 parent 01b1afe commit 5b7edbe

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

flox/aggregate_flox.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import numpy as np
44

5+
from .xrutils import isnull
6+
57

68
def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None):
79
"""
@@ -36,7 +38,7 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
3638

3739

3840
def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
39-
result = func(group_idx, np.where(np.isnan(array), fillna, array), *args, **kwargs)
41+
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
4042
# np.nanmax([np.nan, np.nan]) = np.nan
4143
# To recover this behaviour, we need to search for the fillna value
4244
# (either np.inf or -np.inf), and replace with NaN
@@ -74,7 +76,7 @@ def sum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None, dty
7476
def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
7577
return sum_of_squares(
7678
group_idx,
77-
np.where(np.isnan(array), 0, array),
79+
np.where(isnull(array), 0, array),
7880
size=size,
7981
fill_value=fill_value,
8082
axis=axis,
@@ -83,7 +85,7 @@ def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None,
8385

8486

8587
def nanlen(group_idx, array, *args, **kwargs):
86-
return sum(group_idx, (~np.isnan(array)).astype(int), *args, **kwargs)
88+
return sum(group_idx, (~isnull(array)).astype(int), *args, **kwargs)
8789

8890

8991
def mean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):

tests/test_xarray.py

+13
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,16 @@ def test_cache():
400400

401401
xarray_reduce(ds, "labels", func="mean", method="blockwise")
402402
assert len(cache.data) == 2
403+
404+
405+
@pytest.mark.parametrize("use_cftime", [True, False])
406+
def test_datetime_array_reduce(use_cftime):
407+
408+
time = xr.DataArray(
409+
xr.date_range("2009-01-01", "2012-12-31", use_cftime=use_cftime),
410+
dims=("time",),
411+
name="time",
412+
)
413+
expected = time.resample(time="YS").count() # fails
414+
actual = resample_reduce(time.resample(time="YS"), func="count", engine="flox")
415+
assert_equal(expected, actual)

0 commit comments

Comments
 (0)