Skip to content

Commit 2da7f55

Browse files
authored
Use flox for grouped first, last. (#10148)
* Optimize grouped first, last. 1. Use flox where possible. Closes #9647 * simplify * add whats-new * typing * more typing * Fix * docstrings and types
1 parent 3ca8824 commit 2da7f55

File tree

5 files changed

+102
-20
lines changed

5 files changed

+102
-20
lines changed

doc/whats-new.rst

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ New Features
4141
- Support reading to `GPU memory with Zarr <https://zarr.readthedocs.io/en/stable/user-guide/gpu.html>`_ (:pull:`10078`).
4242
By `Deepak Cherian <https://github.com/dcherian>`_.
4343

44+
Performance
45+
~~~~~~~~~~~
46+
- :py:meth:`DatasetGroupBy.first` and :py:meth:`DatasetGroupBy.last` can now use ``flox`` if available. (:issue:`9647`)
47+
By `Deepak Cherian <https://github.com/dcherian>`_.
48+
4449
Breaking changes
4550
~~~~~~~~~~~~~~~~
4651
- Rolled back code that would attempt to catch integer overflow when encoding
@@ -174,8 +179,6 @@ New Features
174179
:py:class:`pandas.DatetimeIndex` (:pull:`9965`). By `Spencer Clark
175180
<https://github.com/spencerkclark>`_ and `Kai Mühlbauer
176181
<https://github.com/kmuehlbauer>`_.
177-
- :py:meth:`DatasetGroupBy.first` and :py:meth:`DatasetGroupBy.last` can now use ``flox`` if available. (:issue:`9647`)
178-
By `Deepak Cherian <https://github.com/dcherian>`_.
179182

180183
Breaking changes
181184
~~~~~~~~~~~~~~~~

xarray/core/groupby.py

+66-11
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ def _flox_reduce(
993993
dim: Dims,
994994
keep_attrs: bool | None = None,
995995
**kwargs: Any,
996-
):
996+
) -> T_Xarray:
997997
"""Adaptor function that translates our groupby API to that of flox."""
998998
import flox
999999
from flox.xarray import xarray_reduce
@@ -1116,6 +1116,8 @@ def _flox_reduce(
11161116
# flox always assigns an index so we must drop it here if we don't need it.
11171117
to_drop.append(grouper.name)
11181118
continue
1119+
# TODO: We can't simply use `self.encoded.coords` here because it corresponds to `unique_coord`,
1120+
# NOT `full_index`. We would need to construct a new Coordinates object, that corresponds to `full_index`.
11191121
new_coords.append(
11201122
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
11211123
# all associated levels properly.
@@ -1361,7 +1363,12 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray:
13611363
"""
13621364
return ops.where_method(self, cond, other)
13631365

1364-
def _first_or_last(self, op, skipna, keep_attrs):
1366+
def _first_or_last(
1367+
self,
1368+
op: Literal["first" | "last"],
1369+
skipna: bool | None,
1370+
keep_attrs: bool | None,
1371+
):
13651372
if all(
13661373
isinstance(maybe_slice, slice)
13671374
and (maybe_slice.stop == maybe_slice.start + 1)
@@ -1372,17 +1379,65 @@ def _first_or_last(self, op, skipna, keep_attrs):
13721379
return self._obj
13731380
if keep_attrs is None:
13741381
keep_attrs = _get_keep_attrs(default=True)
1375-
return self.reduce(
1376-
op, dim=[self._group_dim], skipna=skipna, keep_attrs=keep_attrs
1377-
)
1382+
if (
1383+
module_available("flox", minversion="0.10.0")
1384+
and OPTIONS["use_flox"]
1385+
and contains_only_chunked_or_numpy(self._obj)
1386+
):
1387+
result = self._flox_reduce(
1388+
dim=None, func=op, skipna=skipna, keep_attrs=keep_attrs
1389+
)
1390+
else:
1391+
result = self.reduce(
1392+
getattr(duck_array_ops, op),
1393+
dim=[self._group_dim],
1394+
skipna=skipna,
1395+
keep_attrs=keep_attrs,
1396+
)
1397+
return result
13781398

1379-
def first(self, skipna: bool | None = None, keep_attrs: bool | None = None):
1380-
"""Return the first element of each group along the group dimension"""
1381-
return self._first_or_last(duck_array_ops.first, skipna, keep_attrs)
1399+
def first(
1400+
self, skipna: bool | None = None, keep_attrs: bool | None = None
1401+
) -> T_Xarray:
1402+
"""
1403+
Return the first element of each group along the group dimension
13821404
1383-
def last(self, skipna: bool | None = None, keep_attrs: bool | None = None):
1384-
"""Return the last element of each group along the group dimension"""
1385-
return self._first_or_last(duck_array_ops.last, skipna, keep_attrs)
1405+
Parameters
1406+
----------
1407+
skipna : bool or None, optional
1408+
If True, skip missing values (as marked by NaN). By default, only
1409+
skips missing values for float dtypes; other dtypes either do not
1410+
have a sentinel missing value (int) or ``skipna=True`` has not been
1411+
implemented (object, datetime64 or timedelta64).
1412+
keep_attrs : bool or None, optional
1413+
If True, ``attrs`` will be copied from the original
1414+
object to the new one. If False, the new object will be
1415+
returned without attributes.
1416+
1417+
"""
1418+
return self._first_or_last("first", skipna, keep_attrs)
1419+
1420+
def last(
1421+
self, skipna: bool | None = None, keep_attrs: bool | None = None
1422+
) -> T_Xarray:
1423+
"""
1424+
Return the last element of each group along the group dimension
1425+
1426+
Parameters
1427+
----------
1428+
skipna : bool or None, optional
1429+
If True, skip missing values (as marked by NaN). By default, only
1430+
skips missing values for float dtypes; other dtypes either do not
1431+
have a sentinel missing value (int) or ``skipna=True`` has not been
1432+
implemented (object, datetime64 or timedelta64).
1433+
keep_attrs : bool or None, optional
1434+
If True, ``attrs`` will be copied from the original
1435+
object to the new one. If False, the new object will be
1436+
returned without attributes.
1437+
1438+
1439+
"""
1440+
return self._first_or_last("last", skipna, keep_attrs)
13861441

13871442
def assign_coords(self, coords=None, **coords_kwargs):
13881443
"""Assign coordinates by group.

xarray/core/resample.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from collections.abc import Callable, Hashable, Iterable, Sequence
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, Literal
66

77
from xarray.core._aggregations import (
88
DataArrayResampleAggregations,
@@ -55,8 +55,11 @@ def _flox_reduce(
5555
keep_attrs: bool | None = None,
5656
**kwargs,
5757
) -> T_Xarray:
58-
result = super()._flox_reduce(dim=dim, keep_attrs=keep_attrs, **kwargs)
59-
result = result.rename({RESAMPLE_DIM: self._group_dim})
58+
result: T_Xarray = (
59+
super()
60+
._flox_reduce(dim=dim, keep_attrs=keep_attrs, **kwargs)
61+
.rename({RESAMPLE_DIM: self._group_dim}) # type: ignore[assignment]
62+
)
6063
return result
6164

6265
def shuffle_to_chunks(self, chunks: T_Chunks = None):
@@ -103,6 +106,21 @@ def shuffle_to_chunks(self, chunks: T_Chunks = None):
103106
(grouper,) = self.groupers
104107
return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM)
105108

109+
def _first_or_last(
110+
self, op: Literal["first", "last"], skipna: bool | None, keep_attrs: bool | None
111+
) -> T_Xarray:
112+
from xarray.core.dataset import Dataset
113+
114+
result = super()._first_or_last(op=op, skipna=skipna, keep_attrs=keep_attrs)
115+
if isinstance(result, Dataset):
116+
# Can't do this in the base class because group_dim is RESAMPLE_DIM
117+
# which is not present in the original object
118+
for var in result.data_vars:
119+
result._variables[var] = result._variables[var].transpose(
120+
*self._obj._variables[var].dims
121+
)
122+
return result
123+
106124
def _drop_coords(self) -> T_Xarray:
107125
"""Drop non-dimension coordinates along the resampled dimension."""
108126
obj = self._obj

xarray/groupers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,11 @@ def _factorize_unique(self) -> EncodedGroups:
242242
unique_coord = Variable(
243243
dims=codes.name, data=unique_values, attrs=self.group.attrs
244244
)
245-
full_index = pd.Index(unique_values)
245+
full_index = (
246+
unique_values
247+
if isinstance(unique_values, pd.MultiIndex)
248+
else pd.Index(unique_values)
249+
)
246250

247251
return EncodedGroups(
248252
codes=codes,

xarray/tests/test_groupby.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,8 @@ def test_groupby_first_and_last(self) -> None:
16181618
expected = array # should be a no-op
16191619
assert_identical(expected, actual)
16201620

1621+
# TODO: groupby_bins too
1622+
16211623
def make_groupby_multidim_example_array(self) -> DataArray:
16221624
return DataArray(
16231625
[[[0, 1], [2, 3]], [[5, 10], [15, 20]]],
@@ -2374,13 +2376,13 @@ def test_resample_and_first(self) -> None:
23742376
# upsampling
23752377
expected_time = pd.date_range("2000-01-01", freq="3h", periods=19)
23762378
expected = ds.reindex(time=expected_time)
2377-
actual = ds.resample(time="3h")
2379+
rs = ds.resample(time="3h")
23782380
for how in ["mean", "sum", "first", "last"]:
2379-
method = getattr(actual, how)
2381+
method = getattr(rs, how)
23802382
result = method()
23812383
assert_equal(expected, result)
23822384
for method in [np.mean]:
2383-
result = actual.reduce(method)
2385+
result = rs.reduce(method)
23842386
assert_equal(expected, result)
23852387

23862388
def test_resample_min_count(self) -> None:

0 commit comments

Comments
 (0)