Skip to content

Commit fd40272

Browse files
committed
Implemented basic logger / profiling capabilities.
1 parent b480bd1 commit fd40272

File tree

6 files changed

+216
-126
lines changed

6 files changed

+216
-126
lines changed

analysis/analyze_scalar.py

Lines changed: 106 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
2424
"""
2525

26-
import logging
2726
import time
2827
import warnings
2928
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
@@ -46,17 +45,16 @@
4645

4746
import pism_ragis.processing as prp
4847
from pism_ragis.analysis import delta_analysis
48+
from pism_ragis.decorators import profileit, timeit
4949
from pism_ragis.filtering import importance_sampling
5050
from pism_ragis.likelihood import log_normal
51+
from pism_ragis.logger import get_logger
52+
53+
logger = get_logger(__name__)
5154

5255
xr.set_options(keep_attrs=True)
5356
plt.style.use("tableau-colorblind10")
5457

55-
logger = logging.getLogger(__name__)
56-
logging.getLogger("matplotlib").disabled = True
57-
58-
logging.basicConfig(filename="example.log", encoding="utf-8", level=logging.INFO)
59-
6058

6159
sim_alpha = 0.5
6260
sim_cmap = sns.color_palette("crest", n_colors=4).as_hex()[0:3:2]
@@ -68,73 +66,73 @@
6866
hist_cmap = ["#a6cee3", "#1f78b4"]
6967

7068

71-
# def timeit(func):
72-
# def wrapper(*args, **kwargs):
73-
# start_time = time.time()
74-
# result = func(*args, **kwargs)
75-
# end_time = time.time()
76-
# time_elapsed = end_time - start_time
77-
# print(f"{func.__name__} took {time_elapsed:.0f}s.")
78-
# return result
79-
80-
# return wrapper
81-
82-
83-
# def timeit(func):
84-
# @wraps(func)
85-
# def timeit_wrapper(*args, **kwargs):
86-
# start_time = time.perf_counter()
87-
# result = func(*args, **kwargs)
88-
# end_time = time.perf_counter()
89-
# time_elapsed = end_time - start_time
90-
# print(f"{func.__name__} took {time_elapsed:.1f}s.")
91-
# return result
92-
93-
# return timeit_wrapper
94-
95-
96-
def timeit(func):
69+
@timeit
70+
def prepare_simulations(
71+
filenames: List[Union[Path, str]],
72+
config: Dict,
73+
reference_year: float,
74+
parallel: bool = True,
75+
engine: str = "netcdf4",
76+
) -> xr.Dataset:
9777
"""
98-
Decorator that logs the time a function takes to execute.
78+
Prepare simulations by loading and processing ensemble datasets.
9979
100-
This decorator logs the start time, end time, and the elapsed time
101-
for the execution of the decorated function.
80+
This function loads ensemble datasets from the specified filenames, processes them
81+
according to the provided configuration, and returns the processed dataset. The
82+
processing steps include sorting, dropping NaNs, standardizing variable names,
83+
calculating cumulative variables, and normalizing cumulative variables.
10284
10385
Parameters
10486
----------
105-
func : callable
106-
The function to be decorated.
87+
filenames : List[Union[Path, str]]
88+
A list of file paths to the ensemble datasets.
89+
config : Dict
90+
A dictionary containing configuration settings for processing the datasets.
91+
parallel : bool, optional
92+
Whether to load the datasets in parallel, by default True.
93+
engine : str, optional
94+
The engine to use for loading the datasets, by default "netcdf4".
10795
10896
Returns
10997
-------
110-
callable
111-
The wrapped function with added timing functionality.
98+
xr.Dataset
99+
The processed xarray dataset.
112100
113101
Examples
114102
--------
115-
>>> @timeit
116-
... def example_function():
117-
... time.sleep(1)
118-
...
119-
>>> example_function()
120-
INFO:__main__:Starting example_function
121-
INFO:__main__:Finished example_function in 1.0001 seconds
103+
>>> filenames = ["file1.nc", "file2.nc"]
104+
>>> config = {
105+
... "PISM Spatial": {...},
106+
... "Cumulative Variables": {
107+
... "cumulative_grounding_line_flux": "cumulative_gl_flux",
108+
... "cumulative_smb": "cumulative_smb_flux"
109+
... },
110+
... "Flux Variables": {
111+
... "grounding_line_flux": "gl_flux",
112+
... "smb_flux": "smb_flux"
113+
... }
114+
... }
115+
>>> ds = prepare_simulations(filenames, config)
122116
"""
117+
ds = prp.load_ensemble(filenames, parallel=parallel, engine=engine).sortby("basin")
118+
# ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized")
119+
ds = ds.dropna(dim="exp_id")
123120

124-
@wraps(func)
125-
def wrapper(*args, **kwargs):
126-
start_time = time.time()
127-
logger.info("Starting %s", func.__name__)
128-
result = func(*args, **kwargs)
129-
end_time = time.time()
130-
elapsed_time = end_time - start_time
131-
logger.info("Finished %s in %2.2f seconds", func.__name__, elapsed_time)
132-
return result
133-
134-
return wrapper
121+
ds = prp.standardize_variable_names(ds, config["PISM Spatial"])
122+
ds[config["Cumulative Variables"]["cumulative_grounding_line_flux"]] = ds[
123+
config["Flux Variables"]["grounding_line_flux"]
124+
].cumsum() / len(ds.time)
125+
ds[config["Cumulative Variables"]["cumulative_smb"]] = ds[
126+
config["Flux Variables"]["smb_flux"]
127+
].cumsum() / len(ds.time)
128+
ds = prp.normalize_cumulative_variables(
129+
ds,
130+
list(config["Cumulative Variables"].values()),
131+
reference_year=reference_year,
132+
)
133+
return ds
135134

136135

137-
@timeit
138136
def config_to_dataframe(config: xr.DataArray):
139137
"""
140138
Convert an xarray DataArray configuration to a pandas DataFrame.
@@ -157,7 +155,6 @@ def config_to_dataframe(config: xr.DataArray):
157155
return df
158156

