Skip to content

Commit d65181c

Browse files
authored
Avoid rechunking when preferred_method="blockwise" (#394)
* Avoid rechunking when preferred_method="blockwise" * Add test * fix
1 parent 7421cb1 commit d65181c

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

flox/core.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
642642
DaskArray
643643
Rechunked array
644644
"""
645+
# TODO: this should be unnecessary?
645646
labels = factorize_((labels,), axes=())[0]
646647
chunks = array.chunks[axis]
647648
newchunks = _get_optimal_chunks_for_groups(chunks, labels)
@@ -2623,7 +2624,8 @@ def groupby_reduce(
26232624

26242625
partial_agg = partial(dask_groupby_agg, **kwargs)
26252626

2626-
if method == "blockwise" and by_.ndim == 1:
2627+
# if preferred method is already blockwise, no need to rechunk
2628+
if preferred_method != "blockwise" and method == "blockwise" and by_.ndim == 1:
26272629
array = rechunk_for_blockwise(array, axis=-1, labels=by_)
26282630

26292631
result, groups = partial_agg(

tests/test_core.py

+9
Original file line numberDiff line numberDiff line change
@@ -1997,3 +1997,12 @@ def test_agg_dtypes(func, engine):
19971997
)
19981998
expected = _get_array_func(func)(counts, dtype="uint8")
19991999
assert actual.dtype == np.uint8 == expected.dtype
2000+
2001+
2002+
@requires_dask
2003+
def test_blockwise_avoid_rechunk():
2004+
array = dask.array.zeros((6,), chunks=(2, 4), dtype=np.int64)
2005+
by = np.array(["1", "1", "0", "", "0", ""], dtype="<U1")
2006+
actual, groups = groupby_reduce(array, by, func="first")
2007+
assert_equal(groups, ["", "0", "1"])
2008+
assert_equal(actual, [0, 0, 0])

0 commit comments

Comments
 (0)