Skip to content

Commit 71c9538

Browse files
authored
Support multiple quantiles with xarray (#332)
* Support multiple quantiles with xarray * Fix test * type: ignore * Bug fix * Another bug fix * Fix typing * Bugfix and cleanup * fix typing * Another bug fix * More xarray testing * comment * xfail test
1 parent b12bcfa commit 71c9538

File tree

6 files changed

+144
-47
lines changed

6 files changed

+144
-47
lines changed

flox/aggregate_flox.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,12 @@ def quantile_(array, inv_idx, *, q, axis, skipna, dtype=None, out=None):
6060
if skipna:
6161
sizes = np.add.reduceat(notnull(array), inv_idx[:-1], axis=axis)
6262
else:
63-
sizes = np.reshape(np.diff(inv_idx), (1,) * (array.ndim - 1) + (inv_idx.size - 1,))
64-
nanmask = isnull(np.take_along_axis(array, sizes - 1, axis=axis))
63+
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
64+
sizes = np.reshape(np.diff(inv_idx), newshape)
65+
# NaNs get sorted to the end, so look at the last element in the group to decide
66+
# if there are NaNs
67+
last_group_elem = np.broadcast_to(inv_idx[1:] - 1, newshape)
68+
nanmask = isnull(np.take_along_axis(array, last_group_elem, axis=axis))
6569

6670
qin = q
6771
q = np.atleast_1d(qin)

flox/aggregations.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import copy
44
import logging
55
import warnings
6-
from functools import partial
6+
from dataclasses import dataclass
7+
from functools import cached_property, partial
78
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
89

910
import numpy as np
10-
from numpy.typing import DTypeLike
11+
from numpy.typing import ArrayLike, DTypeLike
1112

1213
from . import aggregate_flox, aggregate_npg, xrutils
1314
from . import xrdtypes as dtypes
@@ -151,6 +152,20 @@ def returns_empty_tuple(*args, **kwargs):
151152
return ()
152153

153154

155+
@dataclass
156+
class Dim:
157+
values: ArrayLike
158+
name: str | None
159+
160+
@cached_property
161+
def is_scalar(self) -> bool:
162+
return xrutils.is_scalar(self.values)
163+
164+
@cached_property
165+
def size(self) -> int:
166+
return 0 if self.is_scalar else len(self.values) # type: ignore[arg-type]
167+
168+
154169
class Aggregation:
155170
def __init__(
156171
self,
@@ -166,7 +181,7 @@ def __init__(
166181
dtypes=None,
167182
final_dtype: DTypeLike | None = None,
168183
reduction_type: Literal["reduce", "argreduce"] = "reduce",
169-
new_axes_func: Callable | None = None,
184+
new_dims_func: Callable | None = None,
170185
):
171186
"""
172187
Blueprint for computing grouped aggregations.
@@ -209,7 +224,7 @@ def __init__(
209224
per reduction in ``chunk`` as a tuple.
210225
final_dtype : DType, optional
211226
DType for output. By default, uses dtype of array being reduced.
212-
new_axes_func: Callable
227+
new_dims_func: Callable
213228
Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
214229
added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
215230
so returns (2,)
@@ -246,12 +261,17 @@ def __init__(
246261
# The following are set by _initialize_aggregation
247262
self.finalize_kwargs: dict[Any, Any] = {}
248263
self.min_count: int = 0
249-
self.new_axes_func: Callable = (
250-
returns_empty_tuple if new_axes_func is None else new_axes_func
264+
self.new_dims_func: Callable = (
265+
returns_empty_tuple if new_dims_func is None else new_dims_func
251266
)
252267

253-
def get_new_axes(self):
254-
return self.new_axes_func(**self.finalize_kwargs)
268+
@cached_property
269+
def new_dims(self) -> tuple[Dim]:
270+
return self.new_dims_func(**self.finalize_kwargs)
271+
272+
@cached_property
273+
def num_new_vector_dims(self) -> int:
274+
return len(tuple(dim for dim in self.new_dims if not dim.is_scalar))
255275

256276
def _normalize_dtype_fill_value(self, value, name):
257277
value = _atleast_1d(value)
@@ -511,8 +531,8 @@ def _pick_second(*x):
511531
)
512532

513533

514-
def quantile_new_axes_func(q):
515-
return tuple() if xrutils.is_scalar(q) else (len(q),)
534+
def quantile_new_dims_func(q) -> tuple[Dim]:
535+
return (Dim(name="quantile", values=q),)
516536

517537

518538
quantile = Aggregation(
@@ -521,15 +541,15 @@ def quantile_new_axes_func(q):
521541
chunk=None,
522542
combine=None,
523543
final_dtype=np.float64,
524-
new_axes_func=quantile_new_axes_func,
544+
new_dims_func=quantile_new_dims_func,
525545
)
526546
nanquantile = Aggregation(
527547
name="nanquantile",
528548
fill_value=dtypes.NA,
529549
chunk=None,
530550
combine=None,
531551
final_dtype=np.float64,
532-
new_axes_func=quantile_new_axes_func,
552+
new_dims_func=quantile_new_dims_func,
533553
)
534554
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
535555
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
@@ -638,9 +658,10 @@ def _initialize_aggregation(
638658
# where the identity element is 0, 1
639659
if min_count > 0:
640660
agg.min_count = min_count
641-
agg.chunk += ("nanlen",)
642661
agg.numpy += ("nanlen",)
643-
agg.combine += ("sum",)
662+
if agg.chunk != (None,):
663+
agg.chunk += ("nanlen",)
664+
agg.combine += ("sum",)
644665
agg.fill_value["intermediate"] += (0,)
645666
agg.fill_value["numpy"] += (0,)
646667
agg.dtype["intermediate"] += (np.intp,)

flox/core.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
_atleast_1d,
3535
_initialize_aggregation,
3636
generic_aggregate,
37-
quantile_new_axes_func,
37+
quantile_new_dims_func,
3838
)
3939
from .cache import memoize
4040
from .xrutils import (
@@ -1006,7 +1006,9 @@ def chunk_reduce(
10061006
result = result[..., :-1]
10071007
# TODO: Figure out how to generalize this
10081008
if reduction in ("quantile", "nanquantile"):
1009-
new_dims_shape = quantile_new_axes_func(**kw)
1009+
new_dims_shape = tuple(
1010+
dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar
1011+
)
10101012
else:
10111013
new_dims_shape = tuple()
10121014
result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape)
@@ -1044,7 +1046,7 @@ def _finalize_results(
10441046
3. Mask using counts and fill with user-provided fill_value.
10451047
4. reindex to expected_groups
10461048
"""
1047-
squeezed = _squeeze_results(results, axis)
1049+
squeezed = _squeeze_results(results, tuple(agg.num_new_vector_dims + ax for ax in axis))
10481050

10491051
min_count = agg.min_count
10501052
if min_count > 0:
@@ -1671,7 +1673,7 @@ def dask_groupby_agg(
16711673
raise ValueError(f"Unknown method={method}.")
16721674

16731675
# Adjust output for any new dimensions added, example for multiple quantiles
1674-
new_dims_shape = agg.get_new_axes()
1676+
new_dims_shape = tuple(dim.size for dim in agg.new_dims if not dim.is_scalar)
16751677
new_inds = tuple(range(-len(new_dims_shape), 0))
16761678
out_inds = new_inds + inds[: -len(axis)] + (inds[-1],)
16771679
output_chunks = new_dims_shape + reduced.chunks[: -len(axis)] + group_chunks
@@ -2297,7 +2299,21 @@ def groupby_reduce(
22972299
# TODO: How else to narrow that array.chunks is there?
22982300
assert isinstance(array, DaskArray)
22992301

2300-
if agg.chunk[0] is None and method not in [None, "blockwise"]:
2302+
if (not any_by_dask and method is None) or method == "cohorts":
2303+
preferred_method, chunks_cohorts = find_group_cohorts(
2304+
by_,
2305+
[array.chunks[ax] for ax in range(-by_.ndim, 0)],
2306+
expected_groups=expected_,
2307+
# when provided with cohorts, we *always* 'merge'
2308+
merge=(method == "cohorts"),
2309+
)
2310+
else:
2311+
preferred_method = "map-reduce"
2312+
chunks_cohorts = {}
2313+
2314+
method = _choose_method(method, preferred_method, agg, by_, nax)
2315+
2316+
if agg.chunk[0] is None and method != "blockwise":
23012317
raise NotImplementedError(
23022318
f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'."
23032319
f"Received method={method!r}"
@@ -2318,19 +2334,6 @@ def groupby_reduce(
23182334
f"Received method={method!r}"
23192335
)
23202336

2321-
if (not any_by_dask and method is None) or method == "cohorts":
2322-
preferred_method, chunks_cohorts = find_group_cohorts(
2323-
by_,
2324-
[array.chunks[ax] for ax in range(-by_.ndim, 0)],
2325-
expected_groups=expected_,
2326-
# when provided with cohorts, we *always* 'merge'
2327-
merge=(method == "cohorts"),
2328-
)
2329-
else:
2330-
preferred_method = "map-reduce"
2331-
chunks_cohorts = {}
2332-
2333-
method = _choose_method(method, preferred_method, agg, by_, nax)
23342337
# TODO: clean this up
23352338
reindex = _validate_reindex(
23362339
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)

flox/xarray.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from packaging.version import Version
1010
from xarray.core.duck_array_ops import _datetime_nanmin
1111

12-
from .aggregations import Aggregation, _atleast_1d
12+
from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
1313
from .core import (
1414
_convert_expected_groups_to_index,
1515
_get_expected_groups,
@@ -74,7 +74,7 @@ def xarray_reduce(
7474
dim: Dims | ellipsis = None,
7575
fill_value=None,
7676
dtype: np.typing.DTypeLike = None,
77-
method: str = "map-reduce",
77+
method: str | None = None,
7878
engine: str | None = None,
7979
keep_attrs: bool | None = True,
8080
skipna: bool | None = None,
@@ -387,6 +387,17 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
387387

388388
result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
389389

390+
# Transpose the new quantile dimension to the end. This is ugly.
391+
# but new core dimensions are expected at the end :/
392+
# but groupby_reduce inserts them at the beginning
393+
if func in ["quantile", "nanquantile"]:
394+
(newdim,) = quantile_new_dims_func(**finalize_kwargs)
395+
if not newdim.is_scalar:
396+
# NOTE: _restore_dim_order will move any new dims to the end anyway.
397+
# This transpose is simply makes it easy to specify output_core_dims
398+
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
399+
result = np.moveaxis(result, 0, -1)
400+
390401
# Output of count has an int dtype.
391402
if requires_numeric and func != "count":
392403
if is_npdatetime:
@@ -412,8 +423,18 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
412423
input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)]
413424
input_core_dims += [list(b.dims) for b in by_da]
414425

426+
newdims: tuple[Dim, ...] = (
427+
quantile_new_dims_func(**finalize_kwargs) if func in ["quantile", "nanquantile"] else ()
428+
)
429+
415430
output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple]
416431
output_core_dims.extend(group_names)
432+
vector_dims = [dim.name for dim in newdims if not dim.is_scalar]
433+
output_core_dims.extend(vector_dims)
434+
435+
output_sizes = group_sizes
436+
output_sizes.update({dim.name: dim.size for dim in newdims if dim.size != 0})
437+
417438
actual = xr.apply_ufunc(
418439
wrapper,
419440
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
@@ -424,7 +445,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
424445
output_core_dims=[output_core_dims],
425446
dask="allowed",
426447
dask_gufunc_kwargs=dict(
427-
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
448+
output_sizes=output_sizes, output_dtypes=[dtype] if dtype is not None else None
428449
),
429450
keep_attrs=keep_attrs,
430451
kwargs={
@@ -451,6 +472,9 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
451472
if all(d not in ds_broad[var].dims for d in dim_tuple):
452473
actual[var] = ds_broad[var]
453474

475+
for newdim in newdims:
476+
actual.coords[newdim.name] = newdim.values if newdim.is_scalar else np.array(newdim.values)
477+
454478
expect3: T_ExpectIndex | np.ndarray
455479
for name, expect2, by_ in zip(group_names, expected_groups_valid_list, by_da):
456480
# Can't remove this until xarray handles IntervalIndex:
@@ -492,7 +516,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
492516
else:
493517
template = obj
494518

495-
if actual[var].ndim > 1:
519+
if actual[var].ndim > 1 + len(vector_dims):
496520
no_groupby_reorder = isinstance(
497521
obj, xr.Dataset
498522
) # do not re-order dataarrays inside datasets

tests/test_core.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
254254
fill_value = np.nan
255255
tolerance = {"rtol": 1e-14, "atol": 1e-16}
256256
elif "quantile" in func:
257-
finalize_kwargs = [{"q": DEFAULT_QUANTILE}]
257+
finalize_kwargs = [{"q": DEFAULT_QUANTILE}, {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]}]
258258
fill_value = None
259259
tolerance = None
260260
else:
@@ -265,6 +265,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
265265
array_func = _get_array_func(func)
266266

267267
for kwargs in finalize_kwargs:
268+
if "quantile" in func and isinstance(kwargs["q"], list) and engine != "flox":
269+
continue
268270
flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
269271
with np.errstate(invalid="ignore", divide="ignore"):
270272
with warnings.catch_warnings():
@@ -289,10 +291,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
289291

290292
if func in BLOCKWISE_FUNCS:
291293
assert chunks == -1
292-
flox_kwargs["method"] = "blockwise"
293294

294295
actual, *groups = groupby_reduce(array, *by, **flox_kwargs)
295-
assert actual.ndim == expected.ndim == (array.ndim + nby - 1)
296+
if "quantile" in func and isinstance(kwargs["q"], list):
297+
assert actual.ndim == expected.ndim == (array.ndim + nby)
298+
else:
299+
assert actual.ndim == expected.ndim == (array.ndim + nby - 1)
300+
296301
expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby))
297302
for actual_group, expect in zip(groups, expected_groups):
298303
assert_equal(actual_group, expect)
@@ -598,6 +603,15 @@ def test_nanfirst_nanlast_disallowed_dask(axis, func):
598603

599604

600605
@requires_dask
606+
@pytest.mark.xfail
607+
@pytest.mark.parametrize("func", ["first", "last"])
608+
def test_first_last_allowed_dask(func):
609+
# blockwise should be fine... but doesn't work now.
610+
groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=-1)
611+
612+
613+
@requires_dask
614+
@pytest.mark.xfail
601615
@pytest.mark.parametrize("func", ["first", "last"])
602616
def test_first_last_disallowed_dask(func):
603617
# blockwise is fine
@@ -1678,19 +1692,25 @@ def test_xarray_fill_value_behaviour():
16781692
assert_equal(expected, actual)
16791693

16801694

1681-
@pytest.mark.parametrize("q", (0.5, (0.5,), (0.5, 0.85)))
1695+
@pytest.mark.parametrize("q", (0.5, (0.5,), (0.5, 0.67, 0.85)))
16821696
@pytest.mark.parametrize("func", ["nanquantile", "quantile"])
16831697
@pytest.mark.parametrize("chunk", [pytest.param(True, marks=requires_dask), False])
1684-
def test_multiple_quantiles(q, chunk, func):
1698+
@pytest.mark.parametrize("by_ndim", [1, 2])
1699+
def test_multiple_quantiles(q, chunk, func, by_ndim):
16851700
array = np.array([[1, -1, np.nan, 3, 4, 10, 5], [1, np.nan, np.nan, 3, 4, np.nan, np.nan]])
16861701
labels = np.array([0, 0, 0, 1, 0, 1, 1])
1687-
axis = -1
1702+
if by_ndim == 2:
1703+
labels = np.broadcast_to(labels, (5, *labels.shape))
1704+
array = np.broadcast_to(np.expand_dims(array, -2), (2, 5, array.shape[-1]))
1705+
axis = tuple(range(-by_ndim, 0))
16881706

16891707
if chunk:
1690-
array = dask.array.from_array(array, chunks=(1, -1))
1708+
array = dask.array.from_array(array, chunks=(1,) + (-1,) * by_ndim)
16911709

16921710
actual, _ = groupby_reduce(array, labels, func=func, finalize_kwargs=dict(q=q), axis=axis)
16931711
sorted_array = array[..., [0, 1, 2, 4, 3, 5, 6]]
16941712
f = partial(getattr(np, func), q=q, axis=axis, keepdims=True)
1695-
expected = np.concatenate((f(sorted_array[..., :4]), f(sorted_array[..., 4:])), axis=axis)
1696-
assert_equal(expected, actual)
1713+
expected = np.concatenate((f(sorted_array[..., :4]), f(sorted_array[..., 4:])), axis=-1)
1714+
if by_ndim == 2:
1715+
expected = expected.squeeze(axis=-2)
1716+
assert_equal(expected, actual, tolerance=1e-14)

0 commit comments

Comments
 (0)