Skip to content

Commit 307899a

Browse files
authored
Fix nanlen with strings (#344)
* Fix nanlen with strings Closes pydata/xarray#8853 * fix windows * Silence warnings
1 parent 20be463 commit 307899a

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

flox/aggregate_flox.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def _lerp(a, b, *, t, dtype, out=None):
3737
"""
3838
if out is None:
3939
out = np.empty_like(a, dtype=dtype)
40-
diff_b_a = np.subtract(b, a)
40+
with np.errstate(invalid="ignore"):
41+
diff_b_a = np.subtract(b, a)
4142
# asanyarray is a stop-gap until gh-13105
4243
np.add(a, diff_b_a * t, out=out)
4344
np.subtract(b, diff_b_a * (1 - t), out=out, where=t >= 0.5)
@@ -95,7 +96,8 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
9596

9697
# partition the complex array in-place
9798
labels_broadcast = np.broadcast_to(group_idx, array.shape)
98-
cmplx = labels_broadcast + 1j * array
99+
with np.errstate(invalid="ignore"):
100+
cmplx = labels_broadcast + 1j * array
99101
cmplx.partition(kth=kth, axis=-1)
100102
if is_scalar_q:
101103
a_ = cmplx.imag

flox/aggregate_npg.py

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def nanprod(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dt
8888

8989

9090
def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None, dtype=None):
91+
if array.dtype.kind in "US":
92+
array = np.broadcast_to(np.array([1]), array.shape)
9193
result = _get_aggregate(engine).aggregate(
9294
group_idx,
9395
array,

flox/aggregate_numbagg.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,24 @@ def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None,
105105
)
106106

107107

108+
def nanlen(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
109+
if array.dtype.kind in "US":
110+
array = np.broadcast_to(np.array([1]), array.shape)
111+
return _numbagg_wrapper(
112+
group_idx,
113+
array,
114+
axis=axis,
115+
size=size,
116+
func="nancount",
117+
# fill_value=fill_value,
118+
# dtype=dtype,
119+
)
120+
121+
108122
nansum = partial(_numbagg_wrapper, func="nansum")
109123
nanmean = partial(_numbagg_wrapper, func="nanmean")
110124
nanprod = partial(_numbagg_wrapper, func="nanprod")
111125
nansum_of_squares = partial(_numbagg_wrapper, func="nansum_of_squares")
112-
nanlen = partial(_numbagg_wrapper, func="nancount")
113126
nanprod = partial(_numbagg_wrapper, func="nanprod")
114127
nanfirst = partial(_numbagg_wrapper, func="nanfirst")
115128
nanlast = partial(_numbagg_wrapper, func="nanlast")

tests/test_core.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def test_group_by_datetime(engine, method):
11271127

11281128
edges = pd.date_range("1999-12-31", "2000-12-31", freq="ME").to_series().to_numpy()
11291129
actual, _ = groupby_reduce(daskarray, t.to_numpy(), isbin=True, expected_groups=edges, **kwargs)
1130-
expected = data.resample("M").mean().to_numpy()
1130+
expected = data.resample("ME").mean().to_numpy()
11311131
assert_equal(expected, actual)
11321132

11331133
actual, _ = groupby_reduce(
@@ -1688,3 +1688,12 @@ def test_multiple_quantiles(q, chunk, func, by_ndim):
16881688
if by_ndim == 2:
16891689
expected = expected.squeeze(axis=-2)
16901690
assert_equal(expected, actual, tolerance=1e-14)
1691+
1692+
1693+
@pytest.mark.parametrize("dtype", ["U3", "S3"])
1694+
def test_nanlen_string(dtype, engine):
1695+
array = np.array(["ABC", "DEF", "GHI", "JKL", "MNO", "PQR"], dtype=dtype)
1696+
by = np.array([0, 0, 1, 2, 1, 0])
1697+
expected = np.array([3, 2, 1], dtype=np.intp)
1698+
actual, *_ = groupby_reduce(array, by, func="count", engine=engine)
1699+
assert_equal(expected, actual)

0 commit comments

Comments
 (0)