Skip to content

Commit a9f51a1

Browse files
authored
Fix bug with NaNs in by and method='blockwise' (#384)
xref pydata/xarray#9320
1 parent 4dbadae commit a9f51a1

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

flox/core.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -2663,10 +2663,18 @@ def groupby_reduce(
26632663
groups = (groups[0][sorted_idx],)
26642664

26652665
if factorize_early:
2666+
assert len(groups) == 1
2667+
(groups_,) = groups
26662668
# nan group labels are factorized to -1, and preserved
26672669
# now we get rid of them by reindexing
2668-
# This also handles bins with no data
2669-
result = reindex_(result, from_=groups[0], to=expected_, fill_value=fill_value).reshape(
2670+
# First, for "blockwise", we can have -1 repeated in different blocks
2671+
# This breaks the reindexing so remove those first.
2672+
if method == "blockwise" and (mask := groups_ == -1).sum(axis=-1) > 1:
2673+
result = result[..., ~mask]
2674+
groups_ = groups_[..., ~mask]
2675+
2676+
# This reindex also handles bins with no data
2677+
result = reindex_(result, from_=groups_, to=expected_, fill_value=fill_value).reshape(
26702678
result.shape[:-1] + grp_shape
26712679
)
26722680
groups = final_groups

tests/test_core.py

+14
Original file line numberDiff line numberDiff line change
@@ -1929,3 +1929,17 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
19291929
expected = flox.groupby_scan(array.compute(), by, func=func)
19301930
actual = flox.groupby_scan(array, by, func=func)
19311931
assert_equal(expected, actual)
1932+
1933+
1934+
@requires_dask
1935+
def test_blockwise_nans():
1936+
array = dask.array.ones((1, 10), chunks=2)
1937+
by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4])
1938+
actual, actual_groups = flox.groupby_reduce(
1939+
array, by, func="sum", expected_groups=pd.RangeIndex(0, 5)
1940+
)
1941+
expected, expected_groups = flox.groupby_reduce(
1942+
array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5)
1943+
)
1944+
assert_equal(expected_groups, actual_groups)
1945+
assert_equal(expected, actual)

0 commit comments

Comments
 (0)