Skip to content

Groupby-map is slow with out of order indices #9220

Open
@mrocklin

Description

What is your issue?

I think that this is a longstanding problem. Sorry if I missed an existing github issue.

I was looking at an Dask-array-backed Xarray workload with @phofl and we were both concerned about some performance we were seeing with groupby-aggregations called with out-of-order indices. Here is a minimal example:

import xarray as xr
import dask.array as da
import numpy as np
import pandas as pd

lat = np.linspace(-89.5, 89.5, 100)
lon = np.linspace(-179.375, 179.375, 100)
time = pd.date_range(
    start="1990-01-01", end="2000-12-31", freq="D",
)

arr = (
    xr.DataArray(
        da.random.random((100, 100, len(time)), chunks=(100, 100, 365)),
        dims=["lat", "lon", "time"],
        coords={"lat": lat, "lon": lon, "time": time},
        name="arr"
    )
    .to_dataset()
)

arr["arr"].data
Screenshot 2024-07-09 at 11 58 07 AM
def f(x):
    return x

result = arr.groupby("time.dayofyear").map(f)
result["arr"].data
Screenshot 2024-07-09 at 11 58 34 AM

Performance here is bad in a few ways:

  • Output chunk sizes are very small (12 chunks turns into 4000 chunks)
  • There are a lot of tasks
  • There are a lot of layers (365 new layers)

We think that what is happening here looks like this:

  1. slice underlying array with a very out-of-order array to arrange groups to be close to each other
  2. Iterate through each group and apply function
  3. slice the underlying array with the inverse array to put everything back in the right place

For steps (1) and (3) above performance is bad in a way that we can reduce to a dask array performance issue. Here is a small reproducer for that:

x = da.random.random((100, 100, 10000))
x
Screenshot 2024-07-09 at 11 51 52 AM
idx = np.random.randint(0, x.shape[2], x.shape[2])
x[:, :, idx]
Screenshot 2024-07-09 at 11 52 07 AM

We think that we can make this better on our end, and can take that away as homework.

However for step (2) we think that this probably has to be a change in xarray. Ideally xarray would call something like map_blocks, rather than iterate through each group. This would be a special-case for dask-array. Is this ok?

Also, we think that this has a lot of impact throughout xarray, but are not sure. Is this also the code path taken in sum/max/etc..? (assuming that flox is not around). Mostly we're curious how much we all should prioritize this.

Asks

Some questions:

  • Does our understanding of the situation sound right?
  • Is avoiding iteration through groups for dask arrays doable?
  • is anyone around to do this within xarray if we're also improving slicing on the dask array side?

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions