Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions e3sm_diags/driver/arm_diags_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,6 @@ def _run_diag_annual_cycle(parameter: ARMDiagsParameter) -> ARMDiagsParameter:
vars_funcs = _get_vars_funcs_for_derived_var(ds_ref, var)
target_var = list(vars_funcs.keys())[0][0]

# NOTE: The bounds dimension can be "nv", which is not
# currently recognized as a valid bounds dimension
# by xcdat. We rename it to "bnds" to make it compatible.
ds_ref = _rename_bounds_dim(ds_ref)

ds_ref_climo = ds_ref.temporal.climatology(target_var, "month")
da_ref = vars_funcs[(target_var,)](ds_ref_climo[target_var]).rename(
var
Expand Down Expand Up @@ -535,33 +530,3 @@ def _save_metrics_to_json(parameter: ARMDiagsParameter, metrics_dict: Dict[str,
json.dump(metrics_dict, outfile)

logger.info(f"Metrics saved in: {abs_path}")


def _rename_bounds_dim(ds: xr.Dataset) -> xr.Dataset:
"""
Renames the bounds dimension "nv" to "bnds" in the given xarray.Dataset for
xCDAT compatibility.

This is a temporary workaround to ensure compatibility with xCDAT's bounds
handling. The bounds dimension "nv" is commonly used in datasets to
represent the number of vertices in a polygon, but xCDAT expects the
bounds dimension to be in `xcdat.bounds.VALID_BOUNDS_DIMS`. This function
renames "nv" to "bnds" to align with xCDAT's expectations.

Parameters
----------
ds : xr.Dataset
The input xarray.Dataset which may contain a bounds dimension named "nv".

Returns
-------
xr.Dataset
A new xarray.Dataset with the "nv" dimension renamed to "bnds" if it
existed; otherwise, the original dataset copy.
"""
ds_new = ds.copy()

if "nv" in ds_new.dims:
ds_new = ds_new.rename({"nv": "bnds"})

return ds_new
7 changes: 0 additions & 7 deletions e3sm_diags/parser/enso_diags_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,3 @@ def add_arguments(self):
help="End year for the timeseries files.",
required=False,
)

self.parser.add_argument(
"--plot_type",
dest="plot_type",
help="Type of plot to generate: 'map' or 'scatter'.",
required=False,
)
27 changes: 3 additions & 24 deletions e3sm_diags/plot/aerosol_aeronet_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from e3sm_diags.logger import _setup_child_logger
from e3sm_diags.parameter.core_parameter import CoreParameter
from e3sm_diags.plot.lat_lon_plot import _add_colormap
from e3sm_diags.plot.utils import _save_main_plot, _save_single_subplot
from e3sm_diags.plot.utils import _save_plot

matplotlib.use("Agg")
import matplotlib.pyplot as plt # isort:skip # noqa: E402
Expand All @@ -22,28 +22,7 @@
]
# Border padding relative to subplot axes for saving individual panels
# (left, bottom, right, top) in page coordinates.
BORDER_PADDING_COLORMAP = (-0.06, 0.25, 0.13, 0.25)
BORDER_PADDING_SCATTER = (-0.08, -0.04, 0.15, 0.04)


def _save_plot_aerosol_aeronet(fig, parameter):
"""Save aerosol_aeronet plots using the _save_single_subplot helper function.

This function handles the special case where different border padding is needed
for each panel by calling _save_single_subplot twice with panel-specific
configurations (BORDER_PADDING_COLORMAP for panel 0, BORDER_PADDING_SCATTER for panel 1).
"""
# Save the main plot
_save_main_plot(parameter)

# Save subplots with different border padding by calling general function
# for each panel individually
if parameter.output_format_subplot:
# Save colormap panel (panel 0) with its specific border padding
_save_single_subplot(fig, parameter, 0, PANEL_CFG[0], BORDER_PADDING_COLORMAP)

# Save scatter panel (panel 1) with its specific border padding
_save_single_subplot(fig, parameter, 1, PANEL_CFG[1], BORDER_PADDING_SCATTER)
BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03)


