6
6
import pandas as pd
7
7
import scipy
8
8
import xarray as xr
9
- from datatree import DataTree , map_over_subtree
10
9
10
+ from mesmer .core ._datatreecompat import DataTree , map_over_datasets
11
11
from mesmer .core .datatree import (
12
12
collapse_datatree_into_dataset ,
13
13
)
@@ -51,13 +51,19 @@ def _scen_ens_inputs_to_dt(objs: Sequence) -> DataTree:
51
51
return dt
52
52
53
53
54
- def _extract_and_apply_to_da (func : Callable , ds : xr . Dataset , ** kwargs ) -> xr . Dataset :
54
+ def _extract_and_apply_to_da (func : Callable ) -> Callable :
55
55
56
- name , * others = ds .data_vars
57
- if others :
58
- raise ValueError ("Dataset must have only one data variable." )
56
+ def inner (ds : xr .Dataset , ** kwargs ) -> xr .Dataset :
59
57
60
- return func (ds [name ], ** kwargs )
58
+ name , * others = ds .data_vars
59
+ if others :
60
+ raise ValueError ("Dataset must have only one data variable." )
61
+
62
+ x = func (ds [name ], ** kwargs )
63
+
64
+ return x .to_dataset () if isinstance (x , xr .DataArray ) else x
65
+
66
+ return inner
61
67
62
68
63
69
def select_ar_order_scen_ens (
@@ -137,8 +143,10 @@ def _select_ar_order_scen_ens_dt(
137
143
then over all scenarios.
138
144
"""
139
145
140
- ar_order_scen = map_over_subtree (_extract_and_apply_to_da )(
141
- select_ar_order , dt , dim = dim , maxlag = maxlag , ic = ic
146
+ ar_order_scen = map_over_datasets (
147
+ _extract_and_apply_to_da (select_ar_order ),
148
+ dt ,
149
+ kwargs = {"dim" : dim , "maxlag" : maxlag , "ic" : ic },
142
150
)
143
151
144
152
# TODO: think about weighting?
@@ -147,7 +155,7 @@ def _ens_quantile(ds, ens_dim):
147
155
return ds .quantile (dim = ens_dim , q = 0.5 , method = "nearest" )
148
156
return ds
149
157
150
- ar_order_ens_median = map_over_subtree (_ens_quantile )( ar_order_scen , ens_dim )
158
+ ar_order_ens_median = map_over_datasets (_ens_quantile , ar_order_scen , ens_dim )
151
159
152
160
ar_order_ens_median_ds = collapse_datatree_into_dataset (
153
161
ar_order_ens_median , dim = "scen"
@@ -237,8 +245,10 @@ def _fit_auto_regression_scen_ens_dt(
237
245
If no ensemble members are provided, the mean is calculated over scenarios only.
238
246
"""
239
247
240
- ar_params_scen = map_over_subtree (_extract_and_apply_to_da )(
241
- fit_auto_regression , dt , dim = dim , lags = int (lags )
248
+ ar_params_scen = map_over_datasets (
249
+ _extract_and_apply_to_da (fit_auto_regression ),
250
+ dt ,
251
+ kwargs = {"dim" : dim , "lags" : int (lags )},
242
252
)
243
253
244
254
# TODO: think about weighting! see https://github.com/MESMER-group/mesmer/issues/307
@@ -247,7 +257,7 @@ def _ens_mean(ds, ens_dim):
247
257
return ds .mean (ens_dim )
248
258
return ds
249
259
250
- ar_params_scen = map_over_subtree (_ens_mean )( ar_params_scen , ens_dim )
260
+ ar_params_scen = map_over_datasets (_ens_mean , ar_params_scen , ens_dim )
251
261
252
262
ar_params_scen = collapse_datatree_into_dataset (ar_params_scen , dim = "scen" )
253
263
@@ -413,6 +423,44 @@ def draw_auto_regression_uncorrelated(
413
423
n_time x n_coeffs x n_realisations.
414
424
415
425
"""
426
+
427
+ if isinstance (seed , DataTree ):
428
+ return map_over_datasets (
429
+ _draw_auto_regression_uncorrelated ,
430
+ seed ,
431
+ ar_params ,
432
+ kwargs = {
433
+ "time" : time ,
434
+ "realisation" : realisation ,
435
+ "buffer" : buffer ,
436
+ "time_dim" : time_dim ,
437
+ "realisation_dim" : realisation_dim ,
438
+ },
439
+ )
440
+
441
+ else :
442
+ return _draw_auto_regression_uncorrelated (
443
+ seed ,
444
+ ar_params ,
445
+ time = time ,
446
+ realisation = realisation ,
447
+ buffer = buffer ,
448
+ time_dim = time_dim ,
449
+ realisation_dim = realisation_dim ,
450
+ )["samples" ]
451
+
452
+
453
+ def _draw_auto_regression_uncorrelated (
454
+ seed : int | xr .Dataset ,
455
+ ar_params : xr .Dataset ,
456
+ * ,
457
+ time : int | xr .DataArray | pd .Index ,
458
+ realisation : int | xr .DataArray | pd .Index ,
459
+ buffer : int ,
460
+ time_dim : str = "time" ,
461
+ realisation_dim : str = "realisation" ,
462
+ ) -> xr .DataArray :
463
+
416
464
# NOTE: we use variance and not std since we use multivariate normal
417
465
# also to draw univariate realizations
418
466
# check the input
@@ -450,7 +498,7 @@ def draw_auto_regression_uncorrelated(
450
498
# remove the "__gridpoint__" dim again
451
499
result = result .squeeze (dim = "__gridpoint__" , drop = True )
452
500
453
- return result .rename ("samples" )
501
+ return result .rename ("samples" ). to_dataset ()
454
502
455
503
456
504
def draw_auto_regression_correlated (
@@ -513,6 +561,48 @@ def draw_auto_regression_correlated(
513
561
514
562
"""
515
563
564
+ if isinstance (seed , DataTree ):
565
+
566
+ return map_over_datasets (
567
+ _draw_auto_regression_correlated ,
568
+ seed ,
569
+ ar_params ,
570
+ covariance ,
571
+ kwargs = {
572
+ "time" : time ,
573
+ "realisation" : realisation ,
574
+ "buffer" : buffer ,
575
+ "time_dim" : time_dim ,
576
+ "realisation_dim" : realisation_dim ,
577
+ },
578
+ )
579
+
580
+ else :
581
+
582
+ return _draw_auto_regression_correlated (
583
+ seed ,
584
+ ar_params ,
585
+ covariance ,
586
+ time = time ,
587
+ realisation = realisation ,
588
+ buffer = buffer ,
589
+ time_dim = time_dim ,
590
+ realisation_dim = realisation_dim ,
591
+ )["samples" ]
592
+
593
+
594
+ def _draw_auto_regression_correlated (
595
+ seed : int | xr .Dataset ,
596
+ ar_params : xr .Dataset ,
597
+ covariance : xr .DataArray ,
598
+ * ,
599
+ time : int | xr .DataArray | pd .Index ,
600
+ realisation : int | xr .DataArray | pd .Index ,
601
+ buffer : int ,
602
+ time_dim : str = "time" ,
603
+ realisation_dim : str = "realisation" ,
604
+ ) -> xr .DataArray :
605
+
516
606
# check the input
517
607
_check_dataset_form (ar_params , "ar_params" , required_vars = {"intercept" , "coeffs" })
518
608
_check_dataarray_form (ar_params .intercept , "intercept" , ndim = 1 )
@@ -538,7 +628,7 @@ def draw_auto_regression_correlated(
538
628
realisation_dim = realisation_dim ,
539
629
)
540
630
541
- return result .rename ("samples" )
631
+ return result .rename ("samples" ). to_dataset ()
542
632
543
633
544
634
def _draw_ar_corr_xr_internal (
@@ -943,6 +1033,50 @@ def draw_auto_regression_monthly(
943
1033
correlated innovations. The array has shape n_timesteps x n_gridpoints.
944
1034
945
1035
"""
1036
+
1037
+ if isinstance (seed , DataTree ):
1038
+
1039
+ return map_over_datasets (
1040
+ _draw_auto_regression_monthly ,
1041
+ seed ,
1042
+ ar_params ,
1043
+ covariance ,
1044
+ kwargs = {
1045
+ "time" : time ,
1046
+ "n_realisations" : n_realisations ,
1047
+ "buffer" : buffer ,
1048
+ "time_dim" : time_dim ,
1049
+ "realisation_dim" : realisation_dim ,
1050
+ },
1051
+ )
1052
+
1053
+ else :
1054
+ return _draw_auto_regression_monthly (
1055
+ seed ,
1056
+ ar_params ,
1057
+ covariance ,
1058
+ time = time ,
1059
+ n_realisations = n_realisations ,
1060
+ buffer = buffer ,
1061
+ time_dim = time_dim ,
1062
+ realisation_dim = realisation_dim ,
1063
+ )["samples" ]
1064
+
1065
+
1066
+ def _draw_auto_regression_monthly (
1067
+ seed ,
1068
+ ar_params : xr .Dataset ,
1069
+ covariance : xr .DataArray ,
1070
+ * ,
1071
+ time : xr .DataArray | pd .Index ,
1072
+ n_realisations : int ,
1073
+ buffer : int ,
1074
+ time_dim : str = "time" ,
1075
+ realisation_dim : str = "realisation" ,
1076
+ ) -> xr .DataArray :
1077
+
1078
+ # NOTE: seed must be the first positional argument for map_over_datasets to work
1079
+
946
1080
# check input
947
1081
_check_dataset_form (ar_params , "ar_params" , required_vars = {"intercept" , "slope" })
948
1082
month_dim , gridcell_dim = ar_params .intercept .dims
@@ -975,7 +1109,7 @@ def draw_auto_regression_monthly(
975
1109
realisation_dim = realisation_dim ,
976
1110
)
977
1111
978
- return result .rename ("samples" )
1112
+ return result .rename ("samples" ). to_dataset ()
979
1113
980
1114
981
1115
def _draw_ar_corr_monthly_xr_internal (
0 commit comments