Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Index.validate_dataarray_coord #10137

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@
Index.stack
Index.unstack
Index.create_variables
Index.validate_dataarray_coord
Index.to_pandas_index
Index.isel
Index.sel
Expand Down
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,7 @@ Exceptions
.. autosummary::
:toctree: generated/

CoordinateValidationError
MergeError
SerializationWarning

Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ New Features
(:pull:`9498`). By `Spencer Clark <https://github.com/spencerkclark>`_.
- Support reading to `GPU memory with Zarr <https://zarr.readthedocs.io/en/stable/user-guide/gpu.html>`_ (:pull:`10078`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding
:py:meth:`Index.validate_dataarray_coord`. For example, this enables support for CF boundaries coordinate (e.g.,
``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Performance
~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from xarray.conventions import SerializationWarning, decode_cf
from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like
from xarray.core.coordinates import Coordinates
from xarray.core.coordinates import Coordinates, CoordinateValidationError
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
Expand Down Expand Up @@ -128,6 +128,7 @@
"NamedArray",
"Variable",
# Exceptions
"CoordinateValidationError",
"InvalidTreeError",
"MergeError",
"NotFoundInTreeError",
Expand Down
76 changes: 57 additions & 19 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
Expand Down Expand Up @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
Expand Down Expand Up @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

Expand Down Expand Up @@ -964,22 +964,14 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords
validate_dataarray_coords(
self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims
)

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down Expand Up @@ -1154,9 +1146,55 @@ def create_coords_with_default_indexes(
return new_coords


def _coordinates_from_variable(variable: Variable) -> Coordinates:
from xarray.core.indexes import create_default_index_implicit
class CoordinateValidationError(ValueError):
"""Error class for Xarray coordinate validation failures."""


def validate_dataarray_coords(
shape: tuple[int, ...],
coords: Coordinates | Mapping[Hashable, Variable],
dim: tuple[Hashable, ...],
):
"""Validate coordinates ``coords`` to include in a DataArray defined by
``shape`` and dimensions ``dim``.

If a coordinate is associated with an index, the validation is performed by
the index. By default the coordinate dimensions must match (a subset of) the
array dimensions (in any order) to conform to the DataArray model. The index
may override this behavior with other validation rules, though.

Non-index coordinates must all conform to the DataArray model. Scalar
coordinates are always valid.
"""
sizes = dict(zip(dim, shape, strict=True))
dim_set = set(dim)

indexes: Mapping[Hashable, Index]
if isinstance(coords, Coordinates):
indexes = coords.xindexes
else:
indexes = {}

for k, v in coords.items():
if k in indexes:
indexes[k].validate_dataarray_coord(k, v, dim_set)
elif any(d not in dim for d in v.dims):
raise CoordinateValidationError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if d in sizes and s != sizes[d]:
raise CoordinateValidationError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
f"coordinate {k!r}"
)


def coordinates_from_variable(variable: Variable) -> Coordinates:
(name,) = variable.dims
new_index, index_vars = create_default_index_implicit(variable)
indexes = {k: new_index for k in index_vars}
Expand Down
22 changes: 2 additions & 20 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DataArrayCoordinates,
assert_coordinate_consistent,
create_coords_with_default_indexes,
validate_dataarray_coords,
)
from xarray.core.dataset import Dataset
from xarray.core.extension_array import PandasExtensionArray
Expand Down Expand Up @@ -132,25 +133,6 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
sizes = dict(zip(dim, shape, strict=True))
for k, v in coords.items():
if any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if s != sizes[d]:
raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
f"coordinate {k!r}"
)


