Skip to content

Commit 227ce04

Browse files
authored
Fix grouping by multiindex (#106)
1 parent 5b7edbe commit 227ce04

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

flox/xarray.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
import xarray as xr
8+
from packaging.version import Version
89

910
from .aggregations import Aggregation, _atleast_1d
1011
from .core import (
@@ -345,12 +346,16 @@ def wrapper(array, *by, func, skipna, **kwargs):
345346
expect = expect.to_numpy()
346347
if isinstance(actual, xr.Dataset) and name in actual:
347348
actual = actual.drop_vars(name)
348-
actual[name] = expect
349-
350-
# if grouping by multi-indexed variable, then restore it
351-
for name, index in ds.indexes.items():
352-
if name in actual.indexes and isinstance(index, pd.MultiIndex):
353-
actual[name] = index
349+
# When grouping by MultiIndex, expect is an pd.Index wrapping
350+
# an object array of tuples
351+
if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex):
352+
levelnames = ds.indexes[name].names
353+
expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames)
354+
actual[name] = expect
355+
if Version(xr.__version__) > Version("2022.03.0"):
356+
actual = actual.set_coords(levelnames)
357+
else:
358+
actual[name] = expect
354359

355360
if unindexed_dims:
356361
actual = actual.drop_vars(unindexed_dims)
@@ -361,7 +366,8 @@ def wrapper(array, *by, func, skipna, **kwargs):
361366
template = obj
362367
else:
363368
template = obj[var]
364-
actual[var] = _restore_dim_order(actual[var], template, by[0])
369+
if actual[var].ndim > 1:
370+
actual[var] = _restore_dim_order(actual[var], template, by[0])
365371

366372
if missing_dim:
367373
for k, v in missing_dim.items():
@@ -370,9 +376,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
370376
}
371377
# The expand_dims is for backward compat with xarray's questionable behaviour
372378
if missing_group_dims:
373-
actual[k] = v.expand_dims(missing_group_dims)
379+
actual[k] = v.expand_dims(missing_group_dims).variable
374380
else:
375-
actual[k] = v
381+
actual[k] = v.variable
376382

377383
if isinstance(obj, xr.DataArray):
378384
return obj._from_temp_dataset(actual)

tests/test_xarray.py

+11
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,17 @@ def test_multi_index_groupby_sum(engine):
321321
actual = xarray_reduce(stacked, "space", dim="z", func="sum", engine=engine)
322322
assert_equal(expected, actual.unstack("space"))
323323

324+
actual = xarray_reduce(stacked.foo, "space", dim="z", func="sum", engine=engine)
325+
assert_equal(expected.foo, actual.unstack("space"))
326+
327+
ds = xr.Dataset(
328+
dict(a=(("z",), np.ones(10))),
329+
coords=dict(b=(("z"), np.arange(2).repeat(5)), c=(("z"), np.arange(5).repeat(2))),
330+
).set_index(bc=["b", "c"])
331+
expected = ds.groupby("bc").sum()
332+
actual = xarray_reduce(ds, "bc", func="sum")
333+
assert_equal(expected, actual)
334+
324335

325336
@pytest.mark.parametrize("chunks", (None, 2))
326337
def test_xarray_groupby_bins(chunks, engine):

0 commit comments

Comments
 (0)