Skip to content

Commit ce7b9b7

Browse files
authored
Better handling of RangeIndex for sparse reindexing (#440)
1 parent 6d34d62 commit ce7b9b7

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

flox/core.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,19 @@ def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value,
762762

763763
assert axis == -1
764764

765-
needs_reindex = (from_.get_indexer(to) == -1).any()
765+
# Are there any elements in `to` that are not in `from_`.
766+
if isinstance(to, pd.RangeIndex) and len(to) > len(from_):
767+
# 1. pandas optimizes set difference between two RangeIndexes only
768+
# 2. We want to avoid realizing a very large numpy array in to memory.
769+
# This happens in the `else` clause.
770+
# There are potentially other tricks we can play, but this is a simple
771+
# and effective one. If a user is reindexing to sparse, then len(to) is
772+
# almost guaranteed to be > len(from_). If len(to) <= len(from_), then realizing
773+
# another array of the same shape should be fine.
774+
needs_reindex = True
775+
else:
776+
needs_reindex = (from_.get_indexer(to) == -1).any()
777+
766778
if needs_reindex and fill_value is None:
767779
raise ValueError("Filling is required. fill_value cannot be None.")
768780

@@ -2315,6 +2327,8 @@ def _factorize_multiple(
23152327
if any_by_dask:
23162328
import dask.array
23172329

2330+
from . import dask_array_ops # noqa
2331+
23182332
# unifying chunks will make sure all arrays in `by` are dask arrays
23192333
# with compatible chunks, even if there was originally a numpy array
23202334
inds = tuple(range(by[0].ndim))

tests/test_core.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -2085,12 +2085,14 @@ def test_datetime_timedelta_first_last(engine, func) -> None:
20852085

20862086
@requires_dask
20872087
@requires_sparse
2088-
def test_reindex_sparse():
2088+
@pytest.mark.xdist_group(name="sparse-group")
2089+
@pytest.mark.parametrize("size", [2**62 - 1, 11])
2090+
def test_reindex_sparse(size):
20892091
import sparse
20902092

20912093
array = dask.array.ones((2, 12), chunks=(-1, 3))
20922094
func = "sum"
2093-
expected_groups = pd.Index(np.arange(11))
2095+
expected_groups = pd.RangeIndex(size)
20942096
by = dask.array.from_array(np.repeat(np.arange(6) * 2, 2), chunks=(3,))
20952097
dense = np.zeros((2, 11))
20962098
dense[..., np.arange(6) * 2] = 2
@@ -2110,14 +2112,23 @@ def mocked_reindex(*args, **kwargs):
21102112
assert isinstance(res, sparse.COO)
21112113
return res
21122114

2113-
with patch("flox.core.reindex_") as mocked_func:
2114-
mocked_func.side_effect = mocked_reindex
2115-
actual, *_ = groupby_reduce(
2116-
array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0
2117-
)
2118-
assert_equal(actual, expected)
2119-
# once during graph construction, 10 times afterward
2120-
assert mocked_func.call_count > 1
2115+
# Define the error-raising property
2116+
def raise_error(self):
2117+
raise AttributeError("Access to '_data' is not allowed.")
2118+
2119+
with patch("flox.core.reindex_") as mocked_reindex_func:
2120+
with patch.object(pd.RangeIndex, "_data", property(raise_error)):
2121+
mocked_reindex_func.side_effect = mocked_reindex
2122+
actual, *_ = groupby_reduce(
2123+
array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0
2124+
)
2125+
if size == 11:
2126+
assert_equal(actual, expected)
2127+
else:
2128+
actual.compute() # just compute
2129+
2130+
# once during graph construction, 10 times afterward
2131+
assert mocked_reindex_func.call_count > 1
21212132

21222133

21232134
def test_sparse_errors():

0 commit comments

Comments
 (0)