Skip to content

Commit 676c0f0

Browse files
authored
Bugfix for cohorts where not all expected_groups are present (#316)
1 parent 8ea0cd1 commit 676c0f0

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

flox/core.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,10 @@ def invert(x) -> tuple[np.ndarray, ...]:
340340
# TODO: we can optimize this to loop over chunk_cohorts instead
341341
# by zeroing out rows that are already in a cohort
342342
for rowidx in order:
343-
cohort_ = containment.indices[
343+
cohidx = containment.indices[
344344
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
345345
]
346+
cohort_ = present_labels[cohidx]
346347
cohort = [elem for elem in cohort_ if elem not in merged_keys]
347348
if not cohort:
348349
continue

tests/test_core.py

+10
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,16 @@ def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None:
857857
assert actual == expected, (actual, expected)
858858

859859

860+
@requires_dask
861+
def test_find_cohorts_missing_groups():
862+
by = np.array([np.nan, np.nan, np.nan, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, np.nan, np.nan])
863+
kwargs = {"func": "sum", "expected_groups": [0, 1, 2], "fill_value": 123}
864+
array = dask.array.ones_like(by, chunks=(3,))
865+
actual, _ = groupby_reduce(array, by, method="cohorts", **kwargs)
866+
expected, _ = groupby_reduce(array.compute(), by, **kwargs)
867+
assert_equal(expected, actual)
868+
869+
860870
@pytest.mark.parametrize("chunksize", [12, 13, 14, 24, 36, 48, 72, 71])
861871
def test_verify_complex_cohorts(chunksize: int) -> None:
862872
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))

0 commit comments

Comments
 (0)