Skip to content

Commit aab2646

Browse files
TomNicholasdcherianpre-commit-ci[bot]
authored
Preserve multiindex (#216)
* test based on example * fix * Update tests/test_xarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 33136a8 commit aab2646

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

flox/xarray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
432432

433433
# restore non-dim coord variables without the core dimension
434434
# TODO: shouldn't apply_ufunc handle this?
435-
for var in set(ds_broad.variables) - set(ds_broad.dims):
435+
for var in set(ds_broad.variables) - set(ds_broad.xindexes) - set(ds_broad.dims):
436436
if all(d not in ds_broad[var].dims for d in dim_tuple):
437437
actual[var] = ds_broad[var]
438438

tests/test_xarray.py

+31
Original file line numberDiff line numberDiff line change
@@ -594,3 +594,34 @@ def test_dtype_accumulation(use_flox, chunk):
594594
assert np.issubdtype(actual.dtype, np.float64)
595595
assert np.issubdtype(actual.compute().dtype, np.float64)
596596
xr.testing.assert_allclose(expected, actual, **tolerance64)
597+
598+
599+
def test_preserve_multiindex():
600+
"""Regression test for GH issue #215"""
601+
602+
vort = xr.DataArray(
603+
name="vort",
604+
data=np.random.uniform(size=(4, 2)),
605+
dims=["i", "face"],
606+
coords={"i": ("i", np.arange(4)), "face": ("face", np.arange(2))},
607+
)
608+
609+
vort = (
610+
vort.coarsen(i=2)
611+
.construct(i=("i_region_coarse", "i_region"))
612+
.stack(region=["face", "i_region_coarse"])
613+
)
614+
615+
bins = [np.linspace(0, 1, 10)]
616+
bin_intervals = tuple(pd.IntervalIndex.from_breaks(b) for b in bins)
617+
618+
hist = xarray_reduce(
619+
xr.DataArray(1), # weights
620+
vort, # variables we want to bin
621+
func="count", # count occurrences falling in bins
622+
expected_groups=bin_intervals, # bins for each variable
623+
dim=["i_region"], # broadcast dimensions
624+
fill_value=0, # fill empty bins with 0 counts
625+
)
626+
627+
assert "region" in hist.coords

0 commit comments

Comments
 (0)