Skip to content

Commit 4164712

Browse files
authored
Check method only for dask reductions. (#241)
1 parent 622ddb2 commit 4164712

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

flox/core.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1875,12 +1875,6 @@ def groupby_reduce(
18751875
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
18761876
nax = len(axis_)
18771877

1878-
if method in ["blockwise", "cohorts"] and nax != by_.ndim:
1879-
raise NotImplementedError(
1880-
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
1881-
f"Received method={method!r}"
1882-
)
1883-
18841878
# TODO: make sure expected_groups is unique
18851879
if nax == 1 and by_.ndim > 1 and expected_groups is None:
18861880
if not any_by_dask:
@@ -1949,6 +1943,12 @@ def groupby_reduce(
19491943
f"\n\n Received: {func}"
19501944
)
19511945

1946+
if method in ["blockwise", "cohorts"] and nax != by_.ndim:
1947+
raise NotImplementedError(
1948+
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
1949+
f"Received method={method!r}"
1950+
)
1951+
19521952
# TODO: just do this in dask_groupby_agg
19531953
# we always need some fill_value (see above) so choose the default if needed
19541954
if kwargs["fill_value"] is None:

tests/test_core.py

+34
Original file line numberDiff line numberDiff line change
@@ -1347,3 +1347,37 @@ def test_expected_index_conversion_passthrough_range_index(sort):
13471347
expected_groups=(index,), isbin=(False,), sort=(sort,)
13481348
)
13491349
assert actual[0] is index
1350+
1351+
1352+
def test_method_check_numpy():
1353+
bins = [-2, -1, 0, 1, 2]
1354+
field = np.ones((5, 3))
1355+
by = np.array([[-1.5, -1.5, 0.5, 1.5, 1.5] * 3]).reshape(5, 3)
1356+
actual, _ = groupby_reduce(
1357+
field,
1358+
by,
1359+
expected_groups=pd.IntervalIndex.from_breaks(bins),
1360+
func="count",
1361+
method="cohorts",
1362+
fill_value=np.nan,
1363+
)
1364+
expected = np.array([6, np.nan, 3, 6])
1365+
assert_equal(actual, expected)
1366+
1367+
actual, _ = groupby_reduce(
1368+
field,
1369+
by,
1370+
expected_groups=pd.IntervalIndex.from_breaks(bins),
1371+
func="count",
1372+
fill_value=np.nan,
1373+
method="cohorts",
1374+
axis=0,
1375+
)
1376+
expected = np.array(
1377+
[
1378+
[2.0, np.nan, 1.0, 2.0],
1379+
[2.0, np.nan, 1.0, 2.0],
1380+
[2.0, np.nan, 1.0, 2.0],
1381+
]
1382+
)
1383+
assert_equal(actual, expected)

0 commit comments

Comments
 (0)