Skip to content

Commit 6a5969f

Browse files
authored
Support nanfirst, nanlast with simple combine algo (#240)
* Support nanfirst, nanlast with simple combine algo Closes #227 * Guard dask test * Prepare for first, last support
1 parent 4164712 commit 6a5969f

File tree

4 files changed

+171
-27
lines changed

4 files changed

+171
-27
lines changed

flox/aggregations.py

+14
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def __init__(
184184
self.chunk: FuncTuple = _atleast_1d(chunk)
185185
# how to aggregate results after first round of reduction
186186
self.combine: FuncTuple = _atleast_1d(combine)
187+
# simpler reductions used with the "simple combine" algorithm
188+
self.simple_combine = None
187189
# final aggregation
188190
self.aggregate: Callable | str = aggregate if aggregate else self.combine[0]
189191
# finalize results (see mean)
@@ -577,4 +579,16 @@ def _initialize_aggregation(
577579
else:
578580
agg.min_count = 0
579581

582+
simple_combine = []
583+
for combine in agg.combine:
584+
if isinstance(combine, str):
585+
if combine in ["nanfirst", "nanlast"]:
586+
simple_combine.append(getattr(xrutils, combine))
587+
else:
588+
simple_combine.append(getattr(np, combine))
589+
else:
590+
simple_combine.append(combine)
591+
592+
agg.simple_combine = tuple(simple_combine)
593+
580594
return agg

flox/core.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def _is_minmax_reduction(func: T_Agg) -> bool:
9494
)
9595

9696

97+
def _is_first_last_reduction(func: T_Agg) -> bool:
98+
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
99+
100+
97101
def _get_expected_groups(by: T_By, sort: bool) -> pd.Index:
98102
if is_duck_dask_array(by):
99103
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -954,13 +958,13 @@ def _simple_combine(
954958
results: IntermediateDict = {"groups": unique_groups}
955959
results["intermediates"] = []
956960
axis_ = axis[:-1] + (DUMMY_AXIS,)
957-
for idx, combine in enumerate(agg.combine):
961+
for idx, combine in enumerate(agg.simple_combine):
958962
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_)
959963
assert array.ndim >= 2
960964
with warnings.catch_warnings():
961965
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
962-
assert isinstance(combine, str)
963-
result = getattr(np, combine)(array, axis=axis_, keepdims=True)
966+
assert callable(combine)
967+
result = combine(array, axis=axis_, keepdims=True)
964968
if is_aggregate:
965969
# squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
966970
result = result.squeeze(axis=DUMMY_AXIS)
@@ -1534,11 +1538,17 @@ def _validate_reindex(
15341538
raise ValueError(
15351539
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
15361540
)
1541+
if func in ["first", "last"]:
1542+
raise ValueError("reindex must be None or False when func is 'first' or 'last.")
15371543

15381544
if reindex is None:
15391545
if all_numpy:
15401546
return True
15411547

1548+
if func in ["first", "last"]:
1549+
# have to do the grouped_combine since there's no good fill_value
1550+
reindex = False
1551+
15421552
if method == "blockwise" or _is_arg_reduction(func):
15431553
reindex = False
15441554

@@ -1552,6 +1562,7 @@ def _validate_reindex(
15521562
reindex = True
15531563

15541564
assert isinstance(reindex, bool)
1565+
15551566
return reindex
15561567

15571568

@@ -1875,6 +1886,21 @@ def groupby_reduce(
18751886
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
18761887
nax = len(axis_)
18771888

1889+
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
1890+
1891+
if _is_first_last_reduction(func):
1892+
if has_dask and nax != 1:
1893+
raise ValueError(
1894+
"For dask arrays: first, last, nanfirst, nanlast reductions are "
1895+
"only supported along a single axis. Please reshape appropriately."
1896+
)
1897+
1898+
elif nax not in [1, by_.ndim]:
1899+
raise ValueError(
1900+
"first, last, nanfirst, nanlast reductions are only supported "
1901+
"along a single axis or when reducing across all dimensions of `by`."
1902+
)
1903+
18781904
# TODO: make sure expected_groups is unique
18791905
if nax == 1 and by_.ndim > 1 and expected_groups is None:
18801906
if not any_by_dask:
@@ -1898,8 +1924,6 @@ def groupby_reduce(
18981924
axis_ = tuple(array.ndim + np.arange(-nax, 0))
18991925
nax = len(axis_)
19001926

1901-
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
1902-
19031927
# When axis is a subset of possible values; then npg will
19041928
# apply it to groups that don't exist along a particular axis (for e.g.)
19051929
# since these count as a group that is absent. thoo!
@@ -1986,6 +2010,6 @@ def groupby_reduce(
19862010
).reshape(result.shape[:-1] + grp_shape)
19872011
groups = final_groups
19882012

1989-
if _is_minmax_reduction(func) and is_bool_array:
2013+
if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
19902014
result = result.astype(bool)
19912015
return (result, *groups) # type: ignore[return-value] # Unpack not in mypy yet

flox/xrutils.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# The functions defined here were copied based on the source code
22
# defined in xarray
33

4-
54
import datetime
65
from typing import Any, Iterable
76

87
import numpy as np
98
import pandas as pd
9+
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
1010

1111
try:
1212
import cftime
@@ -283,3 +283,36 @@ def _contains_cftime_datetimes(array) -> bool:
283283
return isinstance(sample, cftime.datetime)
284284
else:
285285
return False
286+
287+
288+
def _select_along_axis(values, idx, axis):
289+
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
290+
sl = other_ind[:axis] + (idx,) + other_ind[axis:]
291+
return values[sl]
292+
293+
294+
def nanfirst(values, axis, keepdims=False):
295+
if isinstance(axis, tuple):
296+
(axis,) = axis
297+
values = np.asarray(values)
298+
axis = normalize_axis_index(axis, values.ndim)
299+
idx_first = np.argmax(~pd.isnull(values), axis=axis)
300+
result = _select_along_axis(values, idx_first, axis)
301+
if keepdims:
302+
return np.expand_dims(result, axis=axis)
303+
else:
304+
return result
305+
306+
307+
def nanlast(values, axis, keepdims=False):
308+
if isinstance(axis, tuple):
309+
(axis,) = axis
310+
values = np.asarray(values)
311+
axis = normalize_axis_index(axis, values.ndim)
312+
rev = (slice(None),) * axis + (slice(None, None, -1),)
313+
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
314+
result = _select_along_axis(values, idx_last, axis)
315+
if keepdims:
316+
return np.expand_dims(result, axis=axis)
317+
else:
318+
return result

tests/test_core.py

+93-20
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import itertools
44
import warnings
55
from functools import partial, reduce
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Callable
77

88
import numpy as np
99
import pandas as pd
1010
import pytest
1111
from numpy_groupies.aggregate_numpy import aggregate
1212

13+
from flox import xrutils
1314
from flox.aggregations import Aggregation
1415
from flox.core import (
1516
_convert_expected_groups_to_index,
@@ -53,6 +54,7 @@ def dask_array_ones(*args):
5354
"sum",
5455
"nansum",
5556
"argmax",
57+
"nanfirst",
5658
pytest.param("nanargmax", marks=(pytest.mark.skip,)),
5759
"prod",
5860
"nanprod",
@@ -70,6 +72,7 @@ def dask_array_ones(*args):
7072
pytest.param("nanargmin", marks=(pytest.mark.skip,)),
7173
"any",
7274
"all",
75+
"nanlast",
7376
pytest.param("median", marks=(pytest.mark.skip,)),
7477
pytest.param("nanmedian", marks=(pytest.mark.skip,)),
7578
)
@@ -78,6 +81,21 @@ def dask_array_ones(*args):
7881
from flox.core import T_Engine, T_ExpectedGroupsOpt, T_Func2
7982

8083

84+
def _get_array_func(func: str) -> Callable:
85+
if func == "count":
86+
87+
def npfunc(x):
88+
x = np.asarray(x)
89+
return (~np.isnan(x)).sum()
90+
91+
elif func in ["nanfirst", "nanlast"]:
92+
npfunc = getattr(xrutils, func)
93+
else:
94+
npfunc = getattr(np, func)
95+
96+
return npfunc
97+
98+
8199
def test_alignment_error():
82100
da = np.ones((12,))
83101
labels = np.ones((5,))
@@ -217,6 +235,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
217235
if "arg" in func and add_nan_by:
218236
array_[..., nanmask] = np.nan
219237
expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs)
238+
# elif func in ["first", "last"]:
239+
# expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
240+
elif func in ["nanfirst", "nanlast"]:
241+
expected = getattr(xrutils, func)(array_[..., ~nanmask], axis=-1, **kwargs)
220242
else:
221243
expected = getattr(np, func)(array_[..., ~nanmask], axis=-1, **kwargs)
222244
for _ in range(nby):
@@ -241,7 +263,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
241263
call = partial(
242264
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
243265
)
244-
if "arg" in func and reindex is True:
266+
if ("arg" in func or func in ["first", "last"]) and reindex is True:
245267
# simple_combine with argreductions not supported right now
246268
with pytest.raises(NotImplementedError):
247269
call()
@@ -486,6 +508,28 @@ def test_dask_reduce_axis_subset():
486508
)
487509

488510

511+
@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
512+
@pytest.mark.parametrize("axis", [(0, 1)])
513+
def test_first_last_disallowed(axis, func):
514+
with pytest.raises(ValueError):
515+
groupby_reduce(np.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=axis)
516+
517+
518+
@requires_dask
519+
@pytest.mark.parametrize("func", ["nanfirst", "nanlast"])
520+
@pytest.mark.parametrize("axis", [None, (0, 1, 2)])
521+
def test_nanfirst_nanlast_disallowed_dask(axis, func):
522+
with pytest.raises(ValueError):
523+
groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=axis)
524+
525+
526+
@requires_dask
527+
@pytest.mark.parametrize("func", ["first", "last"])
528+
def test_first_last_disallowed_dask(func):
529+
with pytest.raises(NotImplementedError):
530+
groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=-1)
531+
532+
489533
@requires_dask
490534
@pytest.mark.parametrize("func", ALL_FUNCS)
491535
@pytest.mark.parametrize(
@@ -495,8 +539,34 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
495539
if "arg" in func and engine == "flox":
496540
pytest.skip()
497541

498-
if not isinstance(axis, int) and "arg" in func and (axis is None or len(axis) > 1):
499-
pytest.skip()
542+
if not isinstance(axis, int):
543+
if "arg" in func and (axis is None or len(axis) > 1):
544+
pytest.skip()
545+
if ("first" in func or "last" in func) and (axis is not None and len(axis) not in [1, 3]):
546+
pytest.skip()
547+
548+
if func in ["all", "any"]:
549+
fill_value = False
550+
else:
551+
fill_value = 123
552+
553+
if "var" in func or "std" in func:
554+
tolerance = {"rtol": 1e-14, "atol": 1e-16}
555+
else:
556+
tolerance = None
557+
# tests against the numpy output to make sure dask compute matches
558+
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
559+
rng = np.random.default_rng(12345)
560+
array = rng.random(by.shape)
561+
kwargs = dict(
562+
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
563+
)
564+
expected, _ = groupby_reduce(array, by, **kwargs)
565+
if engine == "flox":
566+
kwargs.pop("engine")
567+
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
568+
assert_equal(expected_npg, expected)
569+
500570
if func in ["all", "any"]:
501571
fill_value = False
502572
else:
@@ -513,17 +583,23 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
513583
kwargs = dict(
514584
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
515585
)
586+
expected, _ = groupby_reduce(array, by, **kwargs)
587+
if engine == "flox":
588+
kwargs.pop("engine")
589+
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
590+
assert_equal(expected_npg, expected)
591+
592+
if ("first" in func or "last" in func) and (
593+
axis is None or (not isinstance(axis, int) and len(axis) != 1)
594+
):
595+
return
596+
516597
with raise_if_dask_computes():
517598
actual, _ = groupby_reduce(
518599
da.from_array(array, chunks=(-1, 2, 3)),
519600
da.from_array(by, chunks=(-1, 2, 2)),
520601
**kwargs,
521602
)
522-
expected, _ = groupby_reduce(array, by, **kwargs)
523-
if engine == "flox":
524-
kwargs.pop("engine")
525-
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
526-
assert_equal(expected_npg, expected)
527603
assert_equal(actual, expected, tolerance)
528604

529605

@@ -751,23 +827,17 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine):
751827
if chunks is not None and not has_dask:
752828
pytest.skip()
753829

754-
if func == "count":
755-
756-
def npfunc(x):
757-
x = np.asarray(x)
758-
return (~np.isnan(x)).sum()
759-
760-
else:
761-
npfunc = getattr(np, func)
762-
830+
npfunc = _get_array_func(func)
763831
by = np.array([1, 2, 3, 1, 2, 3])
764832
array = np.array([np.nan, 1, 1, np.nan, 1, 1])
765833
if chunks:
766834
array = dask.array.from_array(array, chunks)
767835
actual, _ = groupby_reduce(
768836
array, by, func=func, engine=engine, fill_value=fill_value, expected_groups=[0, 1, 2, 3]
769837
)
770-
expected = np.array([fill_value, fill_value, npfunc([1.0, 1.0]), npfunc([1.0, 1.0])])
838+
expected = np.array(
839+
[fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)]
840+
)
771841
assert_equal(actual, expected)
772842

773843

@@ -832,6 +902,8 @@ def test_cohorts_nd_by(func, method, axis, engine):
832902

833903
if axis is not None and method != "map-reduce":
834904
pytest.xfail()
905+
if axis is None and ("first" in func or "last" in func):
906+
pytest.skip()
835907

836908
kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value)
837909
actual, groups = groupby_reduce(array, by, **kwargs)
@@ -897,7 +969,8 @@ def test_bool_reductions(func, engine):
897969
pytest.skip()
898970
groups = np.array([1, 1, 1])
899971
data = np.array([True, True, False])
900-
expected = np.expand_dims(getattr(np, func)(data), -1)
972+
npfunc = _get_array_func(func)
973+
expected = np.expand_dims(npfunc(data, axis=0), -1)
901974
actual, _ = groupby_reduce(data, groups, func=func, engine=engine)
902975
assert_equal(expected, actual)
903976

0 commit comments

Comments
 (0)