Skip to content

Commit a60532c

Browse files
Fix layout for sub-plot saved in pdf format (#984)
* fix padding for zonal_mean_2d and meridional_mean_2d * fix qbo and tropical subseasonal plots * fix parser/plot for enso_diags * update for aerosol_aeronet * fix pre-commit errors * refactor for review comments; more refining for enso * Add a temporary workaround to rename nv bounds dim * Fix pre-commit issue * Refactor plotting utilities for maintainability --------- Co-authored-by: tomvothecoder <[email protected]>
1 parent 4364561 commit a60532c

File tree

9 files changed

+210
-177
lines changed

9 files changed

+210
-177
lines changed

e3sm_diags/driver/arm_diags_driver.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ def _run_diag_annual_cycle(parameter: ARMDiagsParameter) -> ARMDiagsParameter:
278278
vars_funcs = _get_vars_funcs_for_derived_var(ds_ref, var)
279279
target_var = list(vars_funcs.keys())[0][0]
280280

281+
# NOTE: The bounds dimension can be "nv", which is not
282+
# currently recognized as a valid bounds dimension
283+
# by xcdat. We rename it to "bnds" to make it compatible.
284+
ds_ref = _rename_bounds_dim(ds_ref)
285+
281286
ds_ref_climo = ds_ref.temporal.climatology(target_var, "month")
282287
da_ref = vars_funcs[(target_var,)](ds_ref_climo[target_var]).rename(
283288
var
@@ -530,3 +535,33 @@ def _save_metrics_to_json(parameter: ARMDiagsParameter, metrics_dict: Dict[str,
530535
json.dump(metrics_dict, outfile)
531536

532537
logger.info(f"Metrics saved in: {abs_path}")
538+
539+
540+
def _rename_bounds_dim(ds: xr.Dataset) -> xr.Dataset:
541+
"""
542+
Renames the bounds dimension "nv" to "bnds" in the given xarray.Dataset for
543+
xCDAT compatibility.
544+
545+
This is a temporary workaround to ensure compatibility with xCDAT's bounds
546+
handling. The bounds dimension "nv" is commonly used in datasets to
547+
represent the number of vertices in a polygon, but xCDAT expects the
548+
bounds dimension to be in `xcdat.bounds.VALID_BOUNDS_DIMS`. This function
549+
renames "nv" to "bnds" to align with xCDAT's expectations.
550+
551+
Parameters
552+
----------
553+
ds : xr.Dataset
554+
The input xarray.Dataset which may contain a bounds dimension named "nv".
555+
556+
Returns
557+
-------
558+
xr.Dataset
559+
A new xarray.Dataset with the "nv" dimension renamed to "bnds" if it
560+
existed; otherwise, the original dataset copy.
561+
"""
562+
ds_new = ds.copy()
563+
564+
if "nv" in ds_new.dims:
565+
ds_new = ds_new.rename({"nv": "bnds"})
566+
567+
return ds_new

e3sm_diags/parser/enso_diags_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,10 @@ def add_arguments(self):
5050
help="End year for the timeseries files.",
5151
required=False,
5252
)
53+
54+
self.parser.add_argument(
55+
"--plot_type",
56+
dest="plot_type",
57+
help="Type of plot to generate: 'map' or 'scatter'.",
58+
required=False,
59+
)

e3sm_diags/plot/aerosol_aeronet_plot.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from e3sm_diags.logger import _setup_child_logger
77
from e3sm_diags.parameter.core_parameter import CoreParameter
88
from e3sm_diags.plot.lat_lon_plot import _add_colormap
9-
from e3sm_diags.plot.utils import _save_plot
9+
from e3sm_diags.plot.utils import _save_main_plot, _save_single_subplot
1010

1111
matplotlib.use("Agg")
1212
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -22,7 +22,28 @@
2222
]
2323
# Border padding relative to subplot axes for saving individual panels
2424
# (left, bottom, right, top) in page coordinates.
25-
BORDER_PADDING = (-0.06, -0.03, 0.13, 0.03)
25+
BORDER_PADDING_COLORMAP = (-0.06, 0.25, 0.13, 0.25)
26+
BORDER_PADDING_SCATTER = (-0.08, -0.04, 0.15, 0.04)
27+
28+
29+
def _save_plot_aerosol_aeronet(fig, parameter):
30+
"""Save aerosol_aeronet plots using the _save_single_subplot helper function.
31+
32+
This function handles the special case where different border padding is needed
33+
for each panel by calling _save_single_subplot twice with panel-specific
34+
configurations (BORDER_PADDING_COLORMAP for panel 0, BORDER_PADDING_SCATTER for panel 1).
35+
"""
36+
# Save the main plot
37+
_save_main_plot(parameter)
38+
39+
# Save subplots with different border padding by calling general function
40+
# for each panel individually
41+
if parameter.output_format_subplot:
42+
# Save colormap panel (panel 0) with its specific border padding
43+
_save_single_subplot(fig, parameter, 0, PANEL_CFG[0], BORDER_PADDING_COLORMAP)
44+
45+
# Save scatter panel (panel 1) with its specific border padding
46+
_save_single_subplot(fig, parameter, 1, PANEL_CFG[1], BORDER_PADDING_SCATTER)
2647

2748

2849
def plot(
@@ -108,4 +129,4 @@ def plot(
108129

109130
plt.loglog(ref_site_arr, test_site_arr, "kx", markersize=3.0, mfc="none")
110131

111-
_save_plot(fig, parameter, PANEL_CFG, BORDER_PADDING)
132+
_save_plot_aerosol_aeronet(fig, parameter)

e3sm_diags/plot/enso_diags_plot.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
_get_x_ticks,
2424
_get_y_ticks,
2525
_make_lon_cyclic,
26+
_save_main_plot,
2627
_save_plot,
28+
_save_single_subplot,
2729
)
2830

2931
matplotlib.use("Agg")
@@ -41,6 +43,19 @@
4143
# Use 179.99 as central longitude due to https://github.com/SciTools/cartopy/issues/946
4244
PROJECTION = ccrs.PlateCarree(central_longitude=179.99)
4345

46+
# Border padding relative to subplot axes for saving individual panels
47+
# (left, bottom, right, top) in page coordinates
48+
ENSO_BORDER_PADDING_MAP = (-0.07, -0.025, 0.17, 0.022)
49+
50+
51+
def _save_plot_scatter(fig: plt.Figure, parameter: EnsoDiagsParameter):
52+
"""Save the scatter plot using the shared _save_single_subplot function."""
53+
_save_main_plot(parameter)
54+
55+
# Save the single subplot using shared helper (panel_config=None for full figure)
56+
if parameter.output_format_subplot:
57+
_save_single_subplot(fig, parameter, 0, None, None)
58+
4459

4560
def plot_scatter(
4661
parameter: EnsoDiagsParameter, x: MetricsDictScatter, y: MetricsDictScatter
@@ -138,7 +153,7 @@ def plot_scatter(
138153
plt.ylabel("{} anomaly ({})".format(y["var"], y["units"]))
139154
plt.legend()
140155

141-
_save_plot(fig, parameter)
156+
_save_plot_scatter(fig, parameter)
142157

143158
plt.close()
144159

@@ -192,7 +207,7 @@ def plot_map(
192207
)
193208
_plot_diff_rmse_and_corr(fig, metrics_dict["diff"]) # type: ignore
194209

195-
_save_plot(fig, parameter)
210+
_save_plot(fig, parameter, DEFAULT_PANEL_CFG, ENSO_BORDER_PADDING_MAP)
196211

197212
plt.close()
198213

@@ -301,17 +316,17 @@ def _add_colormap(
301316
top_text = "Max\nMin\nMean\nSTD"
302317
fig.text(
303318
DEFAULT_PANEL_CFG[subplot_num][0] + 0.6635,
304-
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2107,
319+
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2,
305320
top_text,
306321
ha="left",
307-
fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE},
322+
fontdict={"fontsize": 9},
308323
)
309324
fig.text(
310325
DEFAULT_PANEL_CFG[subplot_num][0] + 0.7635,
311-
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2107,
326+
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2,
312327
"%.2f\n%.2f\n%.2f\n%.2f" % metrics_values, # type: ignore
313328
ha="right",
314-
fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE},
329+
fontdict={"fontsize": 9},
315330
)
316331

317332
# Hatch text

e3sm_diags/plot/mp_partition_plot.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import os
2-
31
import matplotlib
42

5-
from e3sm_diags.driver.utils.io import _get_output_dir
63
from e3sm_diags.logger import _setup_child_logger
4+
from e3sm_diags.plot.utils import _save_main_plot
75

86
matplotlib.use("Agg")
97
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -89,11 +87,4 @@ def plot(metrics_dict, parameter):
8987
ax.legend(loc="upper left")
9088
ax.set_title("Mixed-phase Partition LCF [30S - 70S]")
9189

92-
for f in parameter.output_format:
93-
f = f.lower().split(".")[-1]
94-
fnm = os.path.join(
95-
_get_output_dir(parameter),
96-
f"{parameter.output_file}" + "." + f,
97-
)
98-
plt.savefig(fnm)
99-
logger.info(f"Plot saved in: {fnm}")
90+
_save_main_plot(parameter)

e3sm_diags/plot/qbo_plot.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import List, Literal, TypedDict
3+
from typing import List, Literal, Tuple, TypedDict
44

55
import matplotlib
66
import numpy as np
@@ -18,11 +18,15 @@
1818
PANEL_CFG = [
1919
(0.075, 0.75, 0.6, 0.175),
2020
(0.075, 0.525, 0.6, 0.175),
21-
(0.725, 0.525, 0.2, 0.4),
21+
(0.735, 0.525, 0.2, 0.4),
2222
(0.075, 0.285, 0.85, 0.175),
2323
(0.075, 0.04, 0.85, 0.175),
2424
]
2525

26+
# Border padding relative to subplot axes for saving individual panels
27+
# (left, bottom, right, top) in page coordinates
28+
QBO_BORDER_PADDING: Tuple[float, float, float, float] = (-0.07, -0.03, 0.009, 0.03)
29+
2630
LABEL_SIZE = 14
2731
CMAP = plt.cm.RdBu_r
2832

@@ -203,7 +207,7 @@ def plot(parameter: QboParameter, test_dict, ref_dict):
203207
fig.suptitle(parameter.main_title, x=0.5, y=0.97, fontsize=15)
204208

205209
# Save figure
206-
_save_plot(fig, parameter, PANEL_CFG)
210+
_save_plot(fig, parameter, PANEL_CFG, QBO_BORDER_PADDING)
207211

208212
plt.close()
209213

e3sm_diags/plot/tc_analysis_plot.py

Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
import os
2-
31
import cartopy.crs as ccrs
42
import cartopy.feature as cfeature
53
import matplotlib
64
import numpy as np
75
import xcdat as xc
86
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
97

10-
from e3sm_diags.driver.utils.io import _get_output_dir
118
from e3sm_diags.logger import _setup_child_logger
12-
from e3sm_diags.plot.utils import MAIN_TITLE_FONTSIZE
9+
from e3sm_diags.plot.utils import MAIN_TITLE_FONTSIZE, _save_main_plot
1310

1411
matplotlib.use("agg")
1512
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -103,16 +100,8 @@ def plot(test, ref, parameter, basin_dict):
103100
y=0.99,
104101
)
105102

106-
output_file_name = "tc-intensity"
107-
for f in parameter.output_format:
108-
f = f.lower().split(".")[-1]
109-
fnm = os.path.join(
110-
_get_output_dir(parameter),
111-
output_file_name + "." + f,
112-
)
113-
plt.savefig(fnm)
114-
logger.info(f"Plot saved in: {fnm}")
115-
plt.close()
103+
parameter.output_file = "tc-intensity"
104+
_save_main_plot(parameter)
116105

117106
# TC frequency of each basins
118107
fig = plt.figure(figsize=(12, 7))
@@ -150,16 +139,8 @@ def plot(test, ref, parameter, basin_dict):
150139
ax.set_ylabel("Fraction")
151140
ax.set_title("Relative frequency of TCs for each ocean basins")
152141

153-
output_file_name = "tc-frequency"
154-
for f in parameter.output_format:
155-
f = f.lower().split(".")[-1]
156-
fnm = os.path.join(
157-
_get_output_dir(parameter),
158-
output_file_name + "." + f,
159-
)
160-
plt.savefig(fnm)
161-
logger.info(f"Plot saved in: {fnm}")
162-
plt.close()
142+
parameter.output_file = "tc-frequency"
143+
_save_main_plot(parameter)
163144

164145
fig1 = plt.figure(figsize=(12, 6))
165146
ax = fig1.add_subplot(111)
@@ -190,16 +171,9 @@ def plot(test, ref, parameter, basin_dict):
190171
ax.set_title(
191172
"Distribution of accumulated cyclone energy (ACE) among various ocean basins"
192173
)
193-
output_file_name = "ace-distribution"
194-
for f in parameter.output_format:
195-
f = f.lower().split(".")[-1]
196-
fnm = os.path.join(
197-
_get_output_dir(parameter),
198-
output_file_name + "." + f,
199-
)
200-
plt.savefig(fnm)
201-
logger.info(f"Plot saved in: {fnm}")
202-
plt.close()
174+
175+
parameter.output_file = "ace-distribution"
176+
_save_main_plot(parameter)
203177

204178
fig, axes = plt.subplots(2, 3, figsize=(12, 6), sharex=True, sharey=True)
205179
fig.subplots_adjust(hspace=0.4, wspace=0.15)
@@ -231,16 +205,8 @@ def plot(test, ref, parameter, basin_dict):
231205
y=0.99,
232206
)
233207

234-
output_file_name = "tc-frequency-annual-cycle"
235-
for f in parameter.output_format:
236-
f = f.lower().split(".")[-1]
237-
fnm = os.path.join(
238-
_get_output_dir(parameter),
239-
output_file_name + "." + f,
240-
)
241-
plt.savefig(fnm)
242-
logger.info(f"Plot saved in: {fnm}")
243-
plt.close()
208+
parameter.output_file = "tc-frequency-annual-cycle"
209+
_save_main_plot(parameter)
244210

245211
##########################################################
246212
# Plot TC tracks density
@@ -286,17 +252,9 @@ def plot_map(test_data, ref_data, region, parameter):
286252

287253
# Figure title
288254
fig.suptitle(PLOT_INFO[region]["title"], x=0.5, y=0.9, fontsize=14)
289-
output_file_name = "{}-density-map".format(region)
255+
parameter.output_file = "{}-density-map".format(region)
290256

291-
for f in parameter.output_format:
292-
f = f.lower().split(".")[-1]
293-
fnm = os.path.join(
294-
_get_output_dir(parameter),
295-
output_file_name + "." + f,
296-
)
297-
plt.savefig(fnm)
298-
logger.info(f"Plot saved in: {fnm}")
299-
plt.close()
257+
_save_main_plot(parameter)
300258

301259

302260
def plot_panel(n, fig, proj, var, var_num_years, region, title):

0 commit comments

Comments
 (0)