Skip to content

Commit 4a52799

Browse files
dcherianandersy005benbovy
authored
Drop multi-indexes when assigning to a multi-indexed variable (#6798)
Co-authored-by: Anderson Banihirwe <[email protected]> Co-authored-by: Benoit Bovy <[email protected]>
1 parent 9f8d47c commit 4a52799

File tree

5 files changed

+77
-1
lines changed

5 files changed

+77
-1
lines changed

xarray/core/coordinates.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3+
import warnings
34
from contextlib import contextmanager
45
from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast
56

67
import numpy as np
78
import pandas as pd
89

910
from . import formatting
10-
from .indexes import Index, Indexes, assert_no_index_corrupted
11+
from .indexes import Index, Indexes, PandasMultiIndex, assert_no_index_corrupted
1112
from .merge import merge_coordinates_without_align, merge_coords
1213
from .utils import Frozen, ReprObject
1314
from .variable import Variable, calculate_dimensions
@@ -57,6 +58,9 @@ def variables(self):
5758
def _update_coords(self, coords, indexes):
5859
raise NotImplementedError()
5960

61+
def _maybe_drop_multiindex_coords(self, coords):
62+
raise NotImplementedError()
63+
6064
def __iter__(self) -> Iterator[Hashable]:
6165
# needs to be in the same order as the dataset variables
6266
for k in self.variables:
@@ -154,6 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
154158

155159
def update(self, other: Mapping[Any, Any]) -> None:
156160
other_vars = getattr(other, "variables", other)
161+
self._maybe_drop_multiindex_coords(set(other_vars))
157162
coords, indexes = merge_coords(
158163
[self.variables, other_vars], priority_arg=1, indexes=self.xindexes
159164
)
@@ -304,6 +309,15 @@ def _update_coords(
304309
original_indexes.update(indexes)
305310
self._data._indexes = original_indexes
306311

312+
def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None:
313+
"""Drops variables in coords, and any associated variables as well."""
314+
assert self._data.xindexes is not None
315+
variables, indexes = drop_coords(
316+
coords, self._data._variables, self._data.xindexes
317+
)
318+
self._data._variables = variables
319+
self._data._indexes = indexes
320+
307321
def __delitem__(self, key: Hashable) -> None:
308322
if key in self:
309323
del self._data[key]
@@ -372,6 +386,14 @@ def _update_coords(
372386
original_indexes.update(indexes)
373387
self._data._indexes = original_indexes
374388

389+
def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None:
390+
"""Drops variables in coords, and any associated variables as well."""
391+
variables, indexes = drop_coords(
392+
coords, self._data._coords, self._data.xindexes
393+
)
394+
self._data._coords = variables
395+
self._data._indexes = indexes
396+
375397
@property
376398
def variables(self):
377399
return Frozen(self._data._coords)
@@ -397,6 +419,37 @@ def _ipython_key_completions_(self):
397419
return self._data._ipython_key_completions_()
398420

399421

422+
def drop_coords(
423+
coords_to_drop: set[Hashable], variables, indexes: Indexes
424+
) -> tuple[dict, dict]:
425+
"""Drop index variables associated with variables in coords_to_drop."""
426+
# Only warn when we're dropping the dimension with the multi-indexed coordinate
427+
# If asked to drop a subset of the levels in a multi-index, we raise an error
428+
# later but skip the warning here.
429+
new_variables = dict(variables.copy())
430+
new_indexes = dict(indexes.copy())
431+
for key in coords_to_drop & set(indexes):
432+
maybe_midx = indexes[key]
433+
idx_coord_names = set(indexes.get_all_coords(key))
434+
if (
435+
isinstance(maybe_midx, PandasMultiIndex)
436+
and key == maybe_midx.dim
437+
and (idx_coord_names - coords_to_drop)
438+
):
439+
warnings.warn(
440+
f"Updating MultiIndexed coordinate {key!r} would corrupt indices for "
441+
f"other variables: {list(maybe_midx.index.names)!r}. "
442+
f"This will raise an error in the future. Use `.drop_vars({idx_coord_names!r})` before "
443+
"assigning new coordinate values.",
444+
DeprecationWarning,
445+
stacklevel=4,
446+
)
447+
for k in idx_coord_names:
448+
del new_variables[k]
449+
del new_indexes[k]
450+
return new_variables, new_indexes
451+
452+
400453
def assert_coordinate_consistent(
401454
obj: DataArray | Dataset, coords: Mapping[Any, Variable]
402455
) -> None:

xarray/core/dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -5764,6 +5764,7 @@ def assign(
57645764
data = self.copy()
57655765
# do all calculations first...
57665766
results: CoercibleMapping = data._calc_assign_results(variables)
5767+
data.coords._maybe_drop_multiindex_coords(set(results.keys()))
57675768
# ... and then assign
57685769
data.update(results)
57695770
return data

xarray/core/indexes.py

+3
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,9 @@ def dims(self) -> Mapping[Hashable, int]:
10851085

10861086
return Frozen(self._dims)
10871087

1088+
def copy(self):
1089+
return type(self)(dict(self._indexes), dict(self._variables))
1090+
10881091
def get_unique(self) -> list[T_PandasOrXarrayIndex]:
10891092
"""Return a list of unique indexes, preserving order."""
10901093

xarray/tests/test_dataarray.py

+7
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,13 @@ def test_assign_coords(self) -> None:
14991499
with pytest.raises(ValueError):
15001500
da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray
15011501

1502+
def test_assign_coords_existing_multiindex(self) -> None:
1503+
data = self.mda
1504+
with pytest.warns(
1505+
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
1506+
):
1507+
data.assign_coords(x=range(4))
1508+
15021509
def test_coords_alignment(self) -> None:
15031510
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
15041511
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])

xarray/tests/test_dataset.py

+12
Original file line numberDiff line numberDiff line change
@@ -3967,6 +3967,18 @@ def test_assign_multiindex_level(self) -> None:
39673967
data.assign(level_1=range(4))
39683968
data.assign_coords(level_1=range(4))
39693969

3970+
def test_assign_coords_existing_multiindex(self) -> None:
3971+
data = create_test_multiindex()
3972+
with pytest.warns(
3973+
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
3974+
):
3975+
data.assign_coords(x=range(4))
3976+
3977+
with pytest.warns(
3978+
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
3979+
):
3980+
data.assign(x=range(4))
3981+
39703982
def test_assign_all_multiindex_coords(self) -> None:
39713983
data = create_test_multiindex()
39723984
actual = data.assign(x=range(4), level_1=range(4), level_2=range(4))

0 commit comments

Comments
 (0)