Skip to content

Commit 11feda2

Browse files
authored
Fix sparse reindexing some more. (#437)
* Allow empty groups with sparse reindexing * Fix sparse reindexing * fix docs * more test * Fix tests * Nicer error * test for errors
1 parent 89e8238 commit 11feda2

File tree

4 files changed

+112
-21
lines changed

4 files changed

+112
-21
lines changed

docs/source/user-stories/large-zonal-stats.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
" blockwise=False,\n",
162162
" array_type=ReindexArrayType.SPARSE_COO,\n",
163163
" ),\n",
164+
" fill_value=0,\n",
164165
")\n",
165166
"result"
166167
]

flox/core.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
6969

7070
HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")
71+
HAS_SPARSE = module_available("sparse")
7172

7273
if TYPE_CHECKING:
7374
try:
@@ -255,6 +256,12 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool:
255256
)
256257

257258

259+
def _is_sparse_supported_reduction(func: T_Agg) -> bool:
260+
if isinstance(func, Aggregation):
261+
func = func.name
262+
return HAS_SPARSE and all(f not in func for f in ["first", "last", "prod", "var", "std"])
263+
264+
258265
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
259266
if is_duck_dask_array(by):
260267
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -736,12 +743,12 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
736743
return array.rechunk({axis: newchunks})
737744

738745

739-
def reindex_numpy(array, from_, to, fill_value, dtype, axis):
746+
def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
740747
idx = from_.get_indexer(to)
741748
indexer = [slice(None, None)] * array.ndim
742749
indexer[axis] = idx
743750
reindexed = array[tuple(indexer)]
744-
if any(idx == -1):
751+
if (idx == -1).any():
745752
if fill_value is None:
746753
raise ValueError("Filling is required. fill_value cannot be None.")
747754
indexer[axis] = idx == -1
@@ -750,25 +757,43 @@ def reindex_numpy(array, from_, to, fill_value, dtype, axis):
750757
return reindexed
751758

752759

753-
def reindex_pydata_sparse_coo(array, from_, to, fill_value, dtype, axis):
760+
def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
754761
import sparse
755762

756763
assert axis == -1
757764

758-
if fill_value is None:
765+
needs_reindex = (from_.get_indexer(to) == -1).any()
766+
if needs_reindex and fill_value is None:
759767
raise ValueError("Filling is required. fill_value cannot be None.")
768+
760769
idx = to.get_indexer(from_)
761-
assert (idx != -1).all() # FIXME
770+
mask = idx != -1 # indices along last axis to keep
771+
if mask.all():
772+
mask = slice(None)
762773
shape = array.shape
763-
ranges = np.broadcast_arrays(*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx,))))
764-
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
765774

766-
data = array.data if isinstance(array, sparse.COO) else array.reshape(-1)
775+
if isinstance(array, sparse.COO):
776+
subset = array[..., mask]
777+
data = subset.data
778+
coords = subset.coords
779+
if subset.nnz > 0:
780+
coords[-1, :] = idx[mask][coords[-1, :]]
781+
if fill_value is None:
782+
# no reindexing is actually needed (dense case)
783+
# preserve the fill_value
784+
fill_value = array.fill_value
785+
else:
786+
ranges = np.broadcast_arrays(
787+
*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],)))
788+
)
789+
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
790+
data = array[..., mask].reshape(-1)
767791

768792
reindexed = sparse.COO(
769793
coords=coords,
770794
data=data.astype(dtype, copy=False),
771795
shape=(*array.shape[:axis], to.size),
796+
fill_value=fill_value,
772797
)
773798

774799
return reindexed
@@ -795,7 +820,11 @@ def reindex_(
795820

796821
if array.shape[axis] == 0:
797822
# all groups were NaN
798-
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
823+
shape = array.shape[:-1] + (len(to),)
824+
if array_type in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY):
825+
reindexed = np.full(shape, fill_value, dtype=array.dtype)
826+
else:
827+
raise NotImplementedError
799828
return reindexed
800829

801830
from_ = pd.Index(from_)
@@ -1044,7 +1073,7 @@ def chunk_argreduce(
10441073
sort=sort,
10451074
user_dtype=user_dtype,
10461075
)
1047-
if not isnull(results["groups"]).all():
1076+
if not all(isnull(results["groups"])):
10481077
idx = np.broadcast_to(idx, array.shape)
10491078

