Description
What is your issue?
I think that this is a longstanding problem. Sorry if I missed an existing github issue.
I was looking at an Dask-array-backed Xarray workload with @phofl and we were both concerned about some performance we were seeing with groupby-aggregations called with out-of-order indices. Here is a minimal example:
import xarray as xr
import dask.array as da
import numpy as np
import pandas as pd
lat = np.linspace(-89.5, 89.5, 100)
lon = np.linspace(-179.375, 179.375, 100)
time = pd.date_range(
start="1990-01-01", end="2000-12-31", freq="D",
)
arr = (
xr.DataArray(
da.random.random((100, 100, len(time)), chunks=(100, 100, 365)),
dims=["lat", "lon", "time"],
coords={"lat": lat, "lon": lon, "time": time},
name="arr"
)
.to_dataset()
)
arr["arr"].data
def f(x):
return x
result = arr.groupby("time.dayofyear").map(f)
result["arr"].data
Performance here is bad in a few ways:
- Output chunk sizes are very small (12 chunks turns into 4000 chunks)
- There are a lot of tasks
- There are a lot of layers (365 new layers)
We think that what is happening here looks like this:
- slice underlying array with a very out-of-order array to arrange groups to be close to each other
- Iterate through each group and apply function
- slice the underlying array with the inverse array to put everything back in the right place
For steps (1) and (3) above performance is bad in a way that we can reduce to a dask array performance issue. Here is a small reproducer for that:
x = da.random.random((100, 100, 10000))
x
idx = np.random.randint(0, x.shape[2], x.shape[2])
x[:, :, idx]
We think that we can make this better on our end, and can take that away as homework.
However for step (2) we think that this probably has to be a change in xarray. Ideally xarray would call something like map_blocks
, rather than iterate through each group. This would be a special-case for dask-array. Is this ok?
Also, we think that this has a lot of impact throughout xarray, but are not sure. Is this also the code path taken in sum/max/etc..? (assuming that flox
is not around). Mostly we're curious how much we all should prioritize this.
Asks
Some questions:
- Does our understanding of the situation sound right?
- Is avoiding iteration through groups for dask arrays doable?
- is anyone around to do this within xarray if we're also improving slicing on the dask array side?