diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index ac8290b3d1b..d28103d0cb7 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -520,6 +520,7 @@ Index.stack Index.unstack Index.create_variables + Index.validate_dataarray_coord Index.to_pandas_index Index.isel Index.sel diff --git a/doc/api.rst b/doc/api.rst index 67c81aaf601..e4342175279 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1641,6 +1641,7 @@ Exceptions .. autosummary:: :toctree: generated/ + CoordinateValidationError MergeError SerializationWarning diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 86c2106a796..a607bd3aaa8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -90,6 +90,10 @@ New Features (:pull:`9498`). By `Spencer Clark `_. - Support reading to `GPU memory with Zarr `_ (:pull:`10078`). By `Deepak Cherian `_. +- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding + :py:meth:`Index.validate_dataarray_coord`. For example, this enables support for CF boundaries coordinate (e.g., + ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`). + By `Benoit Bovy `_. Performance ~~~~~~~~~~~ diff --git a/xarray/__init__.py b/xarray/__init__.py index 07e6fe5b207..d28947fa17d 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -28,7 +28,7 @@ ) from xarray.conventions import SerializationWarning, decode_cf from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, CoordinateValidationError from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree @@ -128,6 +128,7 @@ "NamedArray", "Variable", # Exceptions + "CoordinateValidationError", "InvalidTreeError", "MergeError", "NotFoundInTreeError", diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 408e9e630ee..686314ab7dc 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool: return self.to_dataset().identical(other.to_dataset()) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: # redirect to DatasetCoordinates._update_coords self._data.coords._update_coords(coords, indexes) @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset: return self._data._copy_listed(names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: variables = self._data._variables.copy() variables.update(coords) @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset: return self._data.dataset._copy_listed(self._names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: from xarray.core.datatree import check_alignment @@ -964,22 +964,14 @@ def __getitem__(self, key: Hashable) -> T_DataArray: return self._data._getitem_coord(key) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: - coords_plus_data = coords.copy() - coords_plus_data[_THIS_ARRAY] = self._data.variable - dims = calculate_dimensions(coords_plus_data) - if not set(dims) <= set(self.dims): - raise ValueError( - "cannot add coordinates with new dimensions to a DataArray" - ) - self._data._coords = coords + validate_dataarray_coords( + self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims + ) - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._coords = coords + self._data._indexes = indexes def _drop_coords(self, coord_names): # should drop indexed coordinates only @@ -1154,9 +1146,55 @@ def create_coords_with_default_indexes( return new_coords -def _coordinates_from_variable(variable: Variable) -> Coordinates: - from xarray.core.indexes import create_default_index_implicit +class CoordinateValidationError(ValueError): + """Error class for Xarray coordinate validation failures.""" + + +def validate_dataarray_coords( + shape: tuple[int, ...], + coords: Coordinates | Mapping[Hashable, Variable], + dim: tuple[Hashable, ...], +): + """Validate coordinates ``coords`` to include in a DataArray defined by + ``shape`` and dimensions ``dim``. + + If a coordinate is associated with an index, the validation is performed by + the index. By default the coordinate dimensions must match (a subset of) the + array dimensions (in any order) to conform to the DataArray model. The index + may override this behavior with other validation rules, though. + + Non-index coordinates must all conform to the DataArray model. Scalar + coordinates are always valid. + """ + sizes = dict(zip(dim, shape, strict=True)) + dim_set = set(dim) + + indexes: Mapping[Hashable, Index] + if isinstance(coords, Coordinates): + indexes = coords.xindexes + else: + indexes = {} + + for k, v in coords.items(): + if k in indexes: + indexes[k].validate_dataarray_coord(k, v, dim_set) + elif any(d not in dim for d in v.dims): + raise CoordinateValidationError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dim}" + ) + + for d, s in v.sizes.items(): + if d in sizes and s != sizes[d]: + raise CoordinateValidationError( + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" + ) + +def coordinates_from_variable(variable: Variable) -> Coordinates: (name,) = variable.dims new_index, index_vars = create_default_index_implicit(variable) indexes = {k: new_index for k in index_vars} diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d39e01f20fc..f9ba56032f3 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -41,6 +41,7 @@ DataArrayCoordinates, assert_coordinate_consistent, create_coords_with_default_indexes, + validate_dataarray_coords, ) from xarray.core.dataset import Dataset from xarray.core.extension_array import PandasExtensionArray @@ -132,25 +133,6 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def _check_coords_dims(shape, coords, dim): - sizes = dict(zip(dim, shape, strict=True)) - for k, v in coords.items(): - if any(d not in dim for d in v.dims): - raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {dim}" - ) - - for d, s in v.sizes.items(): - if s != sizes[d]: - raise ValueError( - f"conflicting sizes for dimension {d!r}: " - f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" - ) - - def _infer_coords_and_dims( shape: tuple[int, ...], coords: ( @@ -214,7 +196,7 @@ def _infer_coords_and_dims( var.dims = (dim,) new_coords[dim] = var.to_index_variable() - _check_coords_dims(shape, new_coords, dims_tuple) + validate_dataarray_coords(shape, new_coords, dims_tuple) return new_coords, dims_tuple diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2ff04475959..1cc439623dd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -48,6 +48,7 @@ ) from xarray.core.coordinates import ( Coordinates, + CoordinateValidationError, DatasetCoordinates, assert_coordinate_consistent, ) @@ -1159,7 +1160,18 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: + var_dims = set(self._variables[k].dims) + if k in self._indexes: + try: + self._indexes[k].validate_dataarray_coord( + k, self._variables[k], needed_dims + ) + coords[k] = self._variables[k] + except CoordinateValidationError: + # failback to strict DataArray model check (index may be dropped later) + if var_dims <= needed_dims: + coords[k] = self._variables[k] + elif k in self._coord_names and var_dims <= needed_dims: coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 6f5472a014a..713addf7d92 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -23,7 +23,7 @@ DatasetGroupByAggregations, ) from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce -from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.duck_array_ops import where from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( @@ -1138,7 +1138,7 @@ def _flox_reduce( new_coords.append( # Using IndexVariable here ensures we reconstruct PandasMultiIndex with # all associated levels properly. - _coordinates_from_variable( + coordinates_from_variable( IndexVariable( dims=grouper.name, data=output_index, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 0b4eee7b21c..ee5976199ed 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -195,6 +195,59 @@ def create_variables( else: return {} + def validate_dataarray_coord( + self, + name: Hashable, + var: Variable, + dims: set[Hashable], + ): + """Validate an index coordinate variable to include in a DataArray. + + This method is called repeatedly for each Variable associated with + this index when creating a new DataArray (via its constructor or from a + Dataset) or updating an existing one. + + By default raises a :py:class:`CoordinateValidationError` if the + dimensions of the coordinate variable do conflict with the array + dimensions (DataArray model). + + This method may be overridden in Index subclasses, e.g., to validate + index coordinates even when they do not strictly conform with the + DataArray model. This is useful for example to include (n+1)-dimensional + cell boundary coordinates attached to an index. + + If the validation passes (i.e., no error raised), the coordinate will be + included in the DataArray regardless of its dimensions. + + If this method raises when a DataArray is constructed from a Dataset, + Xarray will fail back to propagating the coordinate + according to the default rules for DataArray --- i.e., the dimensions of every + coordinate variable must be a subset of DataArray.dims --- which may drop this index. + + Parameters + ---------- + name : Hashable + Name of a coordinate variable associated to this index. + var : Variable + Coordinate variable object. + dims: tuple + Dataarray's dimensions. + + Raises + ------ + CoordinateValidationError + When validation fails. + + """ + from xarray.core.coordinates import CoordinateValidationError + + if any(d not in dims for d in var.dims): + raise CoordinateValidationError( + f"coordinate {name} has dimensions {var.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dims}" + ) + def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a ``TypeError`` if this is not supported. diff --git a/xarray/groupers.py b/xarray/groupers.py index 025f8fae486..af94bf864e0 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -18,7 +18,7 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.computation.apply_ufunc import apply_ufunc -from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import array_all, isnull from xarray.core.groupby import T_Group, _DummyGroup @@ -115,7 +115,7 @@ def __init__( if coords is None: assert not isinstance(self.unique_coord, _DummyGroup) - self.coords = _coordinates_from_variable(self.unique_coord) + self.coords = coordinates_from_variable(self.unique_coord) else: self.coords = coords @@ -252,7 +252,7 @@ def _factorize_unique(self) -> EncodedGroups: codes=codes, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) def _factorize_dummy(self) -> EncodedGroups: @@ -280,7 +280,7 @@ def _factorize_dummy(self) -> EncodedGroups: else: if TYPE_CHECKING: assert isinstance(unique_coord, Variable) - coords = _coordinates_from_variable(unique_coord) + coords = coordinates_from_variable(unique_coord) return EncodedGroups( codes=codes, @@ -409,7 +409,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: codes=codes, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) @@ -543,7 +543,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: group_indices=group_indices, full_index=full_index, unique_coord=unique_coord, - coords=_coordinates_from_variable(unique_coord), + coords=coordinates_from_variable(unique_coord), ) diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 8a2dba9261f..15a239894fb 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -330,10 +330,13 @@ def _assert_indexes_invariants_checks( k: type(v) for k, v in indexes.items() } - index_vars = { - k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable) - } - assert indexes.keys() <= index_vars, (set(indexes), index_vars) + if check_default: + index_vars = { + k + for k, v in possible_coord_variables.items() + if isinstance(v, IndexVariable) + } + assert indexes.keys() <= index_vars, (set(indexes), index_vars) # check pandas index wrappers vs. coordinate data adapters for k, index in indexes.items(): @@ -395,13 +398,18 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): assert isinstance(da._coords, dict), da._coords assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords - assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( - da.dims, - {k: v.dims for k, v in da._coords.items()}, - ) - assert all( - isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,) - ), {k: type(v) for k, v in da._coords.items()} + + if check_default_indexes: + assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( + da.dims, + {k: v.dims for k, v in da._coords.items()}, + ) + assert all( + isinstance(v, IndexVariable) + for (k, v) in da._coords.items() + if v.dims == (k,) + ), {k: type(v) for k, v in da._coords.items()} + for k, v in da._coords.items(): _assert_variable_invariants(v, k) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 23fd90d2721..626911d344e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -33,8 +33,12 @@ from xarray.coders import CFDatetimeCoder from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.coordinates import Coordinates -from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords +from xarray.core.coordinates import Coordinates, CoordinateValidationError +from xarray.core.indexes import ( + Index, + PandasIndex, + filter_indexes_from_coords, +) from xarray.core.types import QueryEngineOptions, QueryParserOptions from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants @@ -418,9 +422,13 @@ def test_constructor_invalid(self) -> None: with pytest.raises(TypeError, match=r"is not hashable"): DataArray(data, dims=["x", []]) # type: ignore[list-item] - with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + with pytest.raises( + CoordinateValidationError, match=r"conflicting sizes for dim" + ): DataArray([1, 2, 3], coords=[("x", [0, 1])]) - with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + with pytest.raises( + CoordinateValidationError, match=r"conflicting sizes for dim" + ): DataArray([1, 2], coords={"x": [0, 1], "y": ("x", [1])}, dims="x") with pytest.raises(ValueError, match=r"conflicting MultiIndex"): @@ -529,6 +537,26 @@ class CustomIndex(Index): ... # test coordinate variables copied assert da.coords["x"] is not coords.variables["x"] + def test_constructor_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + def validate_dataarray_coord(self, name, var, dims): + # pass all index coordinates + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + actual = DataArray([1.0, 2.0], coords=coords, dims="x") + + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_equals_and_identical(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") @@ -1602,11 +1630,11 @@ def test_assign_coords(self) -> None: # GH: 2112 da = xr.DataArray([0, 1, 2], dims="x") - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da["x"] = [0, 1, 2, 3] # size conflict - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da.coords["x"] = [0, 1, 2, 3] # size conflict - with pytest.raises(ValueError): + with pytest.raises(CoordinateValidationError): da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray def test_assign_coords_existing_multiindex(self) -> None: @@ -1634,6 +1662,28 @@ def test_assign_coords_no_default_index(self) -> None: assert_identical(actual.coords, coords, check_default_indexes=False) assert "y" not in actual.xindexes + def test_assign_coords_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + def validate_dataarray_coord(self, name, var, dims): + # pass all index coordinates + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + da = DataArray([1.0, 2.0], dims="x") + actual = da.assign_coords(coords) + expected = DataArray([1.0, 2.0], coords=coords, dims="x") + + assert_identical(actual, expected, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b273b7d1a0d..d3a51c404e4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4206,6 +4206,27 @@ def test_getitem_multiple_dtype(self) -> None: dataset = Dataset({key: ("dim0", range(1)) for key in keys}) assert_identical(dataset, dataset[keys]) + def test_getitem_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + def validate_dataarray_coord(self, name, var, dims): + # pass all index coordinates + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords) + actual = ds["foo"] + + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_virtual_variables_default_coords(self) -> None: dataset = Dataset({"foo": ("x", range(10))}) expected1 = DataArray(range(10), dims="x", name="x")