Skip to content

Commit 619a390

Browse files
authored
Support aggregating sparse arrays. (#442)
1 parent b32a602 commit 619a390

17 files changed

+414
-53
lines changed

.github/workflows/upstream-dev-ci.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ jobs:
7878
git+https://github.com/Unidata/cftime
7979
python -m pip install \
8080
git+https://github.com/dask/dask \
81-
git+https://github.com/ml31415/numpy-groupies
81+
git+https://github.com/ml31415/numpy-groupies \
82+
git+https://github.com/pydata/sparse
8283
8384
- name: Install flox
8485
run: |

ci/env-numpy1.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- pandas
1212
- numpy<2
1313
- scipy
14+
- sparse
1415
- lxml # for mypy coverage report
1516
- matplotlib
1617
- pip

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- pandas
1212
- numpy>=1.22
1313
- scipy
14+
- sparse
1415
- lxml # for mypy coverage report
1516
- matplotlib
1617
- pip

ci/no-dask.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies:
88
- cftime
99
- numpy>=1.22
1010
- scipy
11+
- sparse
1112
- pip
1213
- pytest
1314
- pytest-cov

ci/no-numba.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- pandas
1212
- numpy>=1.22
1313
- scipy
14+
- sparse
1415
- lxml # for mypy coverage report
1516
- matplotlib
1617
- pip

ci/no-xarray.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ dependencies:
77
- pandas
88
- numpy>=1.22
99
- scipy
10+
- sparse
1011
- pip
1112
- pytest
1213
- pytest-cov

docs/source/arrays.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Duck Array Support
22

3+
## Sparse Arrays
4+
5+
`sparse.COO` arrays from the `pydata/sparse` project are supported using algorithms that work on the underlying dense data.
6+
See `aggregate_sparse.py` for details.
7+
At the moment the following reductions are supported: `sum`, `nansum`, `min`, `nanmin`, `max`, `nanmax`, `count`.
8+
9+
## Other array types
10+
311
Aggregating over other array types will work if the array types supports the following methods, [ufunc.reduceat](https://numpy.org/doc/stable/reference/generated/numpy.ufunc.reduceat.html) or [ufunc.at](https://numpy.org/doc/stable/reference/generated/numpy.ufunc.at.html)
412

513
| Reduction | `method="numpy"` | `method="flox"` |

flox/aggregate_flox.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _np_grouped_op(
146146
# assumes input is sorted, which I do in core._prepare_for_flox
147147
aux = group_idx
148148

149-
flag = np.concatenate((np.array([True], like=array), aux[1:] != aux[:-1]))
149+
flag = np.concatenate((np.asarray([True], like=aux), aux[1:] != aux[:-1]))
150150
uniques = aux[flag]
151151
(inv_idx,) = flag.nonzero()
152152

@@ -165,7 +165,7 @@ def _np_grouped_op(
165165
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
166166
kwargs["group_idx"] = group_idx
167167

168-
if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
168+
if (len(uniques) == size) and (uniques == np.arange(size, like=aux)).all():
169169
# The previous version of this if condition
170170
# ((uniques[1:] - uniques[:-1]) == 1).all():
171171
# does not work when group_idx is [1, 2] for e.g.
@@ -257,7 +257,7 @@ def ffill(group_idx, array, *, axis, **kwargs):
257257
ndim = array.ndim
258258
assert axis == (ndim - 1), (axis, ndim - 1)
259259

260-
flag = np.concatenate((np.array([True], like=array), group_idx[1:] != group_idx[:-1]))
260+
flag = np.concatenate((np.asarray([True], like=group_idx), group_idx[1:] != group_idx[:-1]))
261261
(group_starts,) = flag.nonzero()
262262

263263
# https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array

flox/aggregate_sparse.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Unlike the other aggregate_* submodules, this one simply defines a wrapper function
2+
# because we run the groupby on the underlying dense data.
3+
4+
from functools import partial
5+
6+
import numpy as np
7+
import sparse
8+
9+
from flox.core import _factorize_multiple, _is_sparse_supported_reduction, factorize_
10+
from flox.xrdtypes import INF, NINF, _get_fill_value
11+
from flox.xrutils import notnull
12+
13+
14+
def nanadd(a, b):
15+
"""
16+
Annoyingly, there is no numpy ufunc for nan-skipping elementwise addition
17+
unlike np.fmin, np.fmax :(
18+
19+
From https://stackoverflow.com/a/50642947/1707127
20+
"""
21+
return np.where(np.isnan(a + b), np.where(np.isnan(a), b, a), a + b)
22+
23+
24+
BINARY_OPS = {
25+
"sum": np.add,
26+
"nansum": nanadd,
27+
"max": np.maximum,
28+
"nanmax": np.fmax,
29+
"min": np.minimum,
30+
"nanmin": np.fmin,
31+
}
32+
HYPER_OPS = {"sum": np.multiply, "nansum": np.multiply}
33+
IDENTITY = {
34+
"sum": 0,
35+
"nansum": 0,
36+
"prod": 1,
37+
"nanprod": 1,
38+
"max": NINF,
39+
"nanmax": NINF,
40+
"min": INF,
41+
"nanmin": INF,
42+
}
43+
44+
45+
def _sparse_agg(
46+
group_idx: np.ndarray,
47+
array: sparse.COO,
48+
func: str,
49+
engine: str,
50+
axis: int = -1,
51+
size: int | None = None,
52+
fill_value=None,
53+
dtype=None,
54+
**kwargs,
55+
):
56+
"""Wrapper function, that unwraps the underlying dense arrays, executes the groupby,
57+
and constructs the output sparse array."""
58+
from flox.aggregations import generic_aggregate
59+
60+
if not isinstance(array, sparse.COO):
61+
raise ValueError("Sparse aggregations only supported for sparse.COO arrays")
62+
63+
if not _is_sparse_supported_reduction(func):
64+
raise ValueError(f"{func} is unsupported for sparse arrays.")
65+
66+
group_idx_subset = group_idx[array.coords[axis, :]]
67+
if array.ndim > 1:
68+
new_by = tuple(array.coords[:axis, :]) + (group_idx_subset,)
69+
else:
70+
new_by = (group_idx_subset,)
71+
codes, groups, shape = _factorize_multiple(
72+
new_by, expected_groups=(None,) * len(new_by), any_by_dask=False
73+
)
74+
# factorize again so we can construct a sparse result
75+
sparse_codes, sparse_groups, sparse_shape, _, sparse_size, _ = factorize_(codes, axes=(0,))
76+
77+
dense_result = generic_aggregate(
78+
sparse_codes,
79+
array.data,
80+
func=func,
81+
engine=engine,
82+
dtype=dtype,
83+
size=sparse_size,
84+
fill_value=fill_value,
85+
)
86+
dense_counts = generic_aggregate(
87+
sparse_codes,
88+
array.data,
89+
# This counts is used to handle fill_value, so we need a count
90+
# of populated data, regardless of NaN value
91+
func="len",
92+
engine=engine,
93+
dtype=int,
94+
size=sparse_size,
95+
fill_value=0,
96+
)
97+
assert len(sparse_groups) == 1
98+
result_coords = np.stack(tuple(g[i] for g, i in zip(groups, np.unravel_index(*sparse_groups, shape))))
99+
100+
full_shape = array.shape[:-1] + (size,)
101+
count = sparse.COO(coords=result_coords, data=dense_counts, shape=full_shape, fill_value=0)
102+
103+
assert axis in (-1, array.ndim - 1)
104+
grouped_count = generic_aggregate(
105+
group_idx, group_idx, engine=engine, func="len", dtype=np.int64, size=size, fill_value=0
106+
)
107+
total_count = sparse.COO.from_numpy(
108+
np.expand_dims(grouped_count, tuple(range(array.ndim - 1))), fill_value=0
109+
)
110+
111+
assert func in BINARY_OPS
112+
binop = BINARY_OPS[func]
113+
ident = _get_fill_value(array.dtype, IDENTITY[func])
114+
diff_count = total_count - count
115+
if (hyper_op := HYPER_OPS.get(func, None)) is not None:
116+
fill = hyper_op(diff_count, array.fill_value) if (diff_count > 0).any() else ident
117+
else:
118+
if "max" in func or "min" in func:
119+
# Note that fill_value for total_count, and count is 0.
120+
# So the fill_value for the `fill` result is the False branch i.e. `ident`
121+
fill = np.where(diff_count > 0, array.fill_value, ident)
122+
else:
123+
raise NotImplementedError
124+
125+
result = sparse.COO(coords=result_coords, data=dense_result, shape=full_shape, fill_value=ident)
126+
with_fill = binop(result, fill)
127+
return with_fill
128+
129+
130+
def nanlen(
131+
group_idx: np.ndarray,
132+
array: sparse.COO,
133+
engine: str,
134+
axis: int = -1,
135+
size: int | None = None,
136+
fill_value=None,
137+
dtype=None,
138+
**kwargs,
139+
):
140+
new_array = sparse.COO(
141+
coords=array.coords,
142+
data=notnull(array.data),
143+
shape=array.shape,
144+
fill_value=notnull(array.fill_value),
145+
)
146+
return _sparse_agg(
147+
group_idx, new_array, func="sum", engine=engine, axis=axis, size=size, fill_value=0, dtype=dtype
148+
)
149+
150+
151+
def mean(
152+
group_idx: np.ndarray,
153+
array: sparse.COO,
154+
engine: str,
155+
axis: int = -1,
156+
size: int | None = None,
157+
fill_value=None,
158+
dtype=None,
159+
**kwargs,
160+
):
161+
sums = sum(
162+
group_idx, array, func="sum", engine=engine, axis=axis, size=size, fill_value=fill_value, dtype=dtype
163+
)
164+
counts = nanlen(
165+
group_idx, array, func="sum", engine=engine, axis=axis, size=size, fill_value=0, dtype=dtype
166+
)
167+
return sums / counts
168+
169+
170+
def nanmean(
171+
group_idx: np.ndarray,
172+
array: sparse.COO,
173+
engine: str,
174+
axis: int = -1,
175+
size: int | None = None,
176+
fill_value=None,
177+
dtype=None,
178+
**kwargs,
179+
):
180+
sums = sum(
181+
group_idx,
182+
array,
183+
func="nansum",
184+
engine=engine,
185+
axis=axis,
186+
size=size,
187+
fill_value=fill_value,
188+
dtype=dtype,
189+
)
190+
counts = nanlen(
191+
group_idx, array, func="sum", engine=engine, axis=axis, size=size, fill_value=0, dtype=dtype
192+
)
193+
return sums / counts
194+
195+
196+
sum = partial(_sparse_agg, func="sum")
197+
nansum = partial(_sparse_agg, func="nansum")
198+
max = partial(_sparse_agg, func="max")
199+
nanmax = partial(_sparse_agg, func="nanmax")
200+
min = partial(_sparse_agg, func="min")
201+
nanmin = partial(_sparse_agg, func="nanmin")

flox/aggregations.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from . import aggregate_flox, aggregate_npg, xrutils
1616
from . import xrdtypes as dtypes
17+
from .lib import sparse_array_type
1718

1819
if TYPE_CHECKING:
1920
FuncTuple = tuple[Callable | str, ...]
@@ -72,7 +73,14 @@ def generic_aggregate(
7273
if func in ["nanfirst", "nanlast"] and array.dtype.kind in "US":
7374
func = func[3:]
7475

75-
if engine == "flox":
76+
if is_sparse := isinstance(array, sparse_array_type):
77+
# this is not an infinite loop because aggregate_sparse will call
78+
# generic_aggregate with dense data
79+
from flox import aggregate_sparse
80+
81+
method = partial(getattr(aggregate_sparse, func), engine=engine)
82+
83+
elif engine == "flox":
7684
try:
7785
method = getattr(aggregate_flox, func)
7886
except AttributeError:
@@ -105,7 +113,9 @@ def generic_aggregate(
105113
f"Expected engine to be one of ['flox', 'numpy', 'numba', 'numbagg']. Received {engine} instead."
106114
)
107115

108-
group_idx = np.asarray(group_idx, like=array)
116+
# UGLY! but this avoids auto-densification errors
117+
if not is_sparse:
118+
group_idx = np.asarray(group_idx, like=array)
109119

110120
with warnings.catch_warnings():
111121
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")

0 commit comments

Comments
 (0)