Skip to content

Commit

Permalink
Implemented basic logger / profiling capabilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaschwanden committed Oct 21, 2024
1 parent b480bd1 commit fd40272
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 126 deletions.
219 changes: 106 additions & 113 deletions analysis/analyze_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""

import logging
import time
import warnings
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
Expand All @@ -46,17 +45,16 @@

import pism_ragis.processing as prp
from pism_ragis.analysis import delta_analysis
from pism_ragis.decorators import profileit, timeit
from pism_ragis.filtering import importance_sampling
from pism_ragis.likelihood import log_normal
from pism_ragis.logger import get_logger

logger = get_logger(__name__)

xr.set_options(keep_attrs=True)
plt.style.use("tableau-colorblind10")

logger = logging.getLogger(__name__)
logging.getLogger("matplotlib").disabled = True

logging.basicConfig(filename="example.log", encoding="utf-8", level=logging.INFO)


sim_alpha = 0.5
sim_cmap = sns.color_palette("crest", n_colors=4).as_hex()[0:3:2]
Expand All @@ -68,73 +66,73 @@
hist_cmap = ["#a6cee3", "#1f78b4"]


# def timeit(func):
# def wrapper(*args, **kwargs):
# start_time = time.time()
# result = func(*args, **kwargs)
# end_time = time.time()
# time_elapsed = end_time - start_time
# print(f"{func.__name__} took {time_elapsed:.0f}s.")
# return result

# return wrapper


# def timeit(func):
# @wraps(func)
# def timeit_wrapper(*args, **kwargs):
# start_time = time.perf_counter()
# result = func(*args, **kwargs)
# end_time = time.perf_counter()
# time_elapsed = end_time - start_time
# print(f"{func.__name__} took {time_elapsed:.1f}s.")
# return result

# return timeit_wrapper


def timeit(func):
@timeit
def prepare_simulations(
filenames: List[Union[Path, str]],
config: Dict,
reference_year: float,
parallel: bool = True,
engine: str = "netcdf4",
) -> xr.Dataset:
"""
Decorator that logs the time a function takes to execute.
Prepare simulations by loading and processing ensemble datasets.
This decorator logs the start time, end time, and the elapsed time
for the execution of the decorated function.
This function loads ensemble datasets from the specified filenames, processes them
according to the provided configuration, and returns the processed dataset. The
processing steps include sorting, dropping NaNs, standardizing variable names,
calculating cumulative variables, and normalizing cumulative variables.
Parameters
----------
func : callable
The function to be decorated.
filenames : List[Union[Path, str]]
A list of file paths to the ensemble datasets.
config : Dict
A dictionary containing configuration settings for processing the datasets.
parallel : bool, optional
Whether to load the datasets in parallel, by default True.
engine : str, optional
The engine to use for loading the datasets, by default "netcdf4".
Returns
-------
callable
The wrapped function with added timing functionality.
xr.Dataset
The processed xarray dataset.
Examples
--------
>>> @timeit
... def example_function():
... time.sleep(1)
...
>>> example_function()
INFO:__main__:Starting example_function
INFO:__main__:Finished example_function in 1.0001 seconds
>>> filenames = ["file1.nc", "file2.nc"]
>>> config = {
... "PISM Spatial": {...},
... "Cumulative Variables": {
... "cumulative_grounding_line_flux": "cumulative_gl_flux",
... "cumulative_smb": "cumulative_smb_flux"
... },
... "Flux Variables": {
... "grounding_line_flux": "gl_flux",
... "smb_flux": "smb_flux"
... }
... }
>>> ds = prepare_simulations(filenames, config)
"""
ds = prp.load_ensemble(filenames, parallel=parallel, engine=engine).sortby("basin")
# ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized")
ds = ds.dropna(dim="exp_id")

