Skip to content

Add Index.should_add_coord_to_array #10137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0707a8b
typing fixes and tweaks
benbovy Mar 17, 2025
75086ef
add Index.validate_dataarray_coord()
benbovy Mar 17, 2025
8aaf2b8
Dataset._construct_dataarray: validate index coord
benbovy Mar 17, 2025
c9b4baa
DataArray init: validate index coord
benbovy Mar 17, 2025
a47523f
clean-up old TODO
benbovy Mar 17, 2025
551808a
refactor dataarray coord update
benbovy Mar 17, 2025
818b7f5
docstring tweaks
benbovy Mar 17, 2025
e8df9b5
add tests
benbovy Mar 13, 2025
678c013
assert invariants: skip check IndexVariable ...
benbovy Mar 14, 2025
0f822b5
update cherry-picked tests
benbovy Mar 17, 2025
43c44ea
update assert datarray invariants
benbovy Mar 17, 2025
3b33263
doc: add Index.validate_dataarray_coords to API
benbovy Mar 17, 2025
a8e6e20
typo
benbovy Mar 17, 2025
f1440c4
update whats new
benbovy Mar 17, 2025
5da014e
add CoordinateValidationError
benbovy Mar 18, 2025
6026656
docstrings tweaks
benbovy Mar 18, 2025
1eeec9c
nit refactor
benbovy Mar 18, 2025
426ddce
small refactor
benbovy Mar 18, 2025
5c0cc0f
Merge branch 'main' into index-validate-dataarray-coords
benbovy Mar 27, 2025
4399036
docstrings improvements
benbovy Mar 31, 2025
828a4cc
docstrings improvements
benbovy Mar 31, 2025
273d70c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2025
f49c83a
Merge branch 'main' into index-validate-dataarray-coords
benbovy Apr 23, 2025
3e55af0
refactor index check method
benbovy Apr 24, 2025
073c0a2
small refactor
benbovy Apr 24, 2025
8d43dcc
forgot updating API docs and whats new
benbovy Apr 24, 2025
4e7c70a
nit docstrings
benbovy Apr 24, 2025
b0f6782
Merge branch 'main' into index-validate-dataarray-coords
dcherian Apr 26, 2025
bf557f8
Merge branch 'main' into index-validate-dataarray-coords
dcherian Apr 26, 2025
df828b8
Merge branch 'main' into index-validate-dataarray-coords
dcherian Apr 28, 2025
fa574bc
rename method to Index.should_add_coord_to_array
benbovy May 5, 2025
524b7dc
Merge branch 'main' into index-validate-dataarray-coords
benbovy May 5, 2025
15e4159
review suggestion
benbovy May 6, 2025
67bf943
review suggestion 2
benbovy May 6, 2025
d3cbb3a
more docstrings tweaks
benbovy May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@
Index.stack
Index.unstack
Index.create_variables
Index.validate_dataarray_coord
Index.to_pandas_index
Index.isel
Index.sel
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ New Features
By `Benoit Bovy <https://github.com/benbovy>`_.
- Support reading to `GPU memory with Zarr <https://zarr.readthedocs.io/en/stable/user-guide/gpu.html>`_ (:pull:`10078`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding
:py:meth:`Index.validate_dataarray_coord`. For example, this enables support for CF boundaries coordinate (e.g.,
``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
28 changes: 11 additions & 17 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
Expand Down Expand Up @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
Expand Down Expand Up @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

Expand Down Expand Up @@ -964,22 +964,16 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords
from xarray.core.dataarray import check_dataarray_coords

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
check_dataarray_coords(
self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims
)

self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down
23 changes: 19 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,33 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
def check_dataarray_coords(
shape: tuple[int, ...],
coords: Coordinates | Mapping[Hashable, Variable],
dim: tuple[Hashable, ...],
):
sizes = dict(zip(dim, shape, strict=True))

indexes: Mapping[Hashable, Index]
if isinstance(coords, Coordinates):
indexes = coords.xindexes
else:
indexes = {}

dim_set = set(dim)

for k, v in coords.items():
if any(d not in dim for d in v.dims):
if k in indexes:
indexes[k].validate_dataarray_coord(k, v, dim_set)
elif any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if s != sizes[d]:
if d in sizes and s != sizes[d]:
raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
Expand Down Expand Up @@ -214,7 +229,7 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)
check_dataarray_coords(shape, new_coords, dims_tuple)

return new_coords, dims_tuple

Expand Down
13 changes: 12 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,18 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
coords: dict[Hashable, Variable] = {}
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
var_dims = set(self._variables[k].dims)
if k in self._indexes:
try:
self._indexes[k].validate_dataarray_coord(
k, self._variables[k], needed_dims
)
coords[k] = self._variables[k]
except ValueError:
# failback to strict DataArray model check (index may be dropped later)
if var_dims <= needed_dims:
coords[k] = self._variables[k]
elif k in self._coord_names and var_dims <= needed_dims:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))
Expand Down
47 changes: 47 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,53 @@ def create_variables(
else:
return {}

def validate_dataarray_coord(
self,
name: Hashable,
var: Variable,
dims: set[Hashable],
):
"""Validate an index coordinate to be included in a DataArray.

This method is called repeatedly for each coordinate associated with
this index when creating a new DataArray (via its constructor or from a
Dataset) or updating an existing one.

By default raises an error if the dimensions of the coordinate variable
do conflict with the array dimensions (DataArray model).

This method may be overridden in Index subclasses, e.g., to include index
coordinates that does not strictly conform with the DataArray model. This
is useful for example to include (n+1)-dimensional cell boundary
coordinates attached to an index.

When a DataArray is constructed from a Dataset, if the validation fails
Xarray will fail back to propagating the coordinate according to the
default rules for DataArray (i.e., depending on its dimensions), which
may drop this index.

Parameters
----------
name : Hashable
Name of a coordinate associated to this index.
var : Variable
Coordinate variable object.
dims: tuple
Dataarray's dimensions.

Raises
------
ValueError
When validation fails.

"""
if any(d not in dims for d in var.dims):
raise ValueError(
f"coordinate {name} has dimensions {var.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dims}"
)

def to_pandas_index(self) -> pd.Index:
"""Cast this xarray index to a pandas.Index object or raise a
``TypeError`` if this is not supported.
Expand Down
30 changes: 19 additions & 11 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,13 @@ def _assert_indexes_invariants_checks(
k: type(v) for k, v in indexes.items()
}

index_vars = {
k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
if check_default:
index_vars = {
k
for k, v in possible_coord_variables.items()
if isinstance(v, IndexVariable)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)

# check pandas index wrappers vs. coordinate data adapters
for k, index in indexes.items():
Expand Down Expand Up @@ -395,13 +398,18 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):

assert isinstance(da._coords, dict), da._coords
assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

if check_default_indexes:
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

for k, v in da._coords.items():
_assert_variable_invariants(v, k)

Expand Down
42 changes: 42 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,26 @@ class CustomIndex(Index): ...
# test coordinate variables copied
assert da.coords["x"] is not coords.variables["x"]

def test_constructor_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
def validate_dataarray_coord(self, name, var, dims):
# pass all index coordinates
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

actual = DataArray([1.0, 2.0], coords=coords, dims="x")

assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_equals_and_identical(self) -> None:
orig = DataArray(np.arange(5.0), {"a": 42}, dims="x")

Expand Down Expand Up @@ -1634,6 +1654,28 @@ def test_assign_coords_no_default_index(self) -> None:
assert_identical(actual.coords, coords, check_default_indexes=False)
assert "y" not in actual.xindexes

def test_assign_coords_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
def validate_dataarray_coord(self, name, var, dims):
# pass all index coordinates
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

da = DataArray([1.0, 2.0], dims="x")
actual = da.assign_coords(coords)
expected = DataArray([1.0, 2.0], coords=coords, dims="x")

assert_identical(actual, expected, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_coords_alignment(self) -> None:
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])
Expand Down
21 changes: 21 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4206,6 +4206,27 @@ def test_getitem_multiple_dtype(self) -> None:
dataset = Dataset({key: ("dim0", range(1)) for key in keys})
assert_identical(dataset, dataset[keys])

def test_getitem_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
def validate_dataarray_coord(self, name, var, dims):
# pass all index coordinates
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords)
actual = ds["foo"]

assert_identical(actual.coords, coords, check_default_indexes=False)
assert "x_bnds" not in actual.dims

def test_virtual_variables_default_coords(self) -> None:
dataset = Dataset({"foo": ("x", range(10))})
expected1 = DataArray(range(10), dims="x", name="x")
Expand Down
Loading