Skip to content

Commit 622ddb2

Browse files
authored
Optimize broadcasting (#230)
* Optimize broadcasting xref pydata/xarray#7730 * reorder * Fix tests * Another optimization * fixes * fix
1 parent aa358a5 commit 622ddb2

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

flox/xarray.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,6 @@ def xarray_reduce(
257257
more_drop.update(idx_other_names)
258258
maybe_drop.update(more_drop)
259259

260-
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])
261-
262260
if dim is Ellipsis:
263261
if nby > 1:
264262
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
@@ -275,17 +273,23 @@ def xarray_reduce(
275273
# broadcast to make sure grouper dimensions are present in the array.
276274
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
277275

276+
if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
277+
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
278+
278279
try:
279280
xr.align(ds, *by_da, join="exact", copy=False)
280281
except ValueError as e:
281282
raise ValueError(
282283
"Object being grouped must be exactly aligned with every array in `by`."
283284
) from e
284285

285-
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]
286-
287-
if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
288-
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
286+
needs_broadcast = any(
287+
not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values()
288+
)
289+
if needs_broadcast:
290+
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]
291+
else:
292+
ds_broad = ds
289293

290294
dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims)
291295
if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins):
@@ -305,6 +309,8 @@ def xarray_reduce(
305309
else:
306310
return result
307311

312+
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])
313+
308314
axis = tuple(range(-len(dim_tuple), 0))
309315

310316
# Set expected_groups and convert to index since we need coords, sizes
@@ -432,7 +438,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
432438

433439
# restore non-dim coord variables without the core dimension
434440
# TODO: shouldn't apply_ufunc handle this?
435-
for var in set(ds_broad.variables) - set(ds_broad._indexes) - set(ds_broad.dims):
441+
for var in set(ds_broad._coord_names) - set(ds_broad._indexes) - set(ds_broad.dims):
436442
if all(d not in ds_broad[var].dims for d in dim_tuple):
437443
actual[var] = ds_broad[var]
438444

tests/test_xarray.py

+2
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ def test_multi_index_groupby_sum(engine):
343343
expected = ds.sum("z")
344344
stacked = ds.stack(space=["x", "y"])
345345
actual = xarray_reduce(stacked, "space", dim="z", func="sum", engine=engine)
346+
expected_xarray = stacked.groupby("space").sum("z")
347+
assert_equal(expected_xarray, actual)
346348
assert_equal(expected, actual.unstack("space"))
347349

348350
actual = xarray_reduce(stacked.foo, "space", dim="z", func="sum", engine=engine)

0 commit comments

Comments
 (0)