Skip to content

Commit 40b7941

Browse files
authored
compatibility with xr.DataTree (#607)
* compatibility with xr.DataTree * fix imports * restore mesmer-x tests * simplify skip_empty_nodes * Update tests/integration/test_calibrate_mesmer_newcodepath.py * simplify map_over_datasets after pydata/xarray#10012 * informative error for wrong xr version * add kwargs dict * fix version check for dev version * changelog
1 parent 615a0c7 commit 40b7941

15 files changed

+321
-85
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ This release implements using :py:class:`DataTree` from `xarray-datatree` to han
9191
- Add `xarray-datatree` as dependency (`#554 <https://github.com/MESMER-group/mesmer/pull/554>`_)
9292
- Add calibration integration tests for multiple scenarios and change parameter files to netcdfs with new naming structure (`#537 <https://github.com/MESMER-group/mesmer/pull/537>`_)
9393
- Add new integration tests for drawing realisations (`#599 <https://github.com/MESMER-group/mesmer/pull/599>`_)
94+
- Port the functionality to xarray's :py:class:`DataTree` implementation (`#607 <https://github.com/MESMER-group/mesmer/pull/607>`_).
9495

9596
By `Victoria Bauer`_ and `Mathias Hauser`_.
9697

ci/install-upstream-wheels.sh

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ conda uninstall -y --force \
1212
regionmask \
1313
scikit-learn \
1414
scipy \
15-
statsmodels
16-
# xarray # re-add after removing upper pin for xarray-datatree
15+
statsmodels \
16+
xarray
1717

1818
# keep cartopy & matplotlib: we don't have tests that use them
1919
# keep joblib: we want to move away from pickle files
@@ -32,8 +32,8 @@ python -m pip install --no-deps --upgrade \
3232
git+https://github.com/fatiando/pooch \
3333
git+https://github.com/geopandas/geopandas \
3434
git+https://github.com/properscoring/properscoring \
35+
git+https://github.com/pydata/xarray \
3536
git+https://github.com/pypa/packaging \
3637
git+https://github.com/pyproj4/pyproj \
3738
git+https://github.com/regionmask/regionmask \
3839
git+https://github.com/SciTools/nc-time-axis
39-
# git+https://github.com/pydata/xarray \

mesmer/core/_datatreecompat.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import functools
2+
3+
import xarray as xr
4+
from packaging.version import Version
5+
6+
if Version(xr.__version__) < Version("2024.10"):
7+
8+
from datatree import DataTree, map_over_subtree, open_datatree
9+
10+
def map_over_datasets(func, *args, kwargs=None):
11+
"compatibility layer for older xarray versions"
12+
13+
if kwargs is None:
14+
kwargs = {}
15+
16+
return map_over_subtree(func)(*args, **kwargs)
17+
18+
elif Version(xr.__version__) > Version("2025.01"):
19+
20+
def skip_empty_nodes(func):
21+
@functools.wraps(func)
22+
def _func(ds, *args, **kwargs):
23+
if not ds:
24+
return ds
25+
return func(ds, *args, **kwargs)
26+
27+
return _func
28+
29+
from xarray import DataTree, open_datatree
30+
from xarray import map_over_datasets as _map_over_datasets
31+
32+
def map_over_datasets(func, *args, kwargs=None):
33+
34+
return _map_over_datasets(skip_empty_nodes(func), *args, kwargs=kwargs)
35+
36+
else:
37+
raise ImportError(
38+
f"xarray version {xr.__version__} not supported - please upgrade to v2025.02 ("
39+
"or later) or downgrade to v2024.09"
40+
)
41+
42+
__all__ = [
43+
"DataTree",
44+
"map_over_datasets",
45+
"open_datatree",
46+
]

mesmer/core/datatree.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import overload
22

33
import xarray as xr
4-
from datatree import DataTree
4+
5+
from mesmer.core._datatreecompat import DataTree, map_over_datasets
56

67

78
def _extract_single_dataarray_from_dt(dt: DataTree, name: str = "node") -> xr.DataArray:
@@ -157,14 +158,23 @@ def stack_datatrees_for_linear_regression(
157158
predictors_stacked = DataTree()
158159
for key, pred in predictors.items():
159160
# 1) broadcast to target
160-
pred_broadcast = pred.broadcast_like(target, exclude=exclude_dim)
161+
# TODO: use DataTree method again, once available
162+
# pred_broadcast = pred.broadcast_like(target, exclude=exclude_dim)
163+
pred_broadcast = map_over_datasets(
164+
xr.Dataset.broadcast_like, pred, target, kwargs={"exclude": exclude_dim}
165+
)
166+
161167
# 2) collapse into DataSets
162168
predictor_ds = collapse_datatree_into_dataset(pred_broadcast, dim=collapse_dim)
163169
# 3) stack
164170
predictors_stacked[key] = DataTree(
165171
predictor_ds.stack(stack_dim, create_index=False)
166172
)
167-
predictors_stacked = predictors_stacked.dropna(dim=stacked_dim)
173+
# TODO: use DataTree method again, once available
174+
# predictors_stacked = predictors_stacked.dropna(dim=stacked_dim)
175+
predictors_stacked = map_over_datasets(
176+
xr.Dataset.dropna, predictors_stacked, kwargs={"dim": stacked_dim}
177+
)
168178

169179
# prepare target
170180
# 1) collapse into DataSet

mesmer/core/weighted.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import numpy as np
44
import xarray as xr
5-
from datatree import DataTree, map_over_subtree
5+
6+
from mesmer.core._datatreecompat import DataTree, map_over_datasets
67

78

89
def _weighted_if_dim(obj, weights, dims):
@@ -190,6 +191,6 @@ def _create_weights(ds: xr.Dataset) -> xr.Dataset:
190191

191192
return xr.Dataset({"weights": weights})
192193

193-
weights = map_over_subtree(_create_weights)(dt)
194+
weights = map_over_datasets(_create_weights, dt)
194195

195196
return weights

mesmer/stats/_auto_regression.py

+149-15
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import pandas as pd
77
import scipy
88
import xarray as xr
9-
from datatree import DataTree, map_over_subtree
109

10+
from mesmer.core._datatreecompat import DataTree, map_over_datasets
1111
from mesmer.core.datatree import (
1212
collapse_datatree_into_dataset,
1313
)
@@ -51,13 +51,19 @@ def _scen_ens_inputs_to_dt(objs: Sequence) -> DataTree:
5151
return dt
5252

5353

54-
def _extract_and_apply_to_da(func: Callable, ds: xr.Dataset, **kwargs) -> xr.Dataset:
54+
def _extract_and_apply_to_da(func: Callable) -> Callable:
5555

56-
name, *others = ds.data_vars
57-
if others:
58-
raise ValueError("Dataset must have only one data variable.")
56+
def inner(ds: xr.Dataset, **kwargs) -> xr.Dataset:
5957

60-
return func(ds[name], **kwargs)
58+
name, *others = ds.data_vars
59+
if others:
60+
raise ValueError("Dataset must have only one data variable.")
61+
62+
x = func(ds[name], **kwargs)
63+
64+
return x.to_dataset() if isinstance(x, xr.DataArray) else x
65+
66+
return inner
6167

6268

6369
def select_ar_order_scen_ens(
@@ -137,8 +143,10 @@ def _select_ar_order_scen_ens_dt(
137143
then over all scenarios.
138144
"""
139145

140-
ar_order_scen = map_over_subtree(_extract_and_apply_to_da)(
141-
select_ar_order, dt, dim=dim, maxlag=maxlag, ic=ic
146+
ar_order_scen = map_over_datasets(
147+
_extract_and_apply_to_da(select_ar_order),
148+
dt,
149+
kwargs={"dim": dim, "maxlag": maxlag, "ic": ic},
142150
)
143151

144152
# TODO: think about weighting?
@@ -147,7 +155,7 @@ def _ens_quantile(ds, ens_dim):
147155
return ds.quantile(dim=ens_dim, q=0.5, method="nearest")
148156
return ds
149157

150-
ar_order_ens_median = map_over_subtree(_ens_quantile)(ar_order_scen, ens_dim)
158+
ar_order_ens_median = map_over_datasets(_ens_quantile, ar_order_scen, ens_dim)
151159

152160
ar_order_ens_median_ds = collapse_datatree_into_dataset(
153161
ar_order_ens_median, dim="scen"
@@ -237,8 +245,10 @@ def _fit_auto_regression_scen_ens_dt(
237245
If no ensemble members are provided, the mean is calculated over scenarios only.
238246
"""
239247

240-
ar_params_scen = map_over_subtree(_extract_and_apply_to_da)(
241-
fit_auto_regression, dt, dim=dim, lags=int(lags)
248+
ar_params_scen = map_over_datasets(
249+
_extract_and_apply_to_da(fit_auto_regression),
250+
dt,
251+
kwargs={"dim": dim, "lags": int(lags)},
242252
)
243253

244254
# TODO: think about weighting! see https://github.com/MESMER-group/mesmer/issues/307
@@ -247,7 +257,7 @@ def _ens_mean(ds, ens_dim):
247257
return ds.mean(ens_dim)
248258
return ds
249259

250-
ar_params_scen = map_over_subtree(_ens_mean)(ar_params_scen, ens_dim)
260+
ar_params_scen = map_over_datasets(_ens_mean, ar_params_scen, ens_dim)
251261

252262
ar_params_scen = collapse_datatree_into_dataset(ar_params_scen, dim="scen")
253263

@@ -413,6 +423,44 @@ def draw_auto_regression_uncorrelated(
413423
n_time x n_coeffs x n_realisations.
414424
415425
"""
426+
427+
if isinstance(seed, DataTree):
428+
return map_over_datasets(
429+
_draw_auto_regression_uncorrelated,
430+
seed,
431+
ar_params,
432+
kwargs={
433+
"time": time,
434+
"realisation": realisation,
435+
"buffer": buffer,
436+
"time_dim": time_dim,
437+
"realisation_dim": realisation_dim,
438+
},
439+
)
440+
441+
else:
442+
return _draw_auto_regression_uncorrelated(
443+
seed,
444+
ar_params,
445+
time=time,
446+
realisation=realisation,
447+
buffer=buffer,
448+
time_dim=time_dim,
449+
realisation_dim=realisation_dim,
450+
)["samples"]
451+
452+
453+
def _draw_auto_regression_uncorrelated(
454+
seed: int | xr.Dataset,
455+
ar_params: xr.Dataset,
456+
*,
457+
time: int | xr.DataArray | pd.Index,
458+
realisation: int | xr.DataArray | pd.Index,
459+
buffer: int,
460+
time_dim: str = "time",
461+
realisation_dim: str = "realisation",
462+
) -> xr.DataArray:
463+
416464
# NOTE: we use variance and not std since we use multivariate normal
417465
# also to draw univariate realizations
418466
# check the input
@@ -450,7 +498,7 @@ def draw_auto_regression_uncorrelated(
450498
# remove the "__gridpoint__" dim again
451499
result = result.squeeze(dim="__gridpoint__", drop=True)
452500

453-
return result.rename("samples")
501+
return result.rename("samples").to_dataset()
454502

455503

456504
def draw_auto_regression_correlated(
@@ -513,6 +561,48 @@ def draw_auto_regression_correlated(
513561
514562
"""
515563

564+
if isinstance(seed, DataTree):
565+
566+
return map_over_datasets(
567+
_draw_auto_regression_correlated,
568+
seed,
569+
ar_params,
570+
covariance,
571+
kwargs={
572+
"time": time,
573+
"realisation": realisation,
574+
"buffer": buffer,
575+
"time_dim": time_dim,
576+
"realisation_dim": realisation_dim,
577+
},
578+
)
579+
580+
else:
581+
582+
return _draw_auto_regression_correlated(
583+
seed,
584+
ar_params,
585+
covariance,
586+
time=time,
587+
realisation=realisation,
588+
buffer=buffer,
589+
time_dim=time_dim,
590+
realisation_dim=realisation_dim,
591+
)["samples"]
592+
593+
594+
def _draw_auto_regression_correlated(
595+
seed: int | xr.Dataset,
596+
ar_params: xr.Dataset,
597+
covariance: xr.DataArray,
598+
*,
599+
time: int | xr.DataArray | pd.Index,
600+
realisation: int | xr.DataArray | pd.Index,
601+
buffer: int,
602+
time_dim: str = "time",
603+
realisation_dim: str = "realisation",
604+
) -> xr.DataArray:
605+
516606
# check the input
517607
_check_dataset_form(ar_params, "ar_params", required_vars={"intercept", "coeffs"})
518608
_check_dataarray_form(ar_params.intercept, "intercept", ndim=1)
@@ -538,7 +628,7 @@ def draw_auto_regression_correlated(
538628
realisation_dim=realisation_dim,
539629
)
540630

541-
return result.rename("samples")
631+
return result.rename("samples").to_dataset()
542632

543633

544634
def _draw_ar_corr_xr_internal(
@@ -943,6 +1033,50 @@ def draw_auto_regression_monthly(
9431033
correlated innovations. The array has shape n_timesteps x n_gridpoints.
9441034
9451035
"""
1036+
1037+
if isinstance(seed, DataTree):
1038+
1039+
return map_over_datasets(
1040+
_draw_auto_regression_monthly,
1041+
seed,
1042+
ar_params,
1043+
covariance,
1044+
kwargs={
1045+
"time": time,
1046+
"n_realisations": n_realisations,
1047+
"buffer": buffer,
1048+
"time_dim": time_dim,
1049+
"realisation_dim": realisation_dim,
1050+
},
1051+
)
1052+
1053+
else:
1054+
return _draw_auto_regression_monthly(
1055+
seed,
1056+
ar_params,
1057+
covariance,
1058+
time=time,
1059+
n_realisations=n_realisations,
1060+
buffer=buffer,
1061+
time_dim=time_dim,
1062+
realisation_dim=realisation_dim,
1063+
)["samples"]
1064+
1065+
1066+
def _draw_auto_regression_monthly(
1067+
seed,
1068+
ar_params: xr.Dataset,
1069+
covariance: xr.DataArray,
1070+
*,
1071+
time: xr.DataArray | pd.Index,
1072+
n_realisations: int,
1073+
buffer: int,
1074+
time_dim: str = "time",
1075+
realisation_dim: str = "realisation",
1076+
) -> xr.DataArray:
1077+
1078+
# NOTE: seed must be the first positional argument for map_over_datasets to work
1079+
9461080
# check input
9471081
_check_dataset_form(ar_params, "ar_params", required_vars={"intercept", "slope"})
9481082
month_dim, gridcell_dim = ar_params.intercept.dims
@@ -975,7 +1109,7 @@ def draw_auto_regression_monthly(
9751109
realisation_dim=realisation_dim,
9761110
)
9771111

978-
return result.rename("samples")
1112+
return result.rename("samples").to_dataset()
9791113

9801114

9811115
def _draw_ar_corr_monthly_xr_internal(

0 commit comments

Comments
 (0)