Skip to content

Commit 5e0b793

Browse files
dcherianLunarLandingpre-commit-ci[bot]
authored
Fix bug where we had extra groups in expected_groups. (#112)
* Fix bug where we had extra groups in expected_groups. This affected _factorize_multiple. Closes #111 * Fix extra expected groups (#113) * fix dask case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: LunarLanding <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a0b9d1f commit 5e0b793

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

flox/core.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1310,11 +1310,12 @@ def _lazy_factorize_wrapper(*by, **kwargs):
13101310
return group_idx
13111311

13121312

1313-
def _factorize_multiple(by, expected_groups, by_is_dask):
1313+
def _factorize_multiple(by, expected_groups, by_is_dask, reindex):
13141314
kwargs = dict(
13151315
expected_groups=expected_groups,
13161316
axis=None, # always None, we offset later if necessary.
13171317
fastpath=True,
1318+
reindex=reindex,
13181319
)
13191320
if by_is_dask:
13201321
import dask.array
@@ -1325,7 +1326,9 @@ def _factorize_multiple(by, expected_groups, by_is_dask):
13251326
meta=np.array((), dtype=np.int64),
13261327
**kwargs,
13271328
)
1328-
found_groups = tuple(None if is_duck_dask_array(b) else pd.unique(b) for b in by)
1329+
found_groups = tuple(
1330+
None if is_duck_dask_array(b) else pd.unique(np.array(b).reshape(-1)) for b in by
1331+
)
13291332
grp_shape = tuple(len(e) for e in expected_groups)
13301333
else:
13311334
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)
@@ -1489,7 +1492,7 @@ def groupby_reduce(
14891492
)
14901493
if factorize_early:
14911494
by, final_groups, grp_shape = _factorize_multiple(
1492-
by, expected_groups, by_is_dask=by_is_dask
1495+
by, expected_groups, by_is_dask=by_is_dask, reindex=reindex
14931496
)
14941497
expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)
14951498

tests/test_xarray.py

+34
Original file line numberDiff line numberDiff line change
@@ -451,3 +451,37 @@ def test_groupby_bins_indexed_coordinate():
451451
method="split-reduce",
452452
)
453453
xr.testing.assert_allclose(expected, actual)
454+
455+
456+
@pytest.mark.parametrize("chunk", (True, False))
457+
def test_mixed_grouping(chunk):
458+
if not has_dask and chunk:
459+
pytest.skip()
460+
# regression test for https://github.com/dcherian/flox/pull/111
461+
sa = 10
462+
sb = 13
463+
sc = 3
464+
465+
x = xr.Dataset(
466+
{
467+
"v0": xr.DataArray(
468+
((np.arange(sa * sb * sc) / sa) % 1).reshape((sa, sb, sc)),
469+
dims=("a", "b", "c"),
470+
),
471+
"v1": xr.DataArray((np.arange(sa * sb) % 3).reshape(sa, sb), dims=("a", "b")),
472+
}
473+
)
474+
if chunk:
475+
x["v0"] = x["v0"].chunk({"a": 5})
476+
477+
r = xarray_reduce(
478+
x["v0"],
479+
x["v1"],
480+
x["v0"],
481+
expected_groups=(np.arange(6), np.linspace(0, 1, num=5)),
482+
isbin=[False, True],
483+
func="count",
484+
dim="b",
485+
fill_value=0,
486+
)
487+
assert (r.sel(v1=[3, 4, 5]) == 0).all().data

0 commit comments

Comments
 (0)