Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatibility with xr.DataTree #607

Merged
merged 15 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ This release implements using :py:class:`DataTree` from `xarray-datatree` to han
- Add `xarray-datatree` as dependency (`#554 <https://github.com/MESMER-group/mesmer/pull/554>`_)
- Add calibration integration tests for multiple scenarios and change parameter files to netcdfs with new naming structure (`#537 <https://github.com/MESMER-group/mesmer/pull/537>`_)
- Add new integration tests for drawing realisations (`#599 <https://github.com/MESMER-group/mesmer/pull/599>`_)
- Port the functionality to xarray's :py:class:`DataTree` implementation (`#607 <https://github.com/MESMER-group/mesmer/pull/607>`_).

By `Victoria Bauer`_ and `Mathias Hauser`_.

Expand Down
6 changes: 3 additions & 3 deletions ci/install-upstream-wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ conda uninstall -y --force \
regionmask \
scikit-learn \
scipy \
statsmodels
# xarray # re-add after removing upper pin for xarray-datatree
statsmodels \
xarray

# keep cartopy & matplotlib: we don't have tests that use them
# keep joblib: we want to move away from pickle files
Expand All @@ -32,8 +32,8 @@ python -m pip install --no-deps --upgrade \
git+https://github.com/fatiando/pooch \
git+https://github.com/geopandas/geopandas \
git+https://github.com/properscoring/properscoring \
git+https://github.com/pydata/xarray \
git+https://github.com/pypa/packaging \
git+https://github.com/pyproj4/pyproj \
git+https://github.com/regionmask/regionmask \
git+https://github.com/SciTools/nc-time-axis
# git+https://github.com/pydata/xarray \
46 changes: 46 additions & 0 deletions mesmer/core/_datatreecompat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import functools

import xarray as xr
from packaging.version import Version

if Version(xr.__version__) < Version("2024.10"):

from datatree import DataTree, map_over_subtree, open_datatree

def map_over_datasets(func, *args, kwargs=None):
"compatibility layer for older xarray versions"

if kwargs is None:
kwargs = {}

return map_over_subtree(func)(*args, **kwargs)

elif Version(xr.__version__) > Version("2025.01"):

def skip_empty_nodes(func):
@functools.wraps(func)
def _func(ds, *args, **kwargs):
if not ds:
return ds
return func(ds, *args, **kwargs)

return _func

from xarray import DataTree, open_datatree
from xarray import map_over_datasets as _map_over_datasets

def map_over_datasets(func, *args, kwargs=None):

return _map_over_datasets(skip_empty_nodes(func), *args, kwargs=kwargs)

else:
raise ImportError(
f"xarray version {xr.__version__} not supported - please upgrade to v2025.02 ("
"or later) or downgrade to v2024.09"
)

__all__ = [
"DataTree",
"map_over_datasets",
"open_datatree",
]
16 changes: 13 additions & 3 deletions mesmer/core/datatree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import overload

import xarray as xr
from datatree import DataTree

from mesmer.core._datatreecompat import DataTree, map_over_datasets


