Skip to content

Commit 15324a7

Browse files
Fix reordering of dataarray dimensions inside dataset (#289)
* Fix groupby behaviour for dataarray dimensions inside dataset * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6062e35 commit 15324a7

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

flox/xarray.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
Dims = Union[str, Iterable[Hashable], None]
2929

3030

31-
def _restore_dim_order(result, obj, by):
31+
def _restore_dim_order(result, obj, by, no_groupby_reorder=False):
3232
def lookup_order(dimension):
3333
if dimension == by.name and by.ndim == 1:
3434
(dimension,) = by.dims
35+
if no_groupby_reorder:
36+
return -1e6 # some arbitrarily low value
3537
if dimension in obj.dims:
3638
axis = obj.get_axis_num(dimension)
3739
else:
@@ -491,7 +493,12 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
491493
template = obj
492494

493495
if actual[var].ndim > 1:
494-
actual[var] = _restore_dim_order(actual[var], template, by_da[0])
496+
no_groupby_reorder = isinstance(
497+
obj, xr.Dataset
498+
) # do not re-order dataarrays inside datasets
499+
actual[var] = _restore_dim_order(
500+
actual[var], template, by_da[0], no_groupby_reorder=no_groupby_reorder
501+
)
495502

496503
if missing_dim:
497504
for k, v in missing_dim.items():

tests/test_xarray.py

+31
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,34 @@ def test_fill_value_xarray_binning():
598598
actual = data_array.groupby_bins("y", bins=4).mean()
599599

600600
xr.testing.assert_identical(expected, actual)
601+
602+
603+
def test_groupby_2d_dataset():
604+
d = {
605+
"coords": {
606+
"bit_index": {"dims": ("bit_index",), "attrs": {"name": "bit_index"}, "data": [0, 1]},
607+
"index": {"dims": ("index",), "data": [0, 6, 8, 10, 14]},
608+
"clifford": {"dims": ("index",), "attrs": {}, "data": [1, 1, 4, 10, 4]},
609+
},
610+
"dims": {"bit_index": 2, "index": 5},
611+
"data_vars": {
612+
"counts": {
613+
"dims": ("bit_index", "index"),
614+
"attrs": {
615+
"name": "counts",
616+
},
617+
"data": [[18, 30, 45, 70, 38], [382, 370, 355, 330, 362]],
618+
}
619+
},
620+
}
621+
622+
ds = xr.Dataset.from_dict(d)
623+
624+
with xr.set_options(use_flox=False):
625+
expected = ds.groupby("clifford").mean()
626+
with xr.set_options(use_flox=True):
627+
actual = ds.groupby("clifford").mean()
628+
assert (
629+
expected.counts.dims == actual.counts.dims
630+
) # https://github.com/pydata/xarray/issues/8292
631+
xr.testing.assert_identical(expected, actual)

0 commit comments

Comments
 (0)