def _infer_coords_and_dims(
shape: tuple[int, ...],
coords: (
Expand Down Expand Up @@ -214,7 +196,7 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)
validate_dataarray_coords(shape, new_coords, dims_tuple)

return new_coords, dims_tuple

Expand Down
14 changes: 13 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from xarray.core.coordinates import (
Coordinates,
CoordinateValidationError,
DatasetCoordinates,
assert_coordinate_consistent,
)
Expand Down Expand Up @@ -1159,7 +1160,18 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
coords: dict[Hashable, Variable] = {}
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
var_dims = set(self._variables[k].dims)
if k in self._indexes:
try:
self._indexes[k].validate_dataarray_coord(
k, self._variables[k], needed_dims
)
coords[k] = self._variables[k]
except CoordinateValidationError:
# failback to strict DataArray model check (index may be dropped later)
if var_dims <= needed_dims:
coords[k] = self._variables[k]
elif k in self._coord_names and var_dims <= needed_dims:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DatasetGroupByAggregations,
)
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.coordinates import Coordinates, coordinates_from_variable
from xarray.core.duck_array_ops import where
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def _flox_reduce(
new_coords.append(
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
# all associated levels properly.
_coordinates_from_variable(
coordinates_from_variable(
IndexVariable(
dims=grouper.name,
data=output_index,
Expand Down
53 changes: 53 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,59 @@ def create_variables(
else:
return {}

def validate_dataarray_coord(
self,
name: Hashable,
var: Variable,
dims: set[Hashable],
):
"""Validate an index coordinate variable to include in a DataArray.

This method is called repeatedly for each coordinate associated with
this index when creating a new DataArray (via its constructor or from a
Dataset) or updating an existing one.

By default raises a :py:class:`CoordinateValidationError` if the
dimensions of the coordinate variable do conflict with the array
dimensions (DataArray model).

This method may be overridden in Index subclasses, e.g., to validate
index coordinates even when they do not strictly conform with the
DataArray model. This is useful for example to include (n+1)-dimensional
cell boundary coordinates attached to an index.

If the validation passes (i.e., no error raised), the coordinate will be
included in the DataArray regardless of its dimensions.

When a DataArray is constructed from a Dataset (variable access), if the
validation fails Xarray will fail back to propagating the coordinate
according to the default rules for DataArray (i.e., depending on its
dimensions), which may drop this index.

Parameters
----------
name : Hashable
Name of a coordinate variable associated to this index.
var : Variable
Coordinate variable object.
dims: tuple
Dataarray's dimensions.

Raises
------
CoordinateValidationError
When validation fails.

"""
from xarray.core.coordinates import CoordinateValidationError

if any(d not in dims for d in var.dims):
raise CoordinateValidationError(
f"coordinate {name} has dimensions {var.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dims}"
)

def to_pandas_index(self) -> pd.Index:
"""Cast this xarray index to a pandas.Index object or raise a
``TypeError`` if this is not supported.
Expand Down
12 changes: 6 additions & 6 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.coordinates import Coordinates, coordinates_from_variable
from xarray.core.dataarray import DataArray
from xarray.core.duck_array_ops import array_all, isnull
from xarray.core.groupby import T_Group, _DummyGroup
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(

if coords is None:
assert not isinstance(self.unique_coord, _DummyGroup)
self.coords = _coordinates_from_variable(self.unique_coord)
self.coords = coordinates_from_variable(self.unique_coord)
else:
self.coords = coords

Expand Down Expand Up @@ -252,7 +252,7 @@ def _factorize_unique(self) -> EncodedGroups:
codes=codes,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)

def _factorize_dummy(self) -> EncodedGroups:
Expand Down Expand Up @@ -280,7 +280,7 @@ def _factorize_dummy(self) -> EncodedGroups:
else:
if TYPE_CHECKING:
assert isinstance(unique_coord, Variable)
coords = _coordinates_from_variable(unique_coord)
coords = coordinates_from_variable(unique_coord)

return EncodedGroups(
codes=codes,
Expand Down Expand Up @@ -409,7 +409,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
codes=codes,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)


Expand Down Expand Up @@ -543,7 +543,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)


Expand Down
Loading
Loading