10501079
# array, by get flattened to 1D before passing to npg
@@ -1288,7 +1317,7 @@ def _finalize_results(
12881317
fill_value = agg.fill_value["user"]
12891318
if min_count > 0:
12901319
count_mask = counts < min_count
1291-
if count_mask.any():
1320+
if count_mask.any() or reindex.array_type is ReindexArrayType.SPARSE_COO:
12921321
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
12931322
# necessary
12941323
if fill_value is None:
@@ -2815,6 +2844,15 @@ def groupby_reduce(
28152844
array.dtype,
28162845
)
28172846

2847+
if reindex.array_type is ReindexArrayType.SPARSE_COO:
2848+
if not HAS_SPARSE:
2849+
raise ImportError("Package 'sparse' must be installed to reindex to a sparse.COO array.")
2850+
if not _is_sparse_supported_reduction(func):
2851+
raise NotImplementedError(
2852+
f"Aggregation {func=!r} is not supported when reindexing to a sparse array. "
2853+
"Please raise an issue"
2854+
)
2855+
28182856
if TYPE_CHECKING:
28192857
assert isinstance(reindex, ReindexStrategy)
28202858
assert method is not None

flox/xrutils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def notnull(data):
159159
return out
160160

161161

162-
def isnull(data):
162+
def isnull(data: Any):
163+
if data is None:
164+
return False
163165
if not is_duck_array(data):
164166
data = np.asarray(data)
165167
scalar_type = data.dtype.type
@@ -177,7 +179,7 @@ def isnull(data):
177179
else:
178180
# at this point, array should have dtype=object
179181
if isinstance(data, (np.ndarray, dask_array_type)): # noqa
180-
return pd.isnull(data)
182+
return pd.isnull(data) # type: ignore[arg-type]
181183
else:
182184
# Not reachable yet, but intended for use with other duck array
183185
# types. For full consistency with pandas, we should accept None as
@@ -374,9 +376,10 @@ def _select_along_axis(values, idx, axis):
374376
def nanfirst(values, axis, keepdims=False):
375377
if isinstance(axis, tuple):
376378
(axis,) = axis
377-
values = np.asarray(values)
379+
if not is_duck_array(values):
380+
values = np.asarray(values)
378381
axis = normalize_axis_index(axis, values.ndim)
379-
idx_first = np.argmax(~pd.isnull(values), axis=axis)
382+
idx_first = np.argmax(~isnull(values), axis=axis)
380383
result = _select_along_axis(values, idx_first, axis)
381384
if keepdims:
382385
return np.expand_dims(result, axis=axis)
@@ -387,10 +390,11 @@ def nanfirst(values, axis, keepdims=False):
387390
def nanlast(values, axis, keepdims=False):
388391
if isinstance(axis, tuple):
389392
(axis,) = axis
390-
values = np.asarray(values)
393+
if not is_duck_array(values):
394+
values = np.asarray(values)
391395
axis = normalize_axis_index(axis, values.ndim)
392396
rev = (slice(None),) * axis + (slice(None, None, -1),)
393-
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
397+
idx_last = -1 - np.argmax(~isnull(values)[rev], axis=axis)
394398
result = _select_along_axis(values, idx_last, axis)
395399
if keepdims:
396400
return np.expand_dims(result, axis=axis)

tests/test_core.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_choose_engine,
2525
_convert_expected_groups_to_index,
2626
_get_optimal_chunks_for_groups,
27+
_is_sparse_supported_reduction,
2728
_normalize_indexes,
2829
_validate_reindex,
2930
factorize_,
@@ -43,6 +44,7 @@
4344
assert_equal_tuple,
4445
has_cubed,
4546
has_dask,
47+
has_sparse,
4648
raise_if_dask_computes,
4749
requires_cubed,
4850
requires_dask,
@@ -74,6 +76,10 @@ def dask_array_ones(*args):
7476

7577

7678
DEFAULT_QUANTILE = 0.9
79+
REINDEX_SPARSE_STRAT = ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)
80+
REINDEX_SPARSE_PARAM = pytest.param(
81+
REINDEX_SPARSE_STRAT, marks=(requires_dask, pytest.mark.skipif(not has_sparse, reason="no sparse"))
82+
)
7783

