23
23
24
24
"""
25
25
26
- import logging
27
26
import time
28
27
import warnings
29
28
from argparse import ArgumentDefaultsHelpFormatter , ArgumentParser
46
45
47
46
import pism_ragis .processing as prp
48
47
from pism_ragis .analysis import delta_analysis
48
+ from pism_ragis .decorators import profileit , timeit
49
49
from pism_ragis .filtering import importance_sampling
50
50
from pism_ragis .likelihood import log_normal
51
+ from pism_ragis .logger import get_logger
52
+
53
+ logger = get_logger (__name__ )
51
54
52
55
xr .set_options (keep_attrs = True )
53
56
plt .style .use ("tableau-colorblind10" )
54
57
55
- logger = logging .getLogger (__name__ )
56
- logging .getLogger ("matplotlib" ).disabled = True
57
-
58
- logging .basicConfig (filename = "example.log" , encoding = "utf-8" , level = logging .INFO )
59
-
60
58
61
59
sim_alpha = 0.5
62
60
sim_cmap = sns .color_palette ("crest" , n_colors = 4 ).as_hex ()[0 :3 :2 ]
68
66
hist_cmap = ["#a6cee3" , "#1f78b4" ]
69
67
70
68
71
- # def timeit(func):
72
- # def wrapper(*args, **kwargs):
73
- # start_time = time.time()
74
- # result = func(*args, **kwargs)
75
- # end_time = time.time()
76
- # time_elapsed = end_time - start_time
77
- # print(f"{func.__name__} took {time_elapsed:.0f}s.")
78
- # return result
79
-
80
- # return wrapper
81
-
82
-
83
- # def timeit(func):
84
- # @wraps(func)
85
- # def timeit_wrapper(*args, **kwargs):
86
- # start_time = time.perf_counter()
87
- # result = func(*args, **kwargs)
88
- # end_time = time.perf_counter()
89
- # time_elapsed = end_time - start_time
90
- # print(f"{func.__name__} took {time_elapsed:.1f}s.")
91
- # return result
92
-
93
- # return timeit_wrapper
94
-
95
-
96
- def timeit (func ):
69
+ @timeit
70
+ def prepare_simulations (
71
+ filenames : List [Union [Path , str ]],
72
+ config : Dict ,
73
+ reference_year : float ,
74
+ parallel : bool = True ,
75
+ engine : str = "netcdf4" ,
76
+ ) -> xr .Dataset :
97
77
"""
98
- Decorator that logs the time a function takes to execute .
78
+ Prepare simulations by loading and processing ensemble datasets .
99
79
100
- This decorator logs the start time, end time, and the elapsed time
101
- for the execution of the decorated function.
80
+ This function loads ensemble datasets from the specified filenames, processes them
81
+ according to the provided configuration, and returns the processed dataset. The
82
+ processing steps include sorting, dropping NaNs, standardizing variable names,
83
+ calculating cumulative variables, and normalizing cumulative variables.
102
84
103
85
Parameters
104
86
----------
105
- func : callable
106
- The function to be decorated.
87
+ filenames : List[Union[Path, str]]
88
+ A list of file paths to the ensemble datasets.
89
+ config : Dict
90
+ A dictionary containing configuration settings for processing the datasets.
91
+ parallel : bool, optional
92
+ Whether to load the datasets in parallel, by default True.
93
+ engine : str, optional
94
+ The engine to use for loading the datasets, by default "netcdf4".
107
95
108
96
Returns
109
97
-------
110
- callable
111
- The wrapped function with added timing functionality .
98
+ xr.Dataset
99
+ The processed xarray dataset .
112
100
113
101
Examples
114
102
--------
115
- >>> @timeit
116
- ... def example_function():
117
- ... time.sleep(1)
118
- ...
119
- >>> example_function()
120
- INFO:__main__:Starting example_function
121
- INFO:__main__:Finished example_function in 1.0001 seconds
103
+ >>> filenames = ["file1.nc", "file2.nc"]
104
+ >>> config = {
105
+ ... "PISM Spatial": {...},
106
+ ... "Cumulative Variables": {
107
+ ... "cumulative_grounding_line_flux": "cumulative_gl_flux",
108
+ ... "cumulative_smb": "cumulative_smb_flux"
109
+ ... },
110
+ ... "Flux Variables": {
111
+ ... "grounding_line_flux": "gl_flux",
112
+ ... "smb_flux": "smb_flux"
113
+ ... }
114
+ ... }
115
+ >>> ds = prepare_simulations(filenames, config)
122
116
"""
117
+ ds = prp .load_ensemble (filenames , parallel = parallel , engine = engine ).sortby ("basin" )
118
+ # ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized")
119
+ ds = ds .dropna (dim = "exp_id" )
123
120
124
- @wraps (func )
125
- def wrapper (* args , ** kwargs ):
126
- start_time = time .time ()
127
- logger .info ("Starting %s" , func .__name__ )
128
- result = func (* args , ** kwargs )
129
- end_time = time .time ()
130
- elapsed_time = end_time - start_time
131
- logger .info ("Finished %s in %2.2f seconds" , func .__name__ , elapsed_time )
132
- return result
133
-
134
- return wrapper
121
+ ds = prp .standardize_variable_names (ds , config ["PISM Spatial" ])
122
+ ds [config ["Cumulative Variables" ]["cumulative_grounding_line_flux" ]] = ds [
123
+ config ["Flux Variables" ]["grounding_line_flux" ]
124
+ ].cumsum () / len (ds .time )
125
+ ds [config ["Cumulative Variables" ]["cumulative_smb" ]] = ds [
126
+ config ["Flux Variables" ]["smb_flux" ]
127
+ ].cumsum () / len (ds .time )
128
+ ds = prp .normalize_cumulative_variables (
129
+ ds ,
130
+ list (config ["Cumulative Variables" ].values ()),
131
+ reference_year = reference_year ,
132
+ )
133
+ return ds
135
134
136
135
137
- @timeit
138
136
def config_to_dataframe (config : xr .DataArray ):
139
137
"""
140
138
Convert an xarray DataArray configuration to a pandas DataFrame.
@@ -157,7 +155,6 @@ def config_to_dataframe(config: xr.DataArray):
157
155
return df
158
156
159
157
160
- @timeit
161
158
def convert_bstrings_to_str (element : Any ) -> Any :
162
159
"""
163
160
Convert byte strings to regular strings.
@@ -178,14 +175,14 @@ def convert_bstrings_to_str(element: Any) -> Any:
178
175
return element
179
176
180
177
181
- @timeit
178
+ @profileit
182
179
def filter_outliers (
183
180
ds : xr .Dataset ,
184
181
outlier_range : List [float ],
185
182
outlier_variable : str ,
186
183
freq : str = "YS" ,
187
- subset : Dict [str , Union [ str , int ] ] = {"basin" : "GIS" , "ensemble_id" : "RAGIS" },
188
- ) -> Dict [ str , xr . Dataset ] :
184
+ subset : Dict [str , str | int ] = {"basin" : "GIS" , "ensemble_id" : "RAGIS" },
185
+ ):
189
186
"""
190
187
Filter outliers from a dataset based on a specified variable and range.
191
188
@@ -249,10 +246,25 @@ def filter_outliers(
249
246
filtered_ds = ds .sel (exp_id = filtered_exp_ids )
250
247
outliers_ds = ds .sel (exp_id = outlier_exp_ids )
251
248
252
- return { "filtered" : filtered_ds , "outliers" : outliers_ds }
249
+ return filtered_ds , outliers_ds
253
250
254
251
255
- @timeit
252
+ def plot_outliers (
253
+ filtered_da : xr .DataArray , outliers_da : xr .DataArray , filename : Path | str
254
+ ):
255
+ """
256
+ Plot outliers.
257
+ """
258
+ fig , ax = plt .subplots (1 , 1 )
259
+ if filtered_da .size > 0 :
260
+ print (filtered_da )
261
+ filtered_da .plot (hue = "exp_id" , color = "k" , add_legend = False , ax = ax , lw = 0.5 )
262
+ if outliers_da .size > 0 :
263
+ outliers_da .plot (hue = "exp_id" , color = "r" , add_legend = False , ax = ax , lw = 0.5 )
264
+ fig .savefig (filename )
265
+
266
+
267
+ @profileit
256
268
def run_delta_analysis (
257
269
ds : xr .Dataset ,
258
270
ensemble_df : pd .DataFrame ,
@@ -347,7 +359,6 @@ def run_delta_analysis(
347
359
return all_delta_indices
348
360
349
361
350
- @timeit
351
362
def plot_obs_sims (
352
363
obs : xr .Dataset ,
353
364
sim_prior : xr .Dataset ,
@@ -356,7 +367,7 @@ def plot_obs_sims(
356
367
filtering_var : str ,
357
368
filter_range : List [int ] = [1990 , 2019 ],
358
369
fig_dir : Union [str , Path ] = "figures" ,
359
- reference_year : int = 1986 ,
370
+ reference_year : float = 1986.0 ,
360
371
sim_alpha : float = 0.4 ,
361
372
obs_alpha : float = 1.0 ,
362
373
sigma : float = 2 ,
@@ -529,7 +540,7 @@ def plot_obs_sims_3(
529
540
filtering_var : str ,
530
541
filter_range : List [int ] = [1990 , 2019 ],
531
542
fig_dir : Union [str , Path ] = "figures" ,
532
- reference_year : int = 1986 ,
543
+ reference_year : float = 1986.0 ,
533
544
sim_alpha : float = 0.4 ,
534
545
obs_alpha : float = 1.0 ,
535
546
sigma : float = 2 ,
@@ -736,7 +747,7 @@ def plot_obs_sims_3(
736
747
"--obs_url" ,
737
748
help = """Path to "observed" mass balance.""" ,
738
749
type = str ,
739
- default = "data/mass_balance/mankoff_greenland_mass_balance .nc" ,
750
+ default = "data/mass_balance/combined_greenland_mass_balance .nc" ,
740
751
)
741
752
parser .add_argument (
742
753
"--engine" ,
@@ -797,8 +808,8 @@ def plot_obs_sims_3(
797
808
parser .add_argument (
798
809
"--reference_year" ,
799
810
help = """Reference year.""" ,
800
- type = int ,
801
- default = 1986 ,
811
+ type = float ,
812
+ default = 2004 ,
802
813
)
803
814
parser .add_argument (
804
815
"--n_jobs" ,
@@ -819,7 +830,13 @@ def plot_obs_sims_3(
819
830
nargs = "*" ,
820
831
)
821
832
822
- options = parser .parse_args ()
833
+ parser .add_argument (
834
+ "--log" ,
835
+ default = "WARNING" ,
836
+ help = "Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" ,
837
+ )
838
+
839
+ options , unknown = parser .parse_known_args ()
823
840
basin_files = options .FILES
824
841
ensemble = options .ensemble
825
842
engine = options .engine
@@ -873,51 +890,26 @@ def plot_obs_sims_3(
873
890
k + "_uncertainty" : v + "_uncertainty" for k , v in cumulative_vars .items ()
874
891
}
875
892
876
- ds = prp . load_ensemble ( basin_files , parallel = parallel , engine = engine ). sortby (
877
- "basin"
893
+ simulated_ds = prepare_simulations (
894
+ basin_files , ragis_config , reference_year , parallel = parallel , engine = engine
878
895
)
879
- # for v in ds.data_vars:
880
- # if ds[v].dtype.kind == "S":
881
- # ds[v] = ds[v].astype(str)
882
- # for c in ds.coords:
883
- # if ds[c].dtype.kind == "S":
884
- # ds.coords[c] = ds.coords[c].astype(str)
885
896
886
- # ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized")
887
- ds = ds .dropna (dim = "exp_id" )
897
+ # fig, ax = plt.subplots(1, 1)
898
+ # ds.sel(time=slice(str(filter_start_year), str(filter_end_year))).sel(
899
+ # basin="GIS", ensemble_id=ensemble
900
+ # ).grounding_line_flux.plot(hue="exp_id", add_legend=False, ax=ax, lw=0.5)
901
+ # fig.savefig("grounding_line_flux_unfiltered.pdf")
888
902
889
- ds = prp .standardize_variable_names (ds , ragis_config ["PISM Spatial" ])
890
- ds [ragis_config ["Cumulative Variables" ]["cumulative_grounding_line_flux" ]] = ds [
891
- ragis_config ["Flux Variables" ]["grounding_line_flux" ]
892
- ].cumsum () / len (ds .time )
893
- ds [ragis_config ["Cumulative Variables" ]["cumulative_smb" ]] = ds [
894
- ragis_config ["Flux Variables" ]["smb_flux" ]
895
- ].cumsum () / len (ds .time )
896
- ds = prp .normalize_cumulative_variables (
897
- ds ,
898
- list (ragis_config ["Cumulative Variables" ].values ()),
899
- reference_year = reference_year ,
903
+ filtered_ds , outliers_ds = filter_outliers (
904
+ simulated_ds , outlier_range = outlier_range , outlier_variable = outlier_variable
900
905
)
901
-
902
- fig , ax = plt .subplots (1 , 1 )
903
- ds .sel (time = slice (str (filter_start_year ), str (filter_end_year ))).sel (
904
- basin = "GIS" , ensemble_id = ensemble
905
- ).grounding_line_flux .plot (hue = "exp_id" , add_legend = False , ax = ax , lw = 0.5 )
906
- fig .savefig ("grounding_line_flux_unfiltered.pdf" )
907
-
908
- result = filter_outliers (
909
- ds , outlier_range = outlier_range , outlier_variable = outlier_variable
906
+ plot_outliers (
907
+ filtered_ds .sel (basin = "GIS" , ensemble_id = "RAGIS" )[outlier_variable ],
908
+ outliers_ds .sel (basin = "GIS" , ensemble_id = "RAGIS" )[outlier_variable ],
909
+ Path (fig_dir ) / Path (f"{ outlier_variable } _filtering.pdf" ),
910
910
)
911
- filtered_ds = result ["filtered" ]
912
- outliers_ds = result ["outliers" ]
913
911
914
- fig , ax = plt .subplots (1 , 1 )
915
- ds .sel (time = slice (str (filter_start_year ), str (filter_end_year ))).sel (
916
- basin = "GIS" , ensemble_id = ensemble
917
- ).grounding_line_flux .plot (hue = "exp_id" , add_legend = False , ax = ax , lw = 0.5 )
918
- fig .savefig ("grounding_line_flux_filtered.pdf" )
919
-
920
- prior_config = ds .sel (pism_config_axis = params ).pism_config
912
+ prior_config = simulated_ds .sel (pism_config_axis = params ).pism_config
921
913
prior = config_to_dataframe (prior_config )
922
914
prior ["Ensemble" ] = "Prior"
923
915
@@ -1014,7 +1006,8 @@ def plot_obs_sims_3(
1014
1006
.mean ()
1015
1007
)
1016
1008
1017
- simulated = filtered_ds .sel (basin = ["CE" , "CW" , "GIS" , "NE" , "NO" , "NW" , "SE" , "SW" ])
1009
+ simulated = filtered_ds
1010
+
1018
1011
simulated_resampled = (
1019
1012
simulated .drop_vars (["pism_config" , "run_stats" ], errors = "ignore" )
1020
1013
.resample (time = resampling_frequency )
@@ -1090,7 +1083,7 @@ def plot_obs_sims_3(
1090
1083
config = ragis_config ,
1091
1084
filtering_var = obs_mean_var ,
1092
1085
filter_range = [filter_start_year , filter_end_year ],
1093
- fig_dir = result_dir / Path ( "figures" ) ,
1086
+ fig_dir = fig_dir ,
1094
1087
obs_alpha = obs_alpha ,
1095
1088
sim_alpha = sim_alpha ,
1096
1089
)
@@ -1173,7 +1166,7 @@ def plot_obs_sims_3(
1173
1166
"calving.rate_scaling.file"
1174
1167
].map (calving_dict )
1175
1168
1176
- to_analyze = ds .sel (time = slice ("1980-01-01" , "2020-01-01" ))
1169
+ to_analyze = simulated_ds .sel (time = slice ("1980-01-01" , "2020-01-01" ))
1177
1170
all_delta_indices = run_delta_analysis (
1178
1171
to_analyze , ensemble_df , list (flux_vars .values ())[:2 ], notebook = notebook
1179
1172
)
0 commit comments