159157

160-
@timeit
161158
def convert_bstrings_to_str(element: Any) -> Any:
162159
"""
163160
Convert byte strings to regular strings.
@@ -178,14 +175,14 @@ def convert_bstrings_to_str(element: Any) -> Any:
178175
return element
179176

180177

181-
@timeit
178+
@profileit
182179
def filter_outliers(
183180
ds: xr.Dataset,
184181
outlier_range: List[float],
185182
outlier_variable: str,
186183
freq: str = "YS",
187-
subset: Dict[str, Union[str, int]] = {"basin": "GIS", "ensemble_id": "RAGIS"},
188-
) -> Dict[str, xr.Dataset]:
184+
subset: Dict[str, str | int] = {"basin": "GIS", "ensemble_id": "RAGIS"},
185+
):
189186
"""
190187
Filter outliers from a dataset based on a specified variable and range.
191188
@@ -249,10 +246,25 @@ def filter_outliers(
249246
filtered_ds = ds.sel(exp_id=filtered_exp_ids)
250247
outliers_ds = ds.sel(exp_id=outlier_exp_ids)
251248

252-
return {"filtered": filtered_ds, "outliers": outliers_ds}
249+
return filtered_ds, outliers_ds
253250

254251

255-
@timeit
252+
def plot_outliers(
253+
filtered_da: xr.DataArray, outliers_da: xr.DataArray, filename: Path | str
254+
):
255+
"""
256+
Plot outliers.
257+
"""
258+
fig, ax = plt.subplots(1, 1)
259+
if filtered_da.size > 0:
260+
print(filtered_da)
261+
filtered_da.plot(hue="exp_id", color="k", add_legend=False, ax=ax, lw=0.5)
262+
if outliers_da.size > 0:
263+
outliers_da.plot(hue="exp_id", color="r", add_legend=False, ax=ax, lw=0.5)
264+
fig.savefig(filename)
265+
266+
267+
@profileit
256268
def run_delta_analysis(
257269
ds: xr.Dataset,
258270
ensemble_df: pd.DataFrame,
@@ -347,7 +359,6 @@ def run_delta_analysis(
347359
return all_delta_indices
348360

349361

350-
@timeit
351362
def plot_obs_sims(
352363
obs: xr.Dataset,
353364
sim_prior: xr.Dataset,
@@ -356,7 +367,7 @@ def plot_obs_sims(
356367
filtering_var: str,
357368
filter_range: List[int] = [1990, 2019],
358369
fig_dir: Union[str, Path] = "figures",
359-
reference_year: int = 1986,
370+
reference_year: float = 1986.0,
360371
sim_alpha: float = 0.4,
361372
obs_alpha: float = 1.0,
362373
sigma: float = 2,
@@ -529,7 +540,7 @@ def plot_obs_sims_3(
529540
filtering_var: str,
530541
filter_range: List[int] = [1990, 2019],
531542
fig_dir: Union[str, Path] = "figures",
532-
reference_year: int = 1986,
543+
reference_year: float = 1986.0,
533544
sim_alpha: float = 0.4,
534545
obs_alpha: float = 1.0,
535546
sigma: float = 2,
@@ -736,7 +747,7 @@ def plot_obs_sims_3(
736747
"--obs_url",
737748
help="""Path to "observed" mass balance.""",
738749
type=str,
739-
default="data/mass_balance/mankoff_greenland_mass_balance.nc",
750+
default="data/mass_balance/combined_greenland_mass_balance.nc",
740751
)
741752
parser.add_argument(
742753
"--engine",
@@ -797,8 +808,8 @@ def plot_obs_sims_3(
797808
parser.add_argument(
798809
"--reference_year",
799810
help="""Reference year.""",
800-
type=int,
801-
default=1986,
811+
type=float,
812+
default=2004,
802813
)
803814
parser.add_argument(
804815
"--n_jobs",
@@ -819,7 +830,13 @@ def plot_obs_sims_3(
819830
nargs="*",
820831
)
821832

822-
options = parser.parse_args()
833+
parser.add_argument(
834+
"--log",
835+
default="WARNING",
836+
help="Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
837+
)
838+
839+
options, unknown = parser.parse_known_args()
823840
basin_files = options.FILES
824841
ensemble = options.ensemble
825842
engine = options.engine
@@ -873,51 +890,26 @@ def plot_obs_sims_3(
873890
k + "_uncertainty": v + "_uncertainty" for k, v in cumulative_vars.items()
874891
}
875892

876-
ds = prp.load_ensemble(basin_files, parallel=parallel, engine=engine).sortby(
877-
"basin"
893+
simulated_ds = prepare_simulations(
894+
basin_files, ragis_config, reference_year, parallel=parallel, engine=engine
878895
)
879-
# for v in ds.data_vars:
880-
# if ds[v].dtype.kind == "S":
881-
# ds[v] = ds[v].astype(str)
882-
# for c in ds.coords:
883-
# if ds[c].dtype.kind == "S":
884-
# ds.coords[c] = ds.coords[c].astype(str)
885896

886-
# ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized")
887-
ds = ds.dropna(dim="exp_id")
897+
# fig, ax = plt.subplots(1, 1)
898+
# ds.sel(time=slice(str(filter_start_year), str(filter_end_year))).sel(
899+
# basin="GIS", ensemble_id=ensemble
900+
# ).grounding_line_flux.plot(hue="exp_id", add_legend=False, ax=ax, lw=0.5)
901+
# fig.savefig("grounding_line_flux_unfiltered.pdf")
888902

889-
ds = prp.standardize_variable_names(ds, ragis_config["PISM Spatial"])
890-
ds[ragis_config["Cumulative Variables"]["cumulative_grounding_line_flux"]] = ds[
891-
ragis_config["Flux Variables"]["grounding_line_flux"]
892-
].cumsum() / len(ds.time)
893-
ds[ragis_config["Cumulative Variables"]["cumulative_smb"]] = ds[
894-
ragis_config["Flux Variables"]["smb_flux"]
895-
].cumsum() / len(ds.time)
896-
ds = prp.normalize_cumulative_variables(
897-
ds,
898-
list(ragis_config["Cumulative Variables"].values()),
899-
reference_year=reference_year,
903+
filtered_ds, outliers_ds = filter_outliers(
904+
simulated_ds, outlier_range=outlier_range, outlier_variable=outlier_variable
900905
)
901-
902-
fig, ax = plt.subplots(1, 1)
903-
ds.sel(time=slice(str(filter_start_year), str(filter_end_year))).sel(
904-
basin="GIS", ensemble_id=ensemble
905-
).grounding_line_flux.plot(hue="exp_id", add_legend=False, ax=ax, lw=0.5)
906-
fig.savefig("grounding_line_flux_unfiltered.pdf")
907-
908-
result = filter_outliers(
909-
ds, outlier_range=outlier_range, outlier_variable=outlier_variable
906+
plot_outliers(
907+
filtered_ds.sel(basin="GIS", ensemble_id="RAGIS")[outlier_variable],
908+
outliers_ds.sel(basin="GIS", ensemble_id="RAGIS")[outlier_variable],
909+
Path(fig_dir) / Path(f"{outlier_variable}_filtering.pdf"),
910910
)
911-
filtered_ds = result["filtered"]
912-
outliers_ds = result["outliers"]
913911

914-
fig, ax = plt.subplots(1, 1)
915-
ds.sel(time=slice(str(filter_start_year), str(filter_end_year))).sel(
916-
basin="GIS", ensemble_id=ensemble
917-
).grounding_line_flux.plot(hue="exp_id", add_legend=False, ax=ax, lw=0.5)
918-
fig.savefig("grounding_line_flux_filtered.pdf")
919-
920-
prior_config = ds.sel(pism_config_axis=params).pism_config
912+
prior_config = simulated_ds.sel(pism_config_axis=params).pism_config
921913
prior = config_to_dataframe(prior_config)
922914
prior["Ensemble"] = "Prior"
923915

@@ -1014,7 +1006,8 @@ def plot_obs_sims_3(
10141006
.mean()
10151007
)
10161008

1017-
simulated = filtered_ds.sel(basin=["CE", "CW", "GIS", "NE", "NO", "NW", "SE", "SW"])
1009+
simulated = filtered_ds
1010+
10181011
simulated_resampled = (
10191012
simulated.drop_vars(["pism_config", "run_stats"], errors="ignore")
10201013
.resample(time=resampling_frequency)
@@ -1090,7 +1083,7 @@ def plot_obs_sims_3(
10901083
config=ragis_config,
10911084
filtering_var=obs_mean_var,
10921085
filter_range=[filter_start_year, filter_end_year],
1093-
fig_dir=result_dir / Path("figures"),
1086+
fig_dir=fig_dir,
10941087
obs_alpha=obs_alpha,
10951088
sim_alpha=sim_alpha,
10961089
)
@@ -1173,7 +1166,7 @@ def plot_obs_sims_3(
11731166
"calving.rate_scaling.file"
11741167
].map(calving_dict)
11751168

1176-
to_analyze = ds.sel(time=slice("1980-01-01", "2020-01-01"))
1169+
to_analyze = simulated_ds.sel(time=slice("1980-01-01", "2020-01-01"))
11771170
all_delta_indices = run_delta_analysis(
11781171
to_analyze, ensemble_df, list(flux_vars.values())[:2], notebook=notebook
11791172
)

data/03_prepare_mass_balance.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@
110110

111111
fn = "mankoff_greenland_mass_balance.nc"
112112
p_fn = p / fn
113-
ds.pint.dequantify().to_netcdf(p_fn, encoding=encoding)
113+
mankoff_ds = ds.pint.dequantify()
114+
mankoff_ds.to_netcdf(p_fn, encoding=encoding)
114115

115116
short_name = "GREENLAND_MASS_TELLUS_MASCON_CRI_TIME_SERIES_RL06.1_V3"
116117
results = download_earthaccess(result_dir=p, short_name=short_name)
@@ -140,4 +141,10 @@
140141
ds["cumulative_mass_balance_uncertainty"].attrs.update({"units": "Gt"})
141142
fn = "grace_greenland_mass_balance.nc"
142143
p_fn = p / fn
143-
ds.to_netcdf(fn)
144+
grace_ds = ds
145+
grace_ds.to_netcdf(fn)
146+
147+
fn = "combined_greenland_mass_balance.nc"
148+
p_fn = p / fn
149+
combined_ds = xr.merge([grace_ds, mankoff_ds])
150+
combined_ds.to_netcdf(fn)

0 commit comments

Comments
 (0)