def _extract_single_dataarray_from_dt(dt: DataTree, name: str = "node") -> xr.DataArray:
Expand Down Expand Up @@ -157,14 +158,23 @@ def stack_datatrees_for_linear_regression(
predictors_stacked = DataTree()
for key, pred in predictors.items():
# 1) broadcast to target
pred_broadcast = pred.broadcast_like(target, exclude=exclude_dim)
# TODO: use DataTree method again, once available
# pred_broadcast = pred.broadcast_like(target, exclude=exclude_dim)
pred_broadcast = map_over_datasets(
xr.Dataset.broadcast_like, pred, target, kwargs={"exclude": exclude_dim}
)

# 2) collapse into DataSets
predictor_ds = collapse_datatree_into_dataset(pred_broadcast, dim=collapse_dim)
# 3) stack
predictors_stacked[key] = DataTree(
predictor_ds.stack(stack_dim, create_index=False)
)
predictors_stacked = predictors_stacked.dropna(dim=stacked_dim)
# TODO: use DataTree method again, once available
# predictors_stacked = predictors_stacked.dropna(dim=stacked_dim)
predictors_stacked = map_over_datasets(
xr.Dataset.dropna, predictors_stacked, kwargs={"dim": stacked_dim}
)

# prepare target
# 1) collapse into DataSet
Expand Down
5 changes: 3 additions & 2 deletions mesmer/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np
import xarray as xr
from datatree import DataTree, map_over_subtree

from mesmer.core._datatreecompat import DataTree, map_over_datasets


def _weighted_if_dim(obj, weights, dims):
Expand Down Expand Up @@ -190,6 +191,6 @@ def _create_weights(ds: xr.Dataset) -> xr.Dataset:

return xr.Dataset({"weights": weights})

weights = map_over_subtree(_create_weights)(dt)
weights = map_over_datasets(_create_weights, dt)

return weights
164 changes: 149 additions & 15 deletions mesmer/stats/_auto_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pandas as pd
import scipy
import xarray as xr
from datatree import DataTree, map_over_subtree

from mesmer.core._datatreecompat import DataTree, map_over_datasets
from mesmer.core.datatree import (
collapse_datatree_into_dataset,
)
Expand Down Expand Up @@ -51,13 +51,19 @@ def _scen_ens_inputs_to_dt(objs: Sequence) -> DataTree:
return dt


def _extract_and_apply_to_da(func: Callable, ds: xr.Dataset, **kwargs) -> xr.Dataset:
def _extract_and_apply_to_da(func: Callable) -> Callable:

name, *others = ds.data_vars
if others:
raise ValueError("Dataset must have only one data variable.")
def inner(ds: xr.Dataset, **kwargs) -> xr.Dataset:

return func(ds[name], **kwargs)
name, *others = ds.data_vars
if others:
raise ValueError("Dataset must have only one data variable.")

x = func(ds[name], **kwargs)

return x.to_dataset() if isinstance(x, xr.DataArray) else x

return inner


def select_ar_order_scen_ens(
Expand Down Expand Up @@ -137,8 +143,10 @@ def _select_ar_order_scen_ens_dt(
then over all scenarios.
"""

ar_order_scen = map_over_subtree(_extract_and_apply_to_da)(
select_ar_order, dt, dim=dim, maxlag=maxlag, ic=ic
ar_order_scen = map_over_datasets(
_extract_and_apply_to_da(select_ar_order),
dt,
kwargs={"dim": dim, "maxlag": maxlag, "ic": ic},
)

# TODO: think about weighting?
Expand All @@ -147,7 +155,7 @@ def _ens_quantile(ds, ens_dim):
return ds.quantile(dim=ens_dim, q=0.5, method="nearest")
return ds

ar_order_ens_median = map_over_subtree(_ens_quantile)(ar_order_scen, ens_dim)
ar_order_ens_median = map_over_datasets(_ens_quantile, ar_order_scen, ens_dim)

ar_order_ens_median_ds = collapse_datatree_into_dataset(
ar_order_ens_median, dim="scen"
Expand Down Expand Up @@ -237,8 +245,10 @@ def _fit_auto_regression_scen_ens_dt(
If no ensemble members are provided, the mean is calculated over scenarios only.
"""

ar_params_scen = map_over_subtree(_extract_and_apply_to_da)(
fit_auto_regression, dt, dim=dim, lags=int(lags)
ar_params_scen = map_over_datasets(
_extract_and_apply_to_da(fit_auto_regression),
dt,
kwargs={"dim": dim, "lags": int(lags)},
)

# TODO: think about weighting! see https://github.com/MESMER-group/mesmer/issues/307
Expand All @@ -247,7 +257,7 @@ def _ens_mean(ds, ens_dim):
return ds.mean(ens_dim)
return ds

ar_params_scen = map_over_subtree(_ens_mean)(ar_params_scen, ens_dim)
ar_params_scen = map_over_datasets(_ens_mean, ar_params_scen, ens_dim)

ar_params_scen = collapse_datatree_into_dataset(ar_params_scen, dim="scen")

Expand Down Expand Up @@ -413,6 +423,44 @@ def draw_auto_regression_uncorrelated(
n_time x n_coeffs x n_realisations.

"""

if isinstance(seed, DataTree):
return map_over_datasets(
_draw_auto_regression_uncorrelated,
seed,
ar_params,
kwargs={
"time": time,
"realisation": realisation,
"buffer": buffer,
"time_dim": time_dim,
"realisation_dim": realisation_dim,
},
)

else:
return _draw_auto_regression_uncorrelated(
seed,
ar_params,
time=time,
realisation=realisation,
buffer=buffer,
time_dim=time_dim,
realisation_dim=realisation_dim,
)["samples"]


def _draw_auto_regression_uncorrelated(
seed: int | xr.Dataset,
ar_params: xr.Dataset,
*,
time: int | xr.DataArray | pd.Index,
realisation: int | xr.DataArray | pd.Index,
buffer: int,
time_dim: str = "time",
realisation_dim: str = "realisation",
) -> xr.DataArray:

# NOTE: we use variance and not std since we use multivariate normal
# also to draw univariate realizations
# check the input
Expand Down Expand Up @@ -450,7 +498,7 @@ def draw_auto_regression_uncorrelated(
# remove the "__gridpoint__" dim again
result = result.squeeze(dim="__gridpoint__", drop=True)

return result.rename("samples")
return result.rename("samples").to_dataset()


def draw_auto_regression_correlated(
Expand Down Expand Up @@ -513,6 +561,48 @@ def draw_auto_regression_correlated(

"""

if isinstance(seed, DataTree):

return map_over_datasets(
_draw_auto_regression_correlated,
seed,
ar_params,
covariance,
kwargs={
"time": time,
"realisation": realisation,
"buffer": buffer,
"time_dim": time_dim,
"realisation_dim": realisation_dim,
},
)

else:

return _draw_auto_regression_correlated(
seed,
ar_params,
covariance,
time=time,
realisation=realisation,
buffer=buffer,
time_dim=time_dim,
realisation_dim=realisation_dim,
)["samples"]


def _draw_auto_regression_correlated(
seed: int | xr.Dataset,
ar_params: xr.Dataset,
covariance: xr.DataArray,
*,
time: int | xr.DataArray | pd.Index,
realisation: int | xr.DataArray | pd.Index,
buffer: int,
time_dim: str = "time",
realisation_dim: str = "realisation",
) -> xr.DataArray:

# check the input
_check_dataset_form(ar_params, "ar_params", required_vars={"intercept", "coeffs"})
_check_dataarray_form(ar_params.intercept, "intercept", ndim=1)
Expand All @@ -538,7 +628,7 @@ def draw_auto_regression_correlated(
realisation_dim=realisation_dim,
)

return result.rename("samples")
return result.rename("samples").to_dataset()


def _draw_ar_corr_xr_internal(
Expand Down Expand Up @@ -943,6 +1033,50 @@ def draw_auto_regression_monthly(
correlated innovations. The array has shape n_timesteps x n_gridpoints.

"""

if isinstance(seed, DataTree):

return map_over_datasets(
_draw_auto_regression_monthly,
seed,
ar_params,
covariance,
kwargs={
"time": time,
"n_realisations": n_realisations,
"buffer": buffer,
"time_dim": time_dim,
"realisation_dim": realisation_dim,
},
)

else:
return _draw_auto_regression_monthly(
seed,
ar_params,
covariance,
time=time,
n_realisations=n_realisations,
buffer=buffer,
time_dim=time_dim,
realisation_dim=realisation_dim,
)["samples"]


def _draw_auto_regression_monthly(
seed,
ar_params: xr.Dataset,
covariance: xr.DataArray,
*,
time: xr.DataArray | pd.Index,
n_realisations: int,
buffer: int,
time_dim: str = "time",
realisation_dim: str = "realisation",
) -> xr.DataArray:

# NOTE: seed must be the first positional argument for map_over_datasets to work

# check input
_check_dataset_form(ar_params, "ar_params", required_vars={"intercept", "slope"})
month_dim, gridcell_dim = ar_params.intercept.dims
Expand Down Expand Up @@ -975,7 +1109,7 @@ def draw_auto_regression_monthly(
realisation_dim=realisation_dim,
)

return result.rename("samples")
return result.rename("samples").to_dataset()


def _draw_ar_corr_monthly_xr_internal(
Expand Down
Loading
Loading