Skip to content

Generalized (arg)min, (arg)max: add nsmallest, nlargest, arg_nsmallest, arg_nlargest #10075

Open
@Huite

Description

@Huite

Is your feature request related to a problem?

I find that I need the (index of) N largest or N smallest values along some dimension with some regularity.

Describe the solution you'd like

Pandas provides nsmallest and nlargest:

Something similar would be useful for Xarray, I reckon, although just like there's argmin and argmax next to min and max, having arg_nsmallest and arg_nlargest (or something) would convenient as well.

It could match the existing method signatures, requiring an extra n argument:

    def nlargest(
        self,
        n: int,
        dim: Dims = None,
        *,
        skipna: bool | None = None,
        keep_attrs: bool | None = None,
        **kwargs: Any,
    ) -> Self:

The basic idea is to wrap numpy or bottleneck argpartition, I currently use this quick and dirty utility for a DataArray and a single dimension:

def arg_nsmallest(da: xr.DataArray, dim: str, n: int):
    """
    Return the index or indices of the ``n`` smallest values along dimension ``dim``.

    Parameters
    ----------
    da: xr.DataArray
    dim: str
        Dimension over which to find the ``n`` smallest values.
    n: int
        The number of items to retrieve.
    
    Returns
    -------
    result: xr.DataArray
    """
    # Find the axis over which to apply the partition.
    axis = da.dims.index(dim)

    # Set up output coordinates.
    dim_index = np.arange(n)
    coords = da.coords.copy()
    coords[dim] = dim_index
    shape = list(da.shape)
    shape[axis] = n
    template = xr.DataArray(
        data=dask.array.zeros(shape, dtype=int),
        coords=coords,
        dims=da.dims,
    )
    def _nsmallest(da: xr.DataArray):
        # NOTE: numpy (arg)partition moves NaNs to the back;
        # bottleneck partition does not!
        smallest = np.argpartition(da.to_numpy(), kth=n, axis=axis)
        return template.copy(data=np.take(smallest, indices=np.arange(n), axis=axis))
 
    return xr.map_blocks(_nsmallest, da, template=template)

Describe alternatives you've considered

In principle, the same can be achieved using e.g. xarray's argsort, but this is much more costly when e.g. only the three highest or lowest values are required. Argsort doesn't support dimensions and isn't NaN-aware either; nsmallest is more straightforward since nlargest is obstructed by the NaNs moved to the end.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions