Skip to content

Commit cb640f9

Browse files
authored
Fixes for latest numpy_groupies (#85)
* Remove nanvar, nanstd compatibility code * Remove argreduction compatibility code * Fix argreduction along axis with single chunk Fix argreductions * Update min numpy_groupies version * Restore tests * Fix numpy tests. * Skip nanagr* instead of xfail
1 parent 5798990 commit cb640f9

8 files changed

+67
-60
lines changed

ci/environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies:
1313
- pytest-xdist
1414
- xarray
1515
- pre-commit
16-
- numpy_groupies
16+
- numpy_groupies>=0.9.15
1717
- pooch
1818
- toolz
1919
- numba

ci/minimal-requirements.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- pytest
99
- pytest-cov
1010
- pytest-xdist
11-
- numpy_groupies
11+
- numpy_groupies>=0.9.15
1212
- pandas
1313
- pooch
1414
- toolz

ci/no-dask.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- xarray
1313
- numpydoc
1414
- pre-commit
15-
- numpy_groupies
15+
- numpy_groupies>=0.9.15
1616
- pooch
1717
- toolz
1818
- numba

ci/no-xarray.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- dask-core
1313
- numpydoc
1414
- pre-commit
15-
- numpy_groupies
15+
- numpy_groupies>=0.9.15
1616
- pooch
1717
- toolz
1818
- numba

flox/aggregations.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
from . import aggregate_flox, aggregate_npg, xrdtypes as dtypes, xrutils
1010

1111

12+
def _is_arg_reduction(func: str | Aggregation) -> bool:
13+
if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
14+
return True
15+
if isinstance(func, Aggregation) and func.reduction_type == "argreduce":
16+
return True
17+
return False
18+
19+
1220
def generic_aggregate(
1321
group_idx,
1422
array,
@@ -488,7 +496,11 @@ def _initialize_aggregation(
488496
agg.fill_value[func] = _get_fill_value(agg.dtype[func], agg.fill_value[func])
489497

490498
fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
491-
agg.fill_value["numpy"] = (fv,)
499+
if _is_arg_reduction(agg):
500+
# this allows us to unravel_index easily. we have to do that nearly every time.
501+
agg.fill_value["numpy"] = (0,)
502+
else:
503+
agg.fill_value["numpy"] = (fv,)
492504

493505
if finalize_kwargs is not None:
494506
assert isinstance(finalize_kwargs, dict)

flox/core.py

+46-51
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def chunk_argreduce(
509509
dask.array.reductions.argtopk
510510
"""
511511
array, idx = array_plus_idx
512+
by = np.broadcast_to(by, array.shape)
512513

513514
results = chunk_reduce(
514515
array,
@@ -522,17 +523,22 @@ def chunk_argreduce(
522523
sort=sort,
523524
)
524525
if not isnull(results["groups"]).all():
525-
# will not work for empty groups...
526-
# glorious
527526
idx = np.broadcast_to(idx, array.shape)
527+
528+
# array, by get flattened to 1D before passing to npg
529+
# so the indexes need to be unraveled
528530
newidx = np.unravel_index(results["intermediates"][1], array.shape)
531+
532+
# Now index into the actual "global" indexes `idx`
529533
results["intermediates"][1] = idx[newidx]
530534

531535
if reindex and expected_groups is not None:
532536
results["intermediates"][1] = reindex_(
533537
results["intermediates"][1], results["groups"].squeeze(), expected_groups, fill_value=0
534538
)
535539

540+
assert results["intermediates"][0].shape == results["intermediates"][1].shape
541+
536542
return results
537543

538544

@@ -879,34 +885,45 @@ def _grouped_combine(
879885
array_idx = tuple(
880886
_conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1)
881887
)
882-
results = chunk_argreduce(
883-
array_idx,
884-
groups,
885-
func=agg.combine[slicer], # count gets treated specially next
886-
axis=axis,
887-
expected_groups=None,
888-
fill_value=agg.fill_value["intermediate"][slicer],
889-
dtype=agg.dtype["intermediate"][slicer],
890-
engine=engine,
891-
sort=sort,
892-
)
888+
889+
# for a single element along axis, we don't want to run the argreduction twice
890+
# This happens when we are reducing along an axis with a single chunk.
891+
avoid_reduction = array_idx[0].shape[axis[0]] == 1
892+
if avoid_reduction:
893+
results = {"groups": groups, "intermediates": list(array_idx)}
894+
else:
895+
results = chunk_argreduce(
896+
array_idx,
897+
groups,
898+
func=agg.combine[slicer], # count gets treated specially next
899+
axis=axis,
900+
expected_groups=None,
901+
fill_value=agg.fill_value["intermediate"][slicer],
902+
dtype=agg.dtype["intermediate"][slicer],
903+
engine=engine,
904+
sort=sort,
905+
)
893906

894907
if agg.chunk[-1] == "nanlen":
895908
counts = _conc2(x_chunk, key1="intermediates", key2=2, axis=axis)
896-
# sum the counts
897-
results["intermediates"].append(
898-
chunk_reduce(
899-
counts,
900-
groups,
901-
func="sum",
902-
axis=axis,
903-
expected_groups=None,
904-
fill_value=(0,),
905-
dtype=(np.intp,),
906-
engine=engine,
907-
sort=sort,
908-
)["intermediates"][0]
909-
)
909+
910+
if avoid_reduction:
911+
results["intermediates"].append(counts)
912+
else:
913+
# sum the counts
914+
results["intermediates"].append(
915+
chunk_reduce(
916+
counts,
917+
groups,
918+
func="sum",
919+
axis=axis,
920+
expected_groups=None,
921+
fill_value=(0,),
922+
dtype=(np.intp,),
923+
engine=engine,
924+
sort=sort,
925+
)["intermediates"][0]
926+
)
910927

911928
elif agg.reduction_type == "reduce":
912929
# Here we reduce the intermediates individually
@@ -1006,24 +1023,7 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
10061023
) # type: ignore
10071024

10081025
if _is_arg_reduction(agg):
1009-
if array.ndim > 1:
1010-
# default fill_value is -1; we can't unravel that;
1011-
# so replace -1 with 0; unravel; then replace 0 with -1
1012-
# UGH!
1013-
idx = results["intermediates"][0]
1014-
mask = idx == agg.fill_value["numpy"][0]
1015-
idx[mask] = 0
1016-
# Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
1017-
# will return wrong indices
1018-
idx = np.unravel_index(idx, array.shape)[-1]
1019-
idx[mask] = agg.fill_value["numpy"][0]
1020-
results["intermediates"][0] = idx
1021-
elif agg.name in ["nanvar", "nanstd"]:
1022-
# TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1023-
value, counts = results["intermediates"]
1024-
mask = counts <= 0
1025-
value[mask] = np.nan
1026-
results["intermediates"][0] = value
1026+
results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1]
10271027

10281028
result = _finalize_results(
10291029
results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex
@@ -1530,12 +1530,7 @@ def groupby_reduce(
15301530
# The only way to do this consistently is mask out using min_count
15311531
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
15321532
if min_count is None:
1533-
if (
1534-
len(axis) < by.ndim
1535-
or fill_value is not None
1536-
# TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1537-
or (not has_dask and isinstance(func, str) and func in ["nanvar", "nanstd"])
1538-
):
1533+
if len(axis) < by.ndim or fill_value is not None:
15391534
min_count = 1
15401535

15411536
# TODO: set in xarray?

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ include_package_data = True
2828
python_requires = >=3.7
2929
install_requires =
3030
pandas
31-
numpy_groupies
31+
numpy_groupies >= '0.9.15'
3232
toolz
3333
importlib-metadata; python_version < '3.8'
3434

tests/test_core.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def dask_array_ones(*args):
5353
"min",
5454
"nanmin",
5555
"argmax",
56-
pytest.param("nanargmax", marks=(pytest.mark.xfail,)),
56+
pytest.param("nanargmax", marks=(pytest.mark.skip,)),
5757
"argmin",
58-
pytest.param("nanargmin", marks=(pytest.mark.xfail,)),
58+
pytest.param("nanargmin", marks=(pytest.mark.skip,)),
5959
"any",
6060
"all",
6161
pytest.param("median", marks=(pytest.mark.skip,)),
@@ -142,7 +142,7 @@ def gen_array_by(size, func):
142142
return array, by
143143

144144

145-
@pytest.mark.parametrize("chunks", [None, 3, 4])
145+
@pytest.mark.parametrize("chunks", [None, -1, 3, 4])
146146
@pytest.mark.parametrize("nby", [1, 2, 3])
147147
@pytest.mark.parametrize("size", ((12,), (12, 9)))
148148
@pytest.mark.parametrize("add_nan_by", [True, False])

0 commit comments

Comments
 (0)