Skip to content

Commit 096c080

Browse files
authored
Enable nanargmax, nanargmin (#171)
* Support nanargmin, nanargmax * Fix test * Add blockwise test * Fix blockwise test * Apply suggestions from code review
1 parent 6a5969f commit 096c080

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

flox/aggregations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def _pick_second(*x):
421421
chunk=("nanmax", "nanargmax"), # order is important
422422
combine=("max", "argmax"),
423423
reduction_type="argreduce",
424-
fill_value=(dtypes.NINF, -1),
424+
fill_value=(dtypes.NINF, 0),
425425
final_fill_value=-1,
426426
finalize=_pick_second,
427427
dtypes=(None, np.intp),
@@ -434,7 +434,7 @@ def _pick_second(*x):
434434
chunk=("nanmin", "nanargmin"), # order is important
435435
combine=("min", "argmin"),
436436
reduction_type="argreduce",
437-
fill_value=(dtypes.INF, -1),
437+
fill_value=(dtypes.INF, 0),
438438
final_fill_value=-1,
439439
finalize=_pick_second,
440440
dtypes=(None, np.intp),

flox/core.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1323,8 +1323,11 @@ def dask_groupby_agg(
13231323
by = dask.array.from_array(by, chunks=chunks)
13241324
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])
13251325

1326-
# preprocess the array: for argreductions, this zips the index together with the array block
1327-
if agg.preprocess:
1326+
# preprocess the array:
1327+
# - for argreductions, this zips the index together with the array block
1328+
# - not necessary for blockwise with argreductions
1329+
# - if this is needed later, we can fix this then
1330+
if agg.preprocess and method != "blockwise":
13281331
array = agg.preprocess(array, axis=axis)
13291332

13301333
# 1. We first apply the groupby-reduction blockwise to generate "intermediates"

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33

4-
@pytest.fixture(scope="module", params=["flox"])
4+
@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
55
def engine(request):
66
if request.param == "numba":
77
try:

tests/test_core.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def dask_array_ones(*args):
5555
"nansum",
5656
"argmax",
5757
"nanfirst",
58-
pytest.param("nanargmax", marks=(pytest.mark.skip,)),
58+
"nanargmax",
5959
"prod",
6060
"nanprod",
6161
"mean",
@@ -69,7 +69,7 @@ def dask_array_ones(*args):
6969
"min",
7070
"nanmin",
7171
"argmin",
72-
pytest.param("nanargmin", marks=(pytest.mark.skip,)),
72+
"nanargmin",
7373
"any",
7474
"all",
7575
"nanlast",
@@ -233,8 +233,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
233233
# computing silences a bunch of dask warnings
234234
array_ = array.compute() if chunks is not None else array
235235
if "arg" in func and add_nan_by:
236+
# NaNs are in by, but we can't call np.argmax([..., NaN, .. ])
237+
# That would return index of the NaN
238+
# This way, we insert NaNs where there are NaNs in by, and
239+
# call np.nanargmax
240+
func_ = f"nan{func}" if "nan" not in func else func
236241
array_[..., nanmask] = np.nan
237-
expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs)
242+
expected = getattr(np, func_)(array_, axis=-1, **kwargs)
238243
# elif func in ["first", "last"]:
239244
# expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
240245
elif func in ["nanfirst", "nanlast"]:
@@ -259,6 +264,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
259264

260265
params = list(itertools.product(["map-reduce"], [True, False, None]))
261266
params.extend(itertools.product(["cohorts"], [False, None]))
267+
if chunks == -1:
268+
params.extend([("blockwise", None)])
269+
262270
for method, reindex in params:
263271
call = partial(
264272
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
@@ -269,11 +277,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
269277
call()
270278
continue
271279
actual, *groups = call()
272-
if "arg" not in func:
273-
# make sure we use simple combine
274-
assert any("simple-combine" in key for key in actual.dask.layers.keys())
275-
else:
276-
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
280+
if method != "blockwise":
281+
if "arg" not in func:
282+
# make sure we use simple combine
283+
assert any("simple-combine" in key for key in actual.dask.layers.keys())
284+
else:
285+
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
277286
for actual_group, expect in zip(groups, expected_groups):
278287
assert_equal(actual_group, expect, tolerance)
279288
if "arg" in func:

0 commit comments

Comments
 (0)