Skip to content

Commit 2fdf98a

Browse files
authored
Fix binning with "cohorts" & "split-reduce" (#94)
1 parent af70b2a commit 2fdf98a

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

flox/core.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,10 @@ def groupby_reduce(
14591459
by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
14601460
nby = len(by)
14611461
by_is_dask = any(is_duck_dask_array(b) for b in by)
1462+
1463+
if method in ["split-reduce", "cohorts"] and by_is_dask:
1464+
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
1465+
14621466
if not is_duck_array(array):
14631467
array = np.asarray(array)
14641468
if isinstance(isbin, bool):
@@ -1477,9 +1481,11 @@ def groupby_reduce(
14771481
# (pd.IntervalIndex or not)
14781482
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort)
14791483

1480-
# when grouping by multiple variables, we factorize early.
14811484
# TODO: could restrict this to dask-only
1482-
if nby > 1:
1485+
factorize_early = (nby > 1) or (
1486+
any(isbin) and method in ["split-reduce", "cohorts"] and is_duck_dask_array(array)
1487+
)
1488+
if factorize_early:
14831489
by, final_groups, grp_shape = _factorize_multiple(
14841490
by, expected_groups, by_is_dask=by_is_dask
14851491
)
@@ -1497,6 +1503,7 @@ def groupby_reduce(
14971503
if method in ["blockwise", "cohorts", "split-reduce"] and len(axis) != by.ndim:
14981504
raise NotImplementedError(
14991505
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
1506+
f"Received method={method!r}"
15001507
)
15011508

15021509
# TODO: make sure expected_groups is unique
@@ -1617,10 +1624,12 @@ def groupby_reduce(
16171624
result = result[..., sorted_idx]
16181625
groups = (groups[0][sorted_idx],)
16191626

1620-
if nby > 1:
1627+
if factorize_early:
16211628
# nan group labels are factorized to -1, and preserved
1622-
# now we get rid of them
1623-
nanmask = groups[0] == -1
1629+
# now we get rid of them by reindexing
1630+
# This also handles bins with no data
1631+
result = reindex_(
1632+
result, from_=groups[0], to=expected_groups, fill_value=fill_value
1633+
).reshape(result.shape[:-1] + grp_shape)
16241634
groups = final_groups
1625-
result = result[..., ~nanmask].reshape(result.shape[:-1] + grp_shape)
16261635
return (result, *groups)

tests/test_core.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,16 @@ def test_npg_nanarg_bug(func):
582582
assert_equal(actual, expected)
583583

584584

585+
@pytest.mark.parametrize("method", ["split-reduce", "cohorts", "map-reduce"])
585586
@pytest.mark.parametrize("chunk_labels", [False, True])
586587
@pytest.mark.parametrize("chunks", ((), (1,), (2,)))
587-
def test_groupby_bins(chunk_labels, chunks, engine) -> None:
588+
def test_groupby_bins(chunk_labels, chunks, engine, method) -> None:
588589
array = [1, 1, 1, 1, 1, 1]
589590
labels = [0.2, 1.5, 1.9, 2, 3, 20]
590591

592+
if method in ["split-reduce", "cohorts"] and chunk_labels:
593+
pytest.xfail()
594+
591595
if chunks:
592596
if not has_dask:
593597
pytest.skip()
@@ -604,6 +608,7 @@ def test_groupby_bins(chunk_labels, chunks, engine) -> None:
604608
isbin=True,
605609
fill_value=0,
606610
engine=engine,
611+
method=method,
607612
)
608613
expected = np.array([3, 1, 0])
609614
for left, right in zip(groups, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()):

0 commit comments

Comments
 (0)