Skip to content

Commit 19db5b3

Browse files
authored
Fix numbagg aggregations (#282)
* Fix numbagg version check Closes #281 * Enable numbagg for count * Better numbagg special-casing * Fixes. * A bunch of typing * Handle fill_value in core numbagg reduction. * Update flox/aggregate_numbagg.py * cleanup * [WIP] test hacky fix * [wip] * Cleanup functions * Fix casting * Fix fill_value masking * optimize * Update flox/aggregations.py * Small cleanup * Fix. * Fix typing * Another bugfix * Optimize seen_groups * Be careful about raveling * Fix benchmark skipping for numbagg * add test
1 parent 273d319 commit 19db5b3

File tree

8 files changed

+175
-74
lines changed

8 files changed

+175
-74
lines changed

asv_bench/benchmarks/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
numbagg_skip = []
1919
for name in expected_names:
2020
numbagg_skip.extend(
21-
list((func, expected_names[0], "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
21+
list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
2222
)
2323

2424

flox/aggregate_numbagg.py

+63-49
Original file line numberDiff line numberDiff line change
@@ -4,64 +4,75 @@
44
import numbagg.grouped
55
import numpy as np
66

7+
DEFAULT_FILL_VALUE = {
8+
"nansum": 0,
9+
"nanmean": np.nan,
10+
"nanvar": np.nan,
11+
"nanstd": np.nan,
12+
"nanmin": np.nan,
13+
"nanmax": np.nan,
14+
"nanany": False,
15+
"nanall": False,
16+
"nansum_of_squares": 0,
17+
"nanprod": 1,
18+
"nancount": 0,
19+
"nanargmax": np.nan,
20+
"nanargmin": np.nan,
21+
"nanfirst": np.nan,
22+
"nanlast": np.nan,
23+
}
24+
25+
CAST_TO = {
26+
# "nansum": {np.bool_: np.int64},
27+
"nanmean": {np.int_: np.float64},
28+
"nanvar": {np.int_: np.float64},
29+
"nanstd": {np.int_: np.float64},
30+
}
31+
32+
33+
FILLNA = {"nansum": 0, "nanprod": 1}
34+
735

836
def _numbagg_wrapper(
937
group_idx,
1038
array,
1139
*,
40+
func,
1241
axis=-1,
13-
func="sum",
1442
size=None,
1543
fill_value=None,
1644
dtype=None,
17-
numbagg_func=None,
1845
):
19-
return numbagg_func(
20-
array,
21-
group_idx,
22-
axis=axis,
23-
num_labels=size,
24-
# The following are unsupported
25-
# fill_value=fill_value,
26-
# dtype=dtype,
27-
)
46+
cast_to = CAST_TO.get(func, None)
47+
if cast_to:
48+
for from_, to_ in cast_to.items():
49+
if np.issubdtype(array.dtype, from_):
50+
array = array.astype(to_)
2851

52+
func_ = getattr(numbagg.grouped, f"group_{func}")
2953

30-
def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
31-
if np.issubdtype(array.dtype, np.bool_):
32-
array = array.astype(np.in64)
33-
return numbagg.grouped.group_nansum(
54+
result = func_(
3455
array,
3556
group_idx,
3657
axis=axis,
3758
num_labels=size,
59+
# The following are unsupported
3860
# fill_value=fill_value,
3961
# dtype=dtype,
40-
)
41-
62+
).astype(dtype, copy=False)
4263

43-
def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
44-
if np.issubdtype(array.dtype, np.int_):
45-
array = array.astype(np.float64)
46-
return numbagg.grouped.group_nanmean(
47-
array,
48-
group_idx,
49-
axis=axis,
50-
num_labels=size,
51-
# fill_value=fill_value,
52-
# dtype=dtype,
53-
)
64+
return result
5465

5566

5667
def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
5768
assert ddof != 0
58-
if np.issubdtype(array.dtype, np.int_):
59-
array = array.astype(np.float64)
60-
return numbagg.grouped.group_nanvar(
61-
array,
69+
70+
return _numbagg_wrapper(
6271
group_idx,
72+
array,
6373
axis=axis,
64-
num_labels=size,
74+
size=size,
75+
func="nanvar",
6576
# ddof=0,
6677
# fill_value=fill_value,
6778
# dtype=dtype,
@@ -70,30 +81,33 @@ def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None,
7081

7182
def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
7283
assert ddof != 0
73-
if np.issubdtype(array.dtype, np.int_):
74-
array = array.astype(np.float64)
75-
return numbagg.grouped.group_nanstd(
76-
array,
84+
85+
return _numbagg_wrapper(
7786
group_idx,
87+
array,
7888
axis=axis,
79-
num_labels=size,
89+
size=size,
90+
func="nanstd"
8091
# ddof=0,
8192
# fill_value=fill_value,
8293
# dtype=dtype,
8394
)
8495

8596

86-
nansum_of_squares = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nansum_of_squares)
87-
nanlen = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nancount)
88-
nanprod = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanprod)
89-
nanfirst = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanfirst)
90-
nanlast = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanlast)
91-
# nanargmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmax)
92-
# nanargmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmin)
93-
nanmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmax)
94-
nanmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmin)
95-
any = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanany)
96-
all = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanall)
97+
nansum = partial(_numbagg_wrapper, func="nansum")
98+
nanmean = partial(_numbagg_wrapper, func="nanmean")
99+
nanprod = partial(_numbagg_wrapper, func="nanprod")
100+
nansum_of_squares = partial(_numbagg_wrapper, func="nansum_of_squares")
101+
nanlen = partial(_numbagg_wrapper, func="nancount")
102+
nanprod = partial(_numbagg_wrapper, func="nanprod")
103+
nanfirst = partial(_numbagg_wrapper, func="nanfirst")
104+
nanlast = partial(_numbagg_wrapper, func="nanlast")
105+
# nanargmax = partial(_numbagg_wrapper, func="nanargmax)
106+
# nanargmin = partial(_numbagg_wrapper, func="nanargmin)
107+
nanmax = partial(_numbagg_wrapper, func="nanmax")
108+
nanmin = partial(_numbagg_wrapper, func="nanmin")
109+
any = partial(_numbagg_wrapper, func="nanany")
110+
all = partial(_numbagg_wrapper, func="nanall")
97111