def plot(
Expand Down Expand Up @@ -129,4 +108,4 @@ def plot(

plt.loglog(ref_site_arr, test_site_arr, "kx", markersize=3.0, mfc="none")

_save_plot_aerosol_aeronet(fig, parameter)
_save_plot(fig, parameter, PANEL_CFG, BORDER_PADDING)
27 changes: 6 additions & 21 deletions e3sm_diags/plot/enso_diags_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
_get_x_ticks,
_get_y_ticks,
_make_lon_cyclic,
_save_main_plot,
_save_plot,
_save_single_subplot,
)

matplotlib.use("Agg")
Expand All @@ -43,19 +41,6 @@
# Use 179.99 as central longitude due to https://github.com/SciTools/cartopy/issues/946
PROJECTION = ccrs.PlateCarree(central_longitude=179.99)

# Border padding relative to subplot axes for saving individual panels
# (left, bottom, right, top) in page coordinates
ENSO_BORDER_PADDING_MAP = (-0.07, -0.025, 0.17, 0.022)


def _save_plot_scatter(fig: plt.Figure, parameter: EnsoDiagsParameter):
"""Save the scatter plot using the shared _save_single_subplot function."""
_save_main_plot(parameter)

# Save the single subplot using shared helper (panel_config=None for full figure)
if parameter.output_format_subplot:
_save_single_subplot(fig, parameter, 0, None, None)


def plot_scatter(
parameter: EnsoDiagsParameter, x: MetricsDictScatter, y: MetricsDictScatter
Expand Down Expand Up @@ -153,7 +138,7 @@ def plot_scatter(
plt.ylabel("{} anomaly ({})".format(y["var"], y["units"]))
plt.legend()

_save_plot_scatter(fig, parameter)
_save_plot(fig, parameter)

plt.close()

Expand Down Expand Up @@ -207,7 +192,7 @@ def plot_map(
)
_plot_diff_rmse_and_corr(fig, metrics_dict["diff"]) # type: ignore

_save_plot(fig, parameter, DEFAULT_PANEL_CFG, ENSO_BORDER_PADDING_MAP)
_save_plot(fig, parameter)

plt.close()

Expand Down Expand Up @@ -316,17 +301,17 @@ def _add_colormap(
top_text = "Max\nMin\nMean\nSTD"
fig.text(
DEFAULT_PANEL_CFG[subplot_num][0] + 0.6635,
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2,
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2107,
top_text,
ha="left",
fontdict={"fontsize": 9},
fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE},
)
fig.text(
DEFAULT_PANEL_CFG[subplot_num][0] + 0.7635,
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2,
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2107,
"%.2f\n%.2f\n%.2f\n%.2f" % metrics_values, # type: ignore
ha="right",
fontdict={"fontsize": 9},
fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE},
)

# Hatch text
Expand Down
13 changes: 11 additions & 2 deletions e3sm_diags/plot/mp_partition_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os

import matplotlib

from e3sm_diags.driver.utils.io import _get_output_dir
from e3sm_diags.logger import _setup_child_logger
from e3sm_diags.plot.utils import _save_main_plot

matplotlib.use("Agg")
import matplotlib.pyplot as plt # isort:skip # noqa: E402
Expand Down Expand Up @@ -87,4 +89,11 @@ def plot(metrics_dict, parameter):
ax.legend(loc="upper left")
ax.set_title("Mixed-phase Partition LCF [30S - 70S]")

_save_main_plot(parameter)
for f in parameter.output_format:
f = f.lower().split(".")[-1]
fnm = os.path.join(
_get_output_dir(parameter),
f"{parameter.output_file}" + "." + f,
)
plt.savefig(fnm)
logger.info(f"Plot saved in: {fnm}")
10 changes: 3 additions & 7 deletions e3sm_diags/plot/qbo_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List, Literal, Tuple, TypedDict
from typing import List, Literal, TypedDict

import matplotlib
import numpy as np
Expand All @@ -18,15 +18,11 @@
PANEL_CFG = [
(0.075, 0.75, 0.6, 0.175),
(0.075, 0.525, 0.6, 0.175),
(0.735, 0.525, 0.2, 0.4),
(0.725, 0.525, 0.2, 0.4),
(0.075, 0.285, 0.85, 0.175),
(0.075, 0.04, 0.85, 0.175),
]

# Border padding relative to subplot axes for saving individual panels
# (left, bottom, right, top) in page coordinates
QBO_BORDER_PADDING: Tuple[float, float, float, float] = (-0.07, -0.03, 0.009, 0.03)

LABEL_SIZE = 14
CMAP = plt.cm.RdBu_r

Expand Down Expand Up @@ -207,7 +203,7 @@ def plot(parameter: QboParameter, test_dict, ref_dict):
fig.suptitle(parameter.main_title, x=0.5, y=0.97, fontsize=15)

# Save figure
_save_plot(fig, parameter, PANEL_CFG, QBO_BORDER_PADDING)
_save_plot(fig, parameter, PANEL_CFG)

plt.close()

Expand Down
66 changes: 54 additions & 12 deletions e3sm_diags/plot/tc_analysis_plot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib
import numpy as np
import xcdat as xc
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter

from e3sm_diags.driver.utils.io import _get_output_dir
from e3sm_diags.logger import _setup_child_logger
from e3sm_diags.plot.utils import MAIN_TITLE_FONTSIZE, _save_main_plot
from e3sm_diags.plot.utils import MAIN_TITLE_FONTSIZE

matplotlib.use("agg")
import matplotlib.pyplot as plt # isort:skip # noqa: E402
Expand Down Expand Up @@ -100,8 +103,16 @@ def plot(test, ref, parameter, basin_dict):
y=0.99,
)

parameter.output_file = "tc-intensity"
_save_main_plot(parameter)
output_file_name = "tc-intensity"
for f in parameter.output_format:
f = f.lower().split(".")[-1]
fnm = os.path.join(
_get_output_dir(parameter),
output_file_name + "." + f,
)
plt.savefig(fnm)
logger.info(f"Plot saved in: {fnm}")
plt.close()

# TC frequency of each basins
fig = plt.figure(figsize=(12, 7))
Expand Down Expand Up @@ -139,8 +150,16 @@ def plot(test, ref, parameter, basin_dict):
ax.set_ylabel("Fraction")
ax.set_title("Relative frequency of TCs for each ocean basins")

parameter.output_file = "tc-frequency"
_save_main_plot(parameter)
output_file_name = "tc-frequency"
for f in parameter.output_format:
f = f.lower().split(".")[-1]
fnm = os.path.join(
_get_output_dir(parameter),
output_file_name + "." + f,
)
plt.savefig(fnm)
logger.info(f"Plot saved in: {fnm}")
plt.close()

fig1 = plt.figure(figsize=(12, 6))
ax = fig1.add_subplot(111)
Expand Down Expand Up @@ -171,9 +190,16 @@ def plot(test, ref, parameter, basin_dict):
ax.set_title(
"Distribution of accumulated cyclone energy (ACE) among various ocean basins"
)

parameter.output_file = "ace-distribution"
_save_main_plot(parameter)
output_file_name = "ace-distribution"
for f in parameter.output_format:
f = f.lower().split(".")[-1]
fnm = os.path.join(
_get_output_dir(parameter),
output_file_name + "." + f,
)
plt.savefig(fnm)
logger.info(f"Plot saved in: {fnm}")
plt.close()

fig, axes = plt.subplots(2, 3, figsize=(12, 6), sharex=True, sharey=True)
fig.subplots_adjust(hspace=0.4, wspace=0.15)
Expand Down Expand Up @@ -205,8 +231,16 @@ def plot(test, ref, parameter, basin_dict):
y=0.99,
)

parameter.output_file = "tc-frequency-annual-cycle"
_save_main_plot(parameter)
output_file_name = "tc-frequency-annual-cycle"
for f in parameter.output_format:
f = f.lower().split(".")[-1]
fnm = os.path.join(
_get_output_dir(parameter),
output_file_name + "." + f,
)
plt.savefig(fnm)
logger.info(f"Plot saved in: {fnm}")
plt.close()

##########################################################
# Plot TC tracks density
Expand Down Expand Up @@ -252,9 +286,17 @@ def plot_map(test_data, ref_data, region, parameter):

# Figure title
fig.suptitle(PLOT_INFO[region]["title"], x=0.5, y=0.9, fontsize=14)
parameter.output_file = "{}-density-map".format(region)
output_file_name = "{}-density-map".format(region)

_save_main_plot(parameter)
for f in parameter.output_format:
f = f.lower().split(".")[-1]
fnm = os.path.join(
_get_output_dir(parameter),
output_file_name + "." + f,
)
plt.savefig(fnm)
logger.info(f"Plot saved in: {fnm}")
plt.close()


def plot_panel(n, fig, proj, var, var_num_years, region, title):
Expand Down
Loading