@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
logger.info("Starting %s", func.__name__)
result = func(*args, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info("Finished %s in %2.2f seconds", func.__name__, elapsed_time)
return result

return wrapper
ds = prp.standardize_variable_names(ds, config["PISM Spatial"])
ds[config["Cumulative Variables"]["cumulative_grounding_line_flux"]] = ds[
config["Flux Variables"]["grounding_line_flux"]
].cumsum() / len(ds.time)
ds[config["Cumulative Variables"]["cumulative_smb"]] = ds[
config["Flux Variables"]["smb_flux"]
].cumsum() / len(ds.time)
ds = prp.normalize_cumulative_variables(
ds,
list(config["Cumulative Variables"].values()),
reference_year=reference_year,
)
return ds


@timeit
def config_to_dataframe(config: xr.DataArray):
"""
Convert an xarray DataArray configuration to a pandas DataFrame.
Expand All @@ -157,7 +155,6 @@ def config_to_dataframe(config: xr.DataArray):
return df


@timeit
def convert_bstrings_to_str(element: Any) -> Any:
"""
Convert byte strings to regular strings.
Expand All @@ -178,14 +175,14 @@ def convert_bstrings_to_str(element: Any) -> Any:
return element


@timeit
@profileit
def filter_outliers(
ds: xr.Dataset,
outlier_range: List[float],
outlier_variable: str,
freq: str = "YS",
subset: Dict[str, Union[str, int]] = {"basin": "GIS", "ensemble_id": "RAGIS"},
) -> Dict[str, xr.Dataset]:
subset: Dict[str, str | int] = {"basin": "GIS", "ensemble_id": "RAGIS"},
):
"""
Filter outliers from a dataset based on a specified variable and range.
Expand Down Expand Up @@ -249,10 +246,25 @@ def filter_outliers(
filtered_ds = ds.sel(exp_id=filtered_exp_ids)
outliers_ds = ds.sel(exp_id=outlier_exp_ids)

return {"filtered": filtered_ds, "outliers": outliers_ds}
return filtered_ds, outliers_ds


@timeit
def plot_outliers(
filtered_da: xr.DataArray, outliers_da: xr.DataArray, filename: Path | str
):
"""
Plot outliers.
"""
fig, ax = plt.subplots(1, 1)
if filtered_da.size > 0:
print(filtered_da)
filtered_da.plot(hue="exp_id", color="k", add_legend=False, ax=ax, lw=0.5)
if outliers_da.size > 0:
outliers_da.plot(hue="exp_id", color="r", add_legend=False, ax=ax, lw=0.5)
fig.savefig(filename)


@profileit
def run_delta_analysis(
ds: xr.Dataset,
ensemble_df: pd.DataFrame,
Expand Down Expand Up @@ -347,7 +359,6 @@ def run_delta_analysis(
return all_delta_indices


@timeit
def plot_obs_sims(
obs: xr.Dataset,
sim_prior: xr.Dataset,
Expand All @@ -356,7 +367,7 @@ def plot_obs_sims(
filtering_var: str,
filter_range: List[int] = [1990, 2019],
fig_dir: Union[str, Path] = "figures",
reference_year: int = 1986,
reference_year: float = 1986.0,
sim_alpha: float = 0.4,
obs_alpha: float = 1.0,
sigma: float = 2,
Expand Down Expand Up @@ -529,7 +540,7 @@ def plot_obs_sims_3(
filtering_var: str,
filter_range: List[int] = [1990, 2019],
fig_dir: Union[str, Path] = "figures",
reference_year: int = 1986,
reference_year: float = 1986.0,
sim_alpha: float = 0.4,
obs_alpha: float = 1.0,
sigma: float = 2,
Expand Down Expand Up @@ -736,7 +747,7 @@ def plot_obs_sims_3(
"--obs_url",
help="""Path to "observed" mass balance.""",
type=str,
default="data/mass_balance/mankoff_greenland_mass_balance.nc",
default="data/mass_balance/combined_greenland_mass_balance.nc",
)
parser.add_argument(
"--engine",
Expand Down Expand Up @@ -797,8 +808,8 @@ def plot_obs_sims_3(
parser.add_argument(
"--reference_year",
help="""Reference year.""",
type=int,
default=1986,
type=float,
default=2004,
)
parser.add_argument(
"--n_jobs",
Expand All @@ -819,7 +830,13 @@ def plot_obs_sims_3(
nargs="*",
)

options = parser.parse_args()
parser.add_argument(
"--log",
default="WARNING",
help="Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)

options, unknown = parser.parse_known_args()
basin_files = options.FILES
ensemble = options.ensemble
engine = options.engine
Expand Down Expand Up @@ -873,51 +890,26 @@ def plot_obs_sims_3(
k + "_uncertainty": v + "_uncertainty" for k, v in cumulative_vars.items()
}

ds = prp.load_ensemble(basin_files, parallel=parallel, engine=engine).sortby(
"basin"
simulated_ds = prepare_simulations(
basin_files, ragis_config, reference_year, parallel=parallel, engine=engine
)
# for v in ds.data_vars:
# if ds[v].dtype.kind == "S":
# ds[v] = ds[v].astype(str)
# for c in ds.coords:
# if ds[c].dtype.kind == "S":
# ds.coords[c] = ds.coords[c].astype(str)

# ds = xr.apply_ufunc(np.vectorize(convert_bstrings_to_str), ds, dask="parallelized")
ds = ds.dropna(dim="exp_id")
# fig, ax = plt.subplots(1, 1)
# ds.sel(time=slice(str(filter_start_year), str(filter_end_year))).sel(
# basin="GIS", ensemble_id=ensemble
# ).grounding_line_flux.plot(hue="exp_id", add_legend=False, ax=ax, lw=0.5)
# fig.savefig("grounding_line_flux_unfiltered.pdf")

ds = prp.standardize_variable_names(ds, ragis_config["PISM Spatial"])
ds[ragis_config["Cumulative Variables"]["cumulative_grounding_line_flux"]] = ds[
ragis_config["Flux Variables"]["grounding_line_flux"]
].cumsum() / len(ds.time)
ds[ragis_config["Cumulative Variables"]["cumulative_smb"]] = ds[
ragis_config["Flux Variables"]["smb_flux"]
].cumsum() / len(ds.time)
ds = prp.normalize_cumulative_variables(
ds,
list(ragis_config["Cumulative Variables"].values()),
reference_year=reference_year,
filtered_ds, outliers_ds = filter_outliers(
simulated_ds, outlier_range=outlier_range, outlier_variable=outlier_variable
)

fig, ax = plt.subplots(1, 1)
ds.sel(time=slice(str(filter_start_year), str(filter_end_year))).sel(
basin="GIS", ensemble_id=ensemble
).grounding_line_flux.plot(hue="exp_id", add_legend=False, ax=ax, lw=0.5)
fig.savefig("grounding_line_flux_unfiltered.pdf")

result = filter_outliers(
ds, outlier_range=outlier_range, outlier_variable=outlier_variable
plot_outliers(
filtered_ds.sel(basin="GIS", ensemble_id="RAGIS")[outlier_variable],
outliers_ds.sel(basin="GIS", ensemble_id="RAGIS")[outlier_variable],
Path(fig_dir) / Path(f"{outlier_variable}_filtering.pdf"),
)
filtered_ds = result["filtered"]
outliers_ds = result["outliers"]

fig, ax = plt.subplots(1, 1)
ds.sel(time=slice(str(filter_start_year), str(filter_end_year))).sel(
basin="GIS", ensemble_id=ensemble
).grounding_line_flux.plot(hue="exp_id", add_legend=False, ax=ax, lw=0.5)
fig.savefig("grounding_line_flux_filtered.pdf")

prior_config = ds.sel(pism_config_axis=params).pism_config
prior_config = simulated_ds.sel(pism_config_axis=params).pism_config
prior = config_to_dataframe(prior_config)
prior["Ensemble"] = "Prior"

Expand Down Expand Up @@ -1014,7 +1006,8 @@ def plot_obs_sims_3(
.mean()
)

simulated = filtered_ds.sel(basin=["CE", "CW", "GIS", "NE", "NO", "NW", "SE", "SW"])
simulated = filtered_ds

simulated_resampled = (
simulated.drop_vars(["pism_config", "run_stats"], errors="ignore")
.resample(time=resampling_frequency)
Expand Down Expand Up @@ -1090,7 +1083,7 @@ def plot_obs_sims_3(
config=ragis_config,
filtering_var=obs_mean_var,
filter_range=[filter_start_year, filter_end_year],
fig_dir=result_dir / Path("figures"),
fig_dir=fig_dir,
obs_alpha=obs_alpha,
sim_alpha=sim_alpha,
)
Expand Down Expand Up @@ -1173,7 +1166,7 @@ def plot_obs_sims_3(
"calving.rate_scaling.file"
].map(calving_dict)

to_analyze = ds.sel(time=slice("1980-01-01", "2020-01-01"))
to_analyze = simulated_ds.sel(time=slice("1980-01-01", "2020-01-01"))
all_delta_indices = run_delta_analysis(
to_analyze, ensemble_df, list(flux_vars.values())[:2], notebook=notebook
)
Expand Down
11 changes: 9 additions & 2 deletions data/03_prepare_mass_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@

fn = "mankoff_greenland_mass_balance.nc"
p_fn = p / fn
ds.pint.dequantify().to_netcdf(p_fn, encoding=encoding)
mankoff_ds = ds.pint.dequantify()
mankoff_ds.to_netcdf(p_fn, encoding=encoding)

short_name = "GREENLAND_MASS_TELLUS_MASCON_CRI_TIME_SERIES_RL06.1_V3"
results = download_earthaccess(result_dir=p, short_name=short_name)
Expand Down Expand Up @@ -140,4 +141,10 @@
ds["cumulative_mass_balance_uncertainty"].attrs.update({"units": "Gt"})
fn = "grace_greenland_mass_balance.nc"
p_fn = p / fn
ds.to_netcdf(fn)
grace_ds = ds
grace_ds.to_netcdf(fn)

fn = "combined_greenland_mass_balance.nc"
p_fn = p / fn
combined_ds = xr.merge([grace_ds, mankoff_ds])
combined_ds.to_netcdf(fn)
Loading

0 comments on commit fd40272

Please sign in to comment.