98112
# sum = nansum
99113
# mean = nanmean

flox/aggregations.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import copy
44
import warnings
55
from functools import partial
6-
from typing import TYPE_CHECKING, Any, Callable, TypedDict
6+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
77

88
import numpy as np
99
from numpy.typing import DTypeLike
@@ -13,6 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
FuncTuple = tuple[Callable | str, ...]
16+
OptionalFuncTuple = tuple[Callable | str | None, ...]
1617

1718

1819
def _is_arg_reduction(func: str | Aggregation) -> bool:
@@ -152,7 +153,7 @@ def __init__(
152153
final_fill_value=dtypes.NA,
153154
dtypes=None,
154155
final_dtype: DTypeLike | None = None,
155-
reduction_type="reduce",
156+
reduction_type: Literal["reduce", "argreduce"] = "reduce",
156157
):
157158
"""
158159
Blueprint for computing grouped aggregations.
@@ -203,11 +204,11 @@ def __init__(
203204
self.reduction_type = reduction_type
204205
self.numpy: FuncTuple = (numpy,) if numpy else (self.name,)
205206
# initialize blockwise reduction
206-
self.chunk: FuncTuple = _atleast_1d(chunk)
207+
self.chunk: OptionalFuncTuple = _atleast_1d(chunk)
207208
# how to aggregate results after first round of reduction
208-
self.combine: FuncTuple = _atleast_1d(combine)
209+
self.combine: OptionalFuncTuple = _atleast_1d(combine)
209210
# simpler reductions used with the "simple combine" algorithm
210-
self.simple_combine: tuple[Callable, ...] = ()
211+
self.simple_combine: OptionalFuncTuple = ()
211212
# finalize results (see mean)
212213
self.finalize: Callable | None = finalize
213214

@@ -279,13 +280,7 @@ def __repr__(self) -> str:
279280
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0)
280281
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0)
281282
prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1)
282-
nanprod = Aggregation(
283-
"nanprod",
284-
chunk="nanprod",
285-
combine="prod",
286-
fill_value=1,
287-
final_fill_value=dtypes.NA,
288-
)
283+
nanprod = Aggregation("nanprod", chunk="nanprod", combine="prod", fill_value=1)
289284

290285

291286
def _mean_finalize(sum_, count):
@@ -579,6 +574,7 @@ def _initialize_aggregation(
579574
}
580575

581576
# Replace sentinel fill values according to dtype
577+
agg.fill_value["user"] = fill_value
582578
agg.fill_value["intermediate"] = tuple(
583579
_get_fill_value(dt, fv)
584580
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
@@ -613,7 +609,7 @@ def _initialize_aggregation(
613609
else:
614610
agg.min_count = 0
615611

616-
simple_combine: list[Callable] = []
612+
simple_combine: list[Callable | None] = []
617613
for combine in agg.combine:
618614
if isinstance(combine, str):
619615
if combine in ["nanfirst", "nanlast"]:

flox/core.py

+52-10
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,26 @@
8686
DUMMY_AXIS = -2
8787

8888

89+
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
90+
"""Account for numbagg not providing a fill_value kwarg."""
91+
from .aggregate_numbagg import DEFAULT_FILL_VALUE
92+
93+
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
94+
return result
95+
# The condition needs to be
96+
# len(found_groups) < size; if so we mask with fill_value (?)
97+
default_fv = DEFAULT_FILL_VALUE[func]
98+
needs_masking = fill_value is not None and not np.array_equal(
99+
fill_value, default_fv, equal_nan=True
100+
)
101+
groups = np.arange(size)
102+
if needs_masking:
103+
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
104+
if mask.any():
105+
result[..., groups[mask]] = fill_value
106+
return result
107+
108+
89109
def _issorted(arr: np.ndarray) -> bool:
90110
return bool((arr[:-1] <= arr[1:]).all())
91111

@@ -780,7 +800,11 @@ def chunk_reduce(
780800
group_idx, grps, found_groups_shape, _, size, props = factorize_(
781801
(by,), axes, expected_groups=(expected_groups,), reindex=reindex, sort=sort
782802
)
783-
groups = grps[0]
803+
(groups,) = grps
804+
805+
# do this *before* possible broadcasting below.
806+
# factorize_ has already taken care of offsetting
807+
seen_groups = _unique(group_idx)
784808

785809
order = "C"
786810
if nax > 1:
@@ -850,6 +874,16 @@ def chunk_reduce(
850874
result = generic_aggregate(
851875
group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func
852876
).astype(dt, copy=False)
877+
if engine == "numbagg":
878+
result = _postprocess_numbagg(
879+
result,
880+
func=reduction,
881+
size=size,
882+
fill_value=fv,
883+
# Unfortunately, we cannot reuse found_groups, it has not
884+
# been "offset" and is really expected_groups in nearly all cases
885+
seen_groups=seen_groups,
886+
)
853887
if np.any(props.nanmask):
854888
# remove NaN group label which should be last
855889
result = result[..., :-1]
@@ -1053,6 +1087,8 @@ def _grouped_combine(
10531087
"""Combine intermediates step of tree reduction."""
10541088
from dask.utils import deepmap
10551089

1090+
combine = agg.combine
1091+
10561092
if isinstance(x_chunk, dict):
10571093
# Only one block at final step; skip one extra groupby
10581094
return x_chunk
@@ -1093,7 +1129,8 @@ def _grouped_combine(
10931129
results = chunk_argreduce(
10941130
array_idx,
10951131
groups,
1096-
func=agg.combine[slicer], # count gets treated specially next
1132+
# count gets treated specially next
1133+
func=combine[slicer], # type: ignore[arg-type]
10971134
axis=axis,
10981135
expected_groups=None,
10991136
fill_value=agg.fill_value["intermediate"][slicer],
@@ -1127,9 +1164,10 @@ def _grouped_combine(
11271164
elif agg.reduction_type == "reduce":
11281165
# Here we reduce the intermediates individually
11291166
results = {"groups": None, "intermediates": []}
1130-
for idx, (combine, fv, dtype) in enumerate(
1131-
zip(agg.combine, agg.fill_value["intermediate"], agg.dtype["intermediate"])
1167+
for idx, (combine_, fv, dtype) in enumerate(
1168+
zip(combine, agg.fill_value["intermediate"], agg.dtype["intermediate"])
11321169
):
1170+
assert combine_ is not None
11331171
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis)
11341172
if array.shape[-1] == 0:
11351173
# all empty when combined
@@ -1143,7 +1181,7 @@ def _grouped_combine(
11431181
_results = chunk_reduce(
11441182
array,
11451183
groups,
1146-
func=combine,
1184+
func=combine_,
11471185
axis=axis,
11481186
expected_groups=None,
11491187
fill_value=(fv,),
@@ -1788,8 +1826,13 @@ def _choose_engine(by, agg: Aggregation):
17881826

17891827
# numbagg only supports nan-skipping reductions
17901828
# without dtype specified
1791-
if HAS_NUMBAGG and "nan" in agg.name:
1792-
if not_arg_reduce and dtype is None:
1829+
has_blockwise_nan_skipping = (agg.chunk[0] is None and "nan" in agg.name) or any(
1830+
(isinstance(func, str) and "nan" in func) for func in agg.chunk
1831+
)
1832+
if HAS_NUMBAGG:
1833+
if agg.name in ["all", "any"] or (
1834+
not_arg_reduce and has_blockwise_nan_skipping and dtype is None
1835+
):
17931836
return "numbagg"
17941837

17951838
if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)):
@@ -2050,7 +2093,7 @@ def groupby_reduce(
20502093
nax = len(axis_)
20512094

20522095
# When axis is a subset of possible values; then npg will
2053-
# apply it to groups that don't exist along a particular axis (for e.g.)
2096+
# apply the fill_value to groups that don't exist along a particular axis (for e.g.)
20542097
# since these count as a group that is absent. thoo!
20552098
# fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
20562099
# The only way to do this consistently is mask out using min_count
@@ -2090,8 +2133,7 @@ def groupby_reduce(
20902133
# TODO: How else to narrow that array.chunks is there?
20912134
assert isinstance(array, DaskArray)
20922135

2093-
# TODO: fix typing of FuncTuple in Aggregation
2094-
if agg.chunk[0] is None and method != "blockwise": # type: ignore[unreachable]
2136+
if agg.chunk[0] is None and method != "blockwise":
20952137
raise NotImplementedError(
20962138
f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'."
20972139
f"Received method={method!r}"

flox/xrutils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,6 @@ def module_available(module: str, minversion: Optional[str] = None) -> bool:
339339
has = importlib.util.find_spec(module) is not None
340340
if has:
341341
mod = importlib.import_module(module)
342-
return Version(mod.__version__) < Version(minversion) if minversion is not None else True
342+
return Version(mod.__version__) >= Version(minversion) if minversion is not None else True
343343
else:
344344
return False

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ show_error_codes = true
111111
warn_unused_ignores = true
112112
warn_unreachable = true
113113
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
114+
exclude=["asv_bench/pkgs"]
114115

115116
[[tool.mypy.overrides]]
116117
module=[

0 commit comments

Comments
 (0)