diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index aff676674d0..6dcdc3c01b3 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -134,3 +134,15 @@ def push(array, n, axis, method="blelloch"): pushed_array = da.where(valid_positions, pushed_array, np.nan) return pushed_array + + +def topk(a, k, axis): + import dask.array as da + + return da.topk(a, k=k, axis=axis) + + +def argtopk(a, k, axis): + import dask.array as da + + return da.argtopk(a, k=k, axis=axis) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c245a0f9b55..484239b3c23 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6288,6 +6288,37 @@ def argmax( else: return self._replace_maybe_drop_dims(result) + def argtopk( + self, + k: int, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Self | dict[Hashable, Self]: + """ + TODO docstring + """ + result = self.variable.argtopk(k, dim, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) + + def topk( + self, + k: int, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Self: + """ + TODO docstring + """ + result = self.variable.topk(k, dim, keep_attrs, skipna) + return self._replace_maybe_drop_dims(result) + def query( self, queries: Mapping[Any, Any] | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1197a27d4d1..fb156b243cd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9727,6 +9727,26 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: "Dataset.argmin() with a sequence or ... for dim" ) + def argtopk(self, k: int, dim: Hashable | None = None, **kwargs) -> Self: + """ + TODO docstring + """ + from xarray.core.missing import _apply_over_vars_with_dim + + func = duck_array_ops.argtopk + new = _apply_over_vars_with_dim(func, self, dim=dim, k=k) + return new + + def topk(self, k: int, dim: Hashable | None = None, **kwargs) -> Self: + """ + TODO docstring + """ + from xarray.core.missing import _apply_over_vars_with_dim + + func = duck_array_ops.topk + new = _apply_over_vars_with_dim(func, self, dim=dim, k=k) + return new + def eval( self, statement: str, diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index faec5ded04e..8e1dd5a414b 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -27,7 +27,12 @@ from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils from xarray.core.array_api_compat import get_array_namespace from xarray.core.options import OPTIONS -from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available +from xarray.core.utils import ( + is_duck_array, + is_duck_dask_array, + module_available, + to_0d_object_array, +) from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import array_type, is_chunked_array @@ -875,3 +880,74 @@ def chunked_nanfirst(darray, axis): def chunked_nanlast(darray, axis): return _chunked_first_or_last(darray, axis, op=nputils.nanlast) + + +def argtopk(values, k, axis=None, skipna=None): + if is_chunked_array(values): + func = dask_array_ops.argtopk + else: + func = nputils.argtopk + + # Borrowed from nanops + xp = get_array_namespace(values) + if skipna or ( + skipna is None + and ( + dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp) + or dtypes.is_object(values.dtype) + ) + ): + valid_count = count(values, axis=axis) + + if k < 0: + fill_value = dtypes.get_pos_infinity(values.dtype) + else: + fill_value = dtypes.get_neg_infinity(values.dtype) + + filled_values = fillna(values, fill_value) + else: + return func(values, k=k, axis=axis) + + data = func(filled_values, k=k, axis=axis) + + # TODO This will evaluate dask arrays and might be costly. + if array_any(valid_count == 0): + raise ValueError("All-NaN slice encountered") + return data + + +def topk(values, k, axis=None, skipna=None): + if is_chunked_array(values): + func = dask_array_ops.topk + else: + func = nputils.topk + + # Borrowed from nanops + xp = get_array_namespace(values) + if skipna or ( + skipna is None + and ( + dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp) + or dtypes.is_object(values.dtype) + ) + ): + valid_count = count(values, axis=axis) + + if k < 0: + fill_value = dtypes.get_pos_infinity(values.dtype) + else: + fill_value = dtypes.get_neg_infinity(values.dtype) + + filled_values = fillna(values, fill_value) + else: + return func(values, k=k, axis=axis) + + data = func(filled_values, k=k, axis=axis) + + if not hasattr(data, "dtype"): # scalar case + data = fill_value if valid_count == 0 else data + # we've computed a single min, max value of type object. + # don't let np.array turn a tuple back into an array + return to_0d_object_array(data) + + return where_method(data, valid_count != 0) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 3211ab296e6..da28a6c5fa1 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -302,6 +302,59 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return coeffs, residuals +def topk(values, k: int, axis: int): + """Extract the k largest elements from a on the given axis. + If k is negative, extract the -k smallest elements instead. + The returned elements are sorted. + """ + if axis < 0: + axis = values.ndim + axis + + if abs(k) >= values.shape[axis]: + b = np.sort(values, axis=axis) + else: + a = np.partition(values, -k, axis=axis) + k_slice = slice(-k, None) if k > 0 else slice(-k) + b = a[tuple(k_slice if i == axis else slice(None) for i in range(values.ndim))] + b.sort(axis=axis) + if k < 0: + return b + return b[ + tuple( + slice(None, None, -1) if i == axis else slice(None) + for i in range(values.ndim) + ) + ] + + +def argtopk(values, k: int, axis: int): + """Extract the indices of the k largest elements from a on the given axis. + If k is negative, extract the indices of the -k smallest elements instead. + The returned elements are argsorted. + """ + if axis < 0: + axis = values.ndim + axis + + if abs(k) >= values.shape[axis]: + idx3 = np.argsort(values, axis=axis) + else: + idx = np.argpartition(values, -k, axis=axis) + k_slice = slice(-k, None) if k > 0 else slice(-k) + idx = idx[ + tuple(k_slice if i == axis else slice(None) for i in range(values.ndim)) + ] + a = np.take_along_axis(values, idx, axis) + idx2 = np.argsort(a, axis=axis) + idx3 = np.take_along_axis(idx, idx2, axis) + if k < 0: + return idx3 + return idx3[ + tuple( + slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim) + ) + ] + + nanmin = _create_method("nanmin") nanmax = _create_method("nanmax") nanmean = _create_method("nanmean") diff --git a/xarray/core/variable.py b/xarray/core/variable.py index ed860dc0e6b..376a84ea5cf 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2511,6 +2511,146 @@ def argmax( """ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + def _topk_stack( + self, + topk_funcname: str, + dim: Dims, + ) -> Variable: + # Get a name for the new dimension that does not conflict with any existing + # dimension + newdimname = f"_unravel_{topk_funcname}_dim_0" + count = 1 + while newdimname in self.dims: + newdimname = f"_unravel_{topk_funcname}_dim_{count}" + count += 1 + return self.stack({newdimname: dim}) + + def _topk_helper( + self, + topk_funcname: str, + k: int, + dim: str, + dtype: Any, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Variable: + from xarray.core.computation import apply_ufunc + + topk_func = getattr(duck_array_ops, topk_funcname) + # apply_ufunc moves the dimension to the back. + kwargs = {"k": k, "axis": -1, "skipna": skipna} + + result = apply_ufunc( + topk_func, + self, + input_core_dims=[[dim]], + exclude_dims={dim}, + output_core_dims=[[topk_funcname]], + output_dtypes=[dtype], + dask_gufunc_kwargs=dict(output_sizes={topk_funcname: k}), + dask="allowed", + kwargs=kwargs, + ) + + keep_attrs_ = ( + _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs + ) + + if keep_attrs_: + result.attrs = self._attrs + return result + + def topk( + self, + k: int, + dim: Dims = None, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Variable | dict[Hashable, Variable]: + """ + TODO docstring + """ + # topk accepts only an integer axis like argmin or argmax, + # not tuples, so we need to stack multiple dimensions. + if dim is ... or dim is None: + # Return dimension for 1D data. + if self.ndim == 1: + dim = self.dims[0] + else: + dim = self.dims + + if isinstance(dim, str): + stacked = self + else: + stacked = self._topk_stack("topk", dim) + dim = stacked.dims[-1] + + result = stacked._topk_helper( + "topk", k=k, dim=dim, dtype=self.dtype, keep_attrs=keep_attrs, skipna=skipna + ) + return result + + def argtopk( + self, + k: int, + dim: Dims = None, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Variable | dict[Hashable, Variable]: + """ + TODO docstring + """ + # argtopk accepts only an integer axis like argmin or argmax, + # not tuples, so we need to stack multiple dimensions. + if dim is ... or dim is None: + # Return dimension for 1D data. + if self.ndim == 1: + dim = self.dims[0] + else: + dim = self.dims + + if isinstance(dim, str): + return self._topk_helper( + "argtopk", + k=k, + dim=dim, + dtype=np.intp, + keep_attrs=keep_attrs, + skipna=skipna, + ) + + stacked = self._topk_stack("topk", dim) + newdimname = stacked.dims[-1] + + result_flat_indices = stacked._topk_helper( + "argtopk", + k=k, + dim=newdimname, + dtype=np.intp, + keep_attrs=keep_attrs, + skipna=skipna, + ) + + reduce_shape = tuple(self.sizes[d] for d in dim) + + result_unravelled_indices = duck_array_ops.unravel_index( + result_flat_indices.data, reduce_shape + ) + + result_dims = [d for d in stacked.dims if d != newdimname] + ["argtopk"] + result = { + d: Variable(dims=result_dims, data=i) + for d, i in zip(dim, result_unravelled_indices, strict=True) + } + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + for v in result.values(): + v.attrs = self.attrs + + return result + def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: """ Use sparse-array as backend. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 75d6d919e19..d420854369a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5283,6 +5283,67 @@ def test_argmax_dim( for key in expected2: assert_identical(result2[key], expected2[key]) + def test_topk( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + ) -> None: + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + + # Test top 1 (should match max) + # Select using the indices + expected0 = ar.max().expand_dims("topk") + result0 = ar.topk(k=1) + assert_identical(result0, expected0) + + # Test keep_attrs + result1 = ar.topk(k=1, keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = ar.attrs + assert_identical(result1, expected1) + + # Test top -1 (should match min) + expected2 = ar.min().expand_dims("topk") + result2 = ar.topk(k=-1) + assert_identical(result2, expected2) + + # Test skipna=False + result3 = ar.topk(k=1, skipna=False) + expected3 = ar.max(skipna=False).expand_dims("topk") + assert_identical(result3, expected3) + + result4 = ar.topk(k=-1, skipna=False) + expected4 = ar.min(skipna=False).expand_dims("topk") + assert_identical(result4, expected4) + + def create_expected(ar, k): + if k < 0: + _sorted = ar.sortby(ar) + else: + _sorted = ar.sortby(-ar) + selected = _sorted.isel(x=slice(0, k)).drop_vars("x").rename({"x": "topk"}) + selected.attrs = {} + return selected + + k = 2 + result5 = ar.topk(k=k) + expected5 = create_expected(ar, k=k) + assert_identical(result5, expected5) + + k = x.size + result6 = ar.topk(k=k) + expected6 = create_expected(ar, k=k) + assert_identical(result6, expected6) + + k = x.size + 1 + result7 = ar.topk(k=k) + expected7 = create_expected(ar, k=k) + assert_identical(result7, expected7) + @pytest.mark.parametrize( ["x", "minindex", "maxindex", "nanindex"],