Skip to content

Commit 24dc7fd

Browse files
authored
Always reindex=True for all numpy inputs (#228)
1 parent 13d1062 commit 24dc7fd

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

flox/core.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1519,9 +1519,15 @@ def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
15191519

15201520

15211521
def _validate_reindex(
1522-
reindex: bool | None, func, method: T_Method, expected_groups, any_by_dask: bool
1522+
reindex: bool | None,
1523+
func,
1524+
method: T_Method,
1525+
expected_groups,
1526+
any_by_dask: bool,
1527+
is_dask_array: bool,
15231528
) -> bool:
1524-
if reindex is True:
1529+
all_numpy = not is_dask_array and not any_by_dask
1530+
if reindex is True and not all_numpy:
15251531
if _is_arg_reduction(func):
15261532
raise NotImplementedError
15271533
if method in ["blockwise", "cohorts"]:
@@ -1530,6 +1536,9 @@ def _validate_reindex(
15301536
)
15311537

15321538
if reindex is None:
1539+
if all_numpy:
1540+
return True
1541+
15331542
if method == "blockwise" or _is_arg_reduction(func):
15341543
reindex = False
15351544

@@ -1796,7 +1805,9 @@ def groupby_reduce(
17961805
if method == "split-reduce":
17971806
method = "cohorts"
17981807

1799-
reindex = _validate_reindex(reindex, func, method, expected_groups, any_by_dask)
1808+
reindex = _validate_reindex(
1809+
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
1810+
)
18001811

18011812
if not is_duck_array(array):
18021813
array = np.asarray(array)

tests/test_core.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,7 @@ def test_subset_block_2d(flatblocks, expectidx):
12361236

12371237

12381238
@pytest.mark.parametrize(
1239-
"expected, reindex, func, expected_groups, any_by_dask",
1239+
"dask_expected, reindex, func, expected_groups, any_by_dask",
12401240
[
12411241
# argmax only False
12421242
[False, None, "argmax", None, False],
@@ -1252,22 +1252,43 @@ def test_subset_block_2d(flatblocks, expectidx):
12521252
[True, None, "sum", ([1], None), True],
12531253
],
12541254
)
1255-
def test_validate_reindex_map_reduce(expected, reindex, func, expected_groups, any_by_dask):
1256-
actual = _validate_reindex(reindex, func, "map-reduce", expected_groups, any_by_dask)
1257-
assert actual == expected
1255+
def test_validate_reindex_map_reduce(
1256+
dask_expected, reindex, func, expected_groups, any_by_dask
1257+
) -> None:
1258+
actual = _validate_reindex(
1259+
reindex, func, "map-reduce", expected_groups, any_by_dask, is_dask_array=True
1260+
)
1261+
assert actual is dask_expected
12581262

1263+
# always reindex with all numpy inputs
1264+
actual = _validate_reindex(
1265+
reindex, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
1266+
)
1267+
assert actual
1268+
1269+
actual = _validate_reindex(
1270+
True, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
1271+
)
1272+
assert actual
12591273

1260-
def test_validate_reindex():
1274+
1275+
def test_validate_reindex() -> None:
12611276
for method in ["map-reduce", "cohorts"]:
12621277
with pytest.raises(NotImplementedError):
1263-
_validate_reindex(True, "argmax", method, expected_groups=None, any_by_dask=False)
1278+
_validate_reindex(
1279+
True, "argmax", method, expected_groups=None, any_by_dask=False, is_dask_array=True
1280+
)
12641281

12651282
for method in ["blockwise", "cohorts"]:
12661283
with pytest.raises(ValueError):
1267-
_validate_reindex(True, "sum", method, expected_groups=None, any_by_dask=False)
1284+
_validate_reindex(
1285+
True, "sum", method, expected_groups=None, any_by_dask=False, is_dask_array=True
1286+
)
12681287

12691288
for func in ["sum", "argmax"]:
1270-
actual = _validate_reindex(None, func, method, expected_groups=None, any_by_dask=False)
1289+
actual = _validate_reindex(
1290+
None, func, method, expected_groups=None, any_by_dask=False, is_dask_array=True
1291+
)
12711292
assert actual is False
12721293

12731294

0 commit comments

Comments
 (0)