Skip to content

Commit d2a68d0

Browse files
Update apply_ufunc output_sizes error message (#7509)
* test for error message * fix * whatsnew --------- Co-authored-by: Deepak Cherian <[email protected]>
1 parent 5b6d757 commit d2a68d0

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ Bug fixes
5858
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_ and `Scott Chamberlin <https://github.com/scottcha>`_.
5959
- Handle ``keep_attrs`` option in binary operators of :py:meth:`Dataset` (:issue:`7390`, :pull:`7391`).
6060
By `Aron Gergely <https://github.com/arongergely>`_.
61+
- Improve error message when using dask in :py:func:`apply_ufunc` with ``output_sizes`` not supplied. (:pull:`7509`)
62+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
6163
- :py:func:`xarray.Dataset.to_zarr` now drops variable encodings that have been added by xarray during reading
6264
a dataset. (:issue:`7129`, :pull:`7500`).
6365
By `Hauke Schulz <https://github.com/observingClouds>`_.

xarray/core/computation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,9 @@ def apply_variable_ufunc(
723723
dask_gufunc_kwargs["output_sizes"] = output_sizes_renamed
724724

725725
for key in signature.all_output_core_dims:
726-
if key not in signature.all_input_core_dims and key not in output_sizes:
726+
if (
727+
key not in signature.all_input_core_dims or key in exclude_dims
728+
) and key not in output_sizes:
727729
raise ValueError(
728730
f"dimension '{key}' in 'output_core_dims' needs corresponding (dim, size) in 'output_sizes'"
729731
)

xarray/tests/test_computation.py

+19
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,25 @@ def func(da):
12201220
assert_identical(expected.chunk(), actual)
12211221

12221222

1223+
@requires_dask
1224+
def test_apply_dask_new_output_sizes_not_supplied_same_dim_names() -> None:
1225+
# test for missing output_sizes kwarg sneaking through
1226+
# see GH discussion 7503
1227+
1228+
data = np.random.randn(4, 4, 3, 2)
1229+
da = xr.DataArray(data=data, dims=("x", "y", "i", "j")).chunk(x=1, y=1)
1230+
1231+
with pytest.raises(ValueError, match="output_sizes"):
1232+
xr.apply_ufunc(
1233+
np.linalg.pinv,
1234+
da,
1235+
input_core_dims=[["i", "j"]],
1236+
output_core_dims=[["i", "j"]],
1237+
exclude_dims=set(("i", "j")),
1238+
dask="parallelized",
1239+
)
1240+
1241+
12231242
def pandas_median(x):
12241243
return pd.Series(x).median()
12251244

0 commit comments

Comments
 (0)