Skip to content

Commit 273d319

Browse files
authored
Set order='F' when raveling group_idx after broadcast (#286)
* Set order='F' when raveling group_idx after broadcast This majorly improves the dim=... case for engine="flox" at least. xref #281 I'm not sure if it is a regression for engine="numpy" We trade off a single bad reshape for array against argsorting both array and group_idx for a ~10-20x speedup ``` ds = xr.tutorial.load_dataset('air_temperature') ds.groupby('lon').count(..., engine="flox") ``` * This is an improvement only for engine=flox * Update tests * Fix benchmark * type ignore
1 parent 92f4780 commit 273d319

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

asv_bench/benchmarks/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def setup(self, *args, **kwargs):
110110
class ChunkReduce2DAllAxes(ChunkReduce):
111111
def setup(self, *args, **kwargs):
112112
self.array = np.ones((N, N))
113-
self.labels = np.repeat(np.arange(N // 5), repeats=5)
113+
self.labels = np.repeat(np.arange(N // 5), repeats=5)[np.newaxis, :]
114114
self.axis = None
115115
setup_jit()
116116

flox/core.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -782,16 +782,25 @@ def chunk_reduce(
782782
)
783783
groups = grps[0]
784784

785+
order = "C"
785786
if nax > 1:
786787
needs_broadcast = any(
787788
group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
788789
for ax in range(-nax, 0)
789790
)
790791
if needs_broadcast:
792+
# This is the dim=... case, it's a lot faster to ravel group_idx
793+
# in fortran order since group_idx is then sorted
794+
# I'm seeing 400ms -> 23ms for engine="flox"
795+
# Of course we are slower to ravel `array` but we avoid argsorting
796+
# both `array` *and* `group_idx` in _prepare_for_flox
791797
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
798+
if engine == "flox":
799+
group_idx = group_idx.reshape(-1, order="F")
800+
order = "F"
792801
# always reshape to 1D along group dimensions
793802
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
794-
array = array.reshape(newshape)
803+
array = array.reshape(newshape, order=order) # type: ignore[call-overload]
795804
group_idx = group_idx.reshape(-1)
796805

797806
assert group_idx.ndim == 1
@@ -1814,7 +1823,8 @@ def groupby_reduce(
18141823
Array to be reduced, possibly nD
18151824
*by : ndarray or DaskArray
18161825
Array of labels to group over. Must be aligned with ``array`` so that
1817-
``array.shape[-by.ndim :] == by.shape``
1826+
``array.shape[-by.ndim :] == by.shape`` or any disagreements in that
1827+
equality check are for dimensions of size 1 in `by`.
18181828
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
18191829
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
18201830
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \

tests/test_core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def test_groupby_reduce(
203203
def gen_array_by(size, func):
204204
by = np.ones(size[-1])
205205
rng = np.random.default_rng(12345)
206-
array = rng.random(size)
206+
array = rng.random(tuple(6 if s == 1 else s for s in size))
207207
if "nan" in func and "nanarg" not in func:
208208
array[[1, 4, 5], ...] = np.nan
209209
elif "nanarg" in func and len(size) > 1:
@@ -222,8 +222,8 @@ def gen_array_by(size, func):
222222
pytest.param(4, marks=requires_dask),
223223
],
224224
)
225+
@pytest.mark.parametrize("size", ((1, 12), (12,), (12, 9)))
225226
@pytest.mark.parametrize("nby", [1, 2, 3])
226-
@pytest.mark.parametrize("size", ((12,), (12, 9)))
227227
@pytest.mark.parametrize("add_nan_by", [True, False])
228228
@pytest.mark.parametrize("func", ALL_FUNCS)
229229
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):

0 commit comments

Comments
 (0)