Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Index.validate_dataarray_coord #10137

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@
Index.stack
Index.unstack
Index.create_variables
Index.validate_dataarray_coord
Index.to_pandas_index
Index.isel
Index.sel
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ New Features
By `Benoit Bovy <https://github.com/benbovy>`_.
- Support reading to `GPU memory with Zarr <https://zarr.readthedocs.io/en/stable/user-guide/gpu.html>`_ (:pull:`10078`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding
:py:meth:`Index.validate_dataarray_coord`. For example, this enables support for CF boundaries coordinate (e.g.,
``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
28 changes: 11 additions & 17 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
Expand Down Expand Up @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
Expand Down Expand Up @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

Expand Down Expand Up @@ -964,22 +964,16 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords
from xarray.core.dataarray import check_dataarray_coords

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
check_dataarray_coords(
self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims
)

self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down
23 changes: 19 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,33 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
def check_dataarray_coords(
shape: tuple[int, ...],
coords: Coordinates | Mapping[Hashable, Variable],
dim: tuple[Hashable, ...],
):
sizes = dict(zip(dim, shape, strict=True))

indexes: Mapping[Hashable, Index]
if isinstance(coords, Coordinates):
indexes = coords.xindexes
else:
indexes = {}

dim_set = set(dim)

for k, v in coords.items():
if any(d not in dim for d in v.dims):
if k in indexes:
indexes[k].validate_dataarray_coord(k, v, dim_set)
elif any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if s != sizes[d]:
if d in sizes and s != sizes[d]:
raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
Expand Down Expand Up @@ -214,7 +229,7 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)
check_dataarray_coords(shape, new_coords, dims_tuple)

return new_coords, dims_tuple

Expand Down
13 changes: 12 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,18 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
coords: dict[Hashable, Variable] = {}
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
var_dims = set(self._variables[k].dims)
if k in self._indexes:
try:
self._indexes[k].validate_dataarray_coord(
k, self._variables[k], needed_dims
)
coords[k] = self._variables[k]
except ValueError:
# failback to strict DataArray model check (index may be dropped later)
if var_dims <= needed_dims:
coords[k] = self._variables[k]
elif k in self._coord_names and var_dims <= needed_dims:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))
Expand Down
47 changes: 47 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,53 @@ def create_variables(
else:
return {}

def validate_dataarray_coord(
self,
name: Hashable,
var: Variable,
dims: set[Hashable],
):
"""Validate an index coordinate to be included in a DataArray.

This method is called repeatedly for each coordinate associated with
this index when creating a new DataArray (via its constructor or from a
Dataset) or updating an existing one.

By default raises an error if the dimensions of the coordinate variable
do conflict with the array dimensions (DataArray model).

This method may be overridden in Index subclasses, e.g., to include index
coordinates that does not strictly conform with the DataArray model. This
is useful for example to include (n+1)-dimensional cell boundary
coordinates attached to an index.

When a DataArray is constructed from a Dataset, if the validation fails
Xarray will fail back to propagating the coordinate according to the
default rules for DataArray (i.e., depending on its dimensions), which
may drop this index.

Parameters
----------
name : Hashable
Name of a coordinate associated to this index.
var : Variable
Coordinate variable object.
dims: tuple
Dataarray's dimensions.

Raises
------
ValueError
When validation fails.

"""
if any(d not in dims for d in var.dims):
raise ValueError(
f"coordinate {name} has dimensions {var.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dims}"
)

def to_pandas_index(self) -> pd.Index:
"""Cast this xarray index to a pandas.Index object or raise a
``TypeError`` if this is not supported.
Expand Down
30 changes: 19 additions & 11 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,13 @@ def _assert_indexes_invariants_checks(
k: type(v) for k, v in indexes.items()
}

index_vars = {
k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
if check_default:
index_vars = {
k
for k, v in possible_coord_variables.items()
if isinstance(v, IndexVariable)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)

# check pandas index wrappers vs. coordinate data adapters
for k, index in indexes.items():
Expand Down Expand Up @@ -395,13 +398,18 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):

assert isinstance(da._coords, dict), da._coords
assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

if check_default_indexes:
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

for k, v in da._coords.items():
_assert_variable_invariants(v, k)

Expand Down
42 changes: 42 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,26 @@ class CustomIndex(Index): ...
# test coordinate variables copied
assert da.coords["x"] is not coords.variables["x"]

def test_constructor_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
def validate_dataarray_coord(self, name, var, dims):
# pass all index coordinates
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

actual = DataArray([1.0, 2.0], coords=coords, dims="x")

assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_equals_and_identical(self) -> None:
orig = DataArray(np.arange(5.0), {"a": 42}, dims="x")

Expand Down Expand Up @@ -1634,6 +1654,28 @@ def test_assign_coords_no_default_index(self) -> None:
assert_identical(actual.coords, coords, check_default_indexes=False)
assert "y" not in actual.xindexes

def test_assign_coords_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
def validate_dataarray_coord(self, name, var, dims):
# pass all index coordinates
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

da = DataArray([1.0, 2.0], dims="x")
actual = da.assign_coords(coords)
expected = DataArray([1.0, 2.0], coords=coords, dims="x")

assert_identical(actual, expected, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_coords_alignment(self) -> None:
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])
Expand Down
21 changes: 21 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4206,6 +4206,27 @@ def test_getitem_multiple_dtype(self) -> None:
dataset = Dataset({key: ("dim0", range(1)) for key in keys})
assert_identical(dataset, dataset[keys])

def test_getitem_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
def validate_dataarray_coord(self, name, var, dims):
# pass all index coordinates
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords)
actual = ds["foo"]

assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_virtual_variables_default_coords(self) -> None:
dataset = Dataset({"foo": ("x", range(10))})
expected1 = DataArray(range(10), dims="x", name="x")
Expand Down
Loading