7884
if TYPE_CHECKING:
7985
from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method
@@ -320,13 +326,20 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
320326
if not has_dask or chunks is None or func in BLOCKWISE_FUNCS:
321327
continue
322328

323-
params = list(itertools.product(["map-reduce"], [True, False, None]))
329+
params = list(
330+
itertools.product(
331+
["map-reduce"],
332+
[True, False, None, REINDEX_SPARSE_STRAT],
333+
)
334+
)
324335
params.extend(itertools.product(["cohorts"], [False, None]))
325336
if chunks == -1:
326337
params.extend([("blockwise", None)])
327338

328339
combine_error = RuntimeError("This combine should not have been called.")
329340
for method, reindex in params:
341+
if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func):
342+
continue
330343
call = partial(
331344
groupby_reduce,
332345
array,
@@ -360,6 +373,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
360373
assert_equal(actual_group, expect, tolerance)
361374
if "arg" in func:
362375
assert actual.dtype.kind == "i"
376+
if isinstance(reindex, ReindexStrategy):
377+
import sparse
378+
379+
expected = sparse.COO.from_numpy(expected)
363380
assert_equal(actual, expected, tolerance)
364381

365382

@@ -447,7 +464,7 @@ def test_numpy_reduce_nd_md():
447464

448465

449466
@requires_dask
450-
@pytest.mark.parametrize("reindex", [None, False, True])
467+
@pytest.mark.parametrize("reindex", [None, False, True, REINDEX_SPARSE_PARAM])
451468
@pytest.mark.parametrize("func", ALL_FUNCS)
452469
@pytest.mark.parametrize("add_nan", [False, True])
453470
@pytest.mark.parametrize("dtype", (float,))
@@ -470,6 +487,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
470487
if "arg" in func and (engine in ["flox", "numbagg"] or reindex):
471488
pytest.skip()
472489

490+
if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func):
491+
pytest.skip()
492+
473493
rng = np.random.default_rng(12345)
474494
array = dask.array.from_array(rng.random(shape), chunks=array_chunks).astype(dtype)
475495
array = dask.array.ones(shape, chunks=array_chunks)
@@ -775,6 +795,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
775795
(None, None),
776796
pytest.param(False, (2, 2, 3), marks=requires_dask),
777797
pytest.param(True, (2, 2, 3), marks=requires_dask),
798+
pytest.param(REINDEX_SPARSE_PARAM, (2, 2, 3), marks=requires_dask),
778799
],
779800
)
780801
@pytest.mark.parametrize(
@@ -821,7 +842,13 @@ def _maybe_chunk(arr):
821842
@requires_dask
822843
@pytest.mark.parametrize(
823844
"expected_groups, reindex",
824-
[(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)],
845+
[
846+
(None, None),
847+
(None, False),
848+
([0, 1, 2], True),
849+
([0, 1, 2], False),
850+
pytest.param([0, 1, 2], REINDEX_SPARSE_PARAM),
851+
],
825852
)
826853
def test_groupby_all_nan_blocks_dask(expected_groups, reindex, engine):
827854
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
@@ -2085,7 +2112,28 @@ def mocked_reindex(*args, **kwargs):
20852112

20862113
with patch("flox.core.reindex_") as mocked_func:
20872114
mocked_func.side_effect = mocked_reindex
2088-
actual, *_ = groupby_reduce(array, by, func=func, reindex=reindex, expected_groups=expected_groups)
2115+
actual, *_ = groupby_reduce(
2116+
array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0
2117+
)
20892118
assert_equal(actual, expected)
20902119
# once during graph construction, 10 times afterward
20912120
assert mocked_func.call_count > 1
2121+
2122+
2123+
def test_sparse_errors():
2124+
call = partial(
2125+
groupby_reduce,
2126+
[1, 2, 3],
2127+
[0, 1, 1],
2128+
reindex=REINDEX_SPARSE_STRAT,
2129+
fill_value=0,
2130+
expected_groups=[0, 1, 2],
2131+
)
2132+
2133+
if not has_sparse:
2134+
with pytest.raises(ImportError):
2135+
call(func="sum")
2136+
2137+
else:
2138+
with pytest.raises(ValueError):
2139+
call(func="first")

0 commit comments

Comments
 (0)