Skip to content

Commit e60e93a

Browse files
committed
refactor for review comments; more refining for enso
1 parent 4014818 commit e60e93a

File tree

4 files changed

+75
-101
lines changed

4 files changed

+75
-101
lines changed

e3sm_diags/plot/aerosol_aeronet_plot.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import os
2+
13
import matplotlib
24
import numpy as np
35
import xarray as xr
46

7+
from e3sm_diags.driver.utils.io import _get_output_dir
58
from e3sm_diags.driver.utils.type_annotations import MetricsDict
69
from e3sm_diags.logger import _setup_child_logger
710
from e3sm_diags.parameter.core_parameter import CoreParameter
811
from e3sm_diags.plot.lat_lon_plot import _add_colormap
12+
from e3sm_diags.plot.utils import _save_single_subplot
913

1014
matplotlib.use("Agg")
1115
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -26,13 +30,12 @@
2630

2731

2832
def _save_plot_aerosol_aeronet(fig, parameter):
29-
"""Save aerosol_aeronet plots with different border padding for each panel."""
30-
import os
31-
32-
from matplotlib.transforms import Bbox
33-
34-
from e3sm_diags.driver.utils.io import _get_output_dir
33+
"""Save aerosol_aeronet plots using the _save_single_subplot helper function.
3534
35+
This function handles the special case where different border padding is needed
36+
for each panel by calling _save_single_subplot twice with panel-specific
37+
configurations (BORDER_PADDING_COLORMAP for panel 0, BORDER_PADDING_SCATTER for panel 1).
38+
"""
3639
# Save the main plot
3740
for f in parameter.output_format:
3841
f = f.lower().split(".")[-1]
@@ -43,28 +46,14 @@ def _save_plot_aerosol_aeronet(fig, parameter):
4346
plt.savefig(fnm)
4447
logger.info(f"Plot saved in: {fnm}")
4548

46-
# Save individual subplots with different border padding
47-
border_paddings = [BORDER_PADDING_COLORMAP, BORDER_PADDING_SCATTER]
49+
# Save subplots with different border padding by calling general function
50+
# for each panel individually
51+
if parameter.output_format_subplot:
52+
# Save colormap panel (panel 0) with its specific border padding
53+
_save_single_subplot(fig, parameter, 0, PANEL_CFG[0], BORDER_PADDING_COLORMAP)
4854

49-
for f in parameter.output_format_subplot:
50-
fnm = os.path.join(
51-
_get_output_dir(parameter),
52-
parameter.output_file,
53-
)
54-
page = fig.get_size_inches()
55-
56-
for idx, (panel, border_padding) in enumerate(zip(PANEL_CFG, border_paddings)):
57-
# Extent of subplot
58-
subpage = np.array(panel).reshape(2, 2)
59-
subpage[1, :] = subpage[0, :] + subpage[1, :]
60-
subpage = subpage + np.array(border_padding).reshape(2, 2)
61-
subpage_list = list(((subpage) * page).flatten())
62-
extent = Bbox.from_extents(*subpage_list)
63-
64-
# Save subplot
65-
fname = fnm + ".%i." % idx + f
66-
plt.savefig(fname, bbox_inches=extent)
67-
logger.info(f"Sub-plot saved in: {fname}")
55+
# Save scatter panel (panel 1) with its specific border padding
56+
_save_single_subplot(fig, parameter, 1, PANEL_CFG[1], BORDER_PADDING_SCATTER)
6857

6958

7059
def plot(

e3sm_diags/plot/enso_diags_plot.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from typing import TYPE_CHECKING, List, Tuple
45

56
import cartopy.crs as ccrs
@@ -10,6 +11,7 @@
1011
from numpy.polynomial.polynomial import polyfit
1112

1213
from e3sm_diags.derivations.default_regions_xr import REGION_SPECS
14+
from e3sm_diags.driver.utils.io import _get_output_dir
1315
from e3sm_diags.logger import _setup_child_logger
1416
from e3sm_diags.parameter.enso_diags_parameter import EnsoDiagsParameter
1517
from e3sm_diags.plot.utils import (
@@ -24,6 +26,7 @@
2426
_get_y_ticks,
2527
_make_lon_cyclic,
2628
_save_plot,
29+
_save_single_subplot,
2730
)
2831

2932
matplotlib.use("Agg")
@@ -43,15 +46,12 @@
4346

4447
# Border padding relative to subplot axes for saving individual panels
4548
# (left, bottom, right, top) in page coordinates
46-
ENSO_BORDER_PADDING_MAP = (-0.07, -0.025, 0.2, 0.035)
49+
ENSO_BORDER_PADDING_MAP = (-0.07, -0.025, 0.17, 0.022)
4750

4851

4952
def _save_plot_scatter(fig: plt.Figure, parameter: EnsoDiagsParameter):
50-
"""Save the scatter plot using a simplified approach for single panel plots."""
51-
import os
52-
53-
from e3sm_diags.driver.utils.io import _get_output_dir
54-
53+
"""Save the scatter plot using the shared _save_single_subplot function."""
54+
# Save the main plot
5555
for f in parameter.output_format:
5656
f = f.lower().split(".")[-1]
5757
fnm = os.path.join(
@@ -61,14 +61,9 @@ def _save_plot_scatter(fig: plt.Figure, parameter: EnsoDiagsParameter):
6161
plt.savefig(fnm)
6262
logger.info(f"Plot saved in: {fnm}")
6363

64-
# Save individual subplots (single panel for scatter)
65-
for f in parameter.output_format_subplot:
66-
fnm = os.path.join(
67-
_get_output_dir(parameter),
68-
parameter.output_file + ".0." + f,
69-
)
70-
plt.savefig(fnm)
71-
logger.info(f"Sub-plot saved in: {fnm}")
64+
# Save the single subplot using shared helper (panel_config=None for full figure)
65+
if parameter.output_format_subplot:
66+
_save_single_subplot(fig, parameter, 0, None, None)
7267

7368

7469
def plot_scatter(
@@ -330,17 +325,17 @@ def _add_colormap(
330325
top_text = "Max\nMin\nMean\nSTD"
331326
fig.text(
332327
DEFAULT_PANEL_CFG[subplot_num][0] + 0.6635,
333-
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2107,
328+
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2,
334329
top_text,
335330
ha="left",
336-
fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE},
331+
fontdict={"fontsize": 9},
337332
)
338333
fig.text(
339334
DEFAULT_PANEL_CFG[subplot_num][0] + 0.7635,
340-
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2107,
335+
DEFAULT_PANEL_CFG[subplot_num][1] + 0.2,
341336
"%.2f\n%.2f\n%.2f\n%.2f" % metrics_values, # type: ignore
342337
ha="right",
343-
fontdict={"fontsize": SECONDARY_TITLE_FONTSIZE},
338+
fontdict={"fontsize": 9},
344339
)
345340

346341
# Hatch text

e3sm_diags/plot/tropical_subseasonal_plot.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from e3sm_diags.driver.utils.io import _get_output_dir
1212
from e3sm_diags.logger import _setup_child_logger
1313
from e3sm_diags.parameter.core_parameter import CoreParameter
14+
from e3sm_diags.plot.utils import _save_plot
1415

1516
matplotlib.use("Agg")
1617
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -227,61 +228,6 @@ def create_colormap_clevs(cmapSpec, clevs):
227228
return cmapSpecUse, normSpecUse
228229

229230

230-
def _save_plot(fig: plt.figure, parameter: CoreParameter):
231-
"""Save the plot using the figure object and parameter configs.
232-
233-
This function creates the output filename to save the plot. It also
234-
saves each individual subplot if the reference name is an empty string ("").
235-
236-
Parameters
237-
----------
238-
fig : plt.figure
239-
The plot figure.
240-
parameter : CoreParameter
241-
The CoreParameter with file configurations.
242-
"""
243-
for f in parameter.output_format:
244-
f = f.lower().split(".")[-1]
245-
fnm = os.path.join(
246-
_get_output_dir(parameter),
247-
parameter.output_file + "." + f,
248-
)
249-
plt.savefig(fnm)
250-
logger.info(f"Plot saved in: {fnm}")
251-
252-
# Save individual subplots
253-
if parameter.ref_name == "":
254-
panels = [PANEL[0]]
255-
else:
256-
panels = PANEL
257-
258-
for f in parameter.output_format_subplot:
259-
fnm = os.path.join(
260-
_get_output_dir(parameter),
261-
parameter.output_file,
262-
)
263-
page = fig.get_size_inches()
264-
265-
for idx, panel in enumerate(panels):
266-
# Extent of subplot
267-
subpage = np.array(panel).reshape(2, 2)
268-
subpage[1, :] = subpage[0, :] + subpage[1, :]
269-
subpage = subpage + np.array(BORDER_PADDING).reshape(2, 2)
270-
subpage = list(((subpage) * page).flatten()) # type: ignore
271-
extent = matplotlib.transforms.Bbox.from_extents(*subpage)
272-
273-
# Save subplot
274-
fname = fnm + ".%i." % idx + f
275-
plt.savefig(fname, bbox_inches=extent)
276-
277-
orig_fnm = os.path.join(
278-
_get_output_dir(parameter),
279-
parameter.output_file,
280-
)
281-
fname = orig_fnm + ".%i." % idx + f
282-
logger.info(f"Sub-plot saved in: {fname}")
283-
284-
285231
def _wave_frequency_plot( # noqa: C901
286232
subplot_num: int,
287233
var: xr.DataArray,
@@ -770,6 +716,6 @@ def plot(
770716
do_zoom=do_zoom,
771717
)
772718

773-
_save_plot(fig, parameter)
719+
_save_plot(fig, parameter, PANEL, BORDER_PADDING)
774720

775721
plt.close()

e3sm_diags/plot/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,50 @@
5757
]
5858

5959

60+
def _save_single_subplot(fig, parameter, panel_idx, panel_config, border_padding):
61+
"""Save a single subplot with specific border padding.
62+
63+
Parameters
64+
----------
65+
fig : plt.Figure
66+
The plot figure.
67+
parameter : CoreParameter
68+
The CoreParameter with file configurations.
69+
panel_idx : int
70+
The panel index for filename generation.
71+
panel_config : tuple or None
72+
Panel configuration (left, bottom, width, height). If None, saves entire figure.
73+
border_padding : tuple
74+
Border padding (left, bottom, right, top) in page coordinates.
75+
"""
76+
for f in parameter.output_format_subplot:
77+
fnm = os.path.join(
78+
_get_output_dir(parameter),
79+
parameter.output_file,
80+
)
81+
82+
if panel_config is None:
83+
# Save entire figure (for full-figure plots like scatter)
84+
fname = fnm + f".{panel_idx}." + f
85+
plt.savefig(fname)
86+
logger.info(f"Sub-plot saved in: {fname}")
87+
else:
88+
# Save cropped subplot with border padding
89+
page = fig.get_size_inches()
90+
91+
# Extent of subplot
92+
subpage = np.array(panel_config).reshape(2, 2)
93+
subpage[1, :] = subpage[0, :] + subpage[1, :]
94+
subpage = subpage + np.array(border_padding).reshape(2, 2)
95+
subpage_list = list(((subpage) * page).flatten())
96+
extent = Bbox.from_extents(*subpage_list)
97+
98+
# Save subplot
99+
fname = fnm + f".{panel_idx}." + f
100+
plt.savefig(fname, bbox_inches=extent)
101+
logger.info(f"Sub-plot saved in: {fname}")
102+
103+
60104
def _save_plot(
61105
fig: plt.Figure,
62106
parameter: CoreParameter,

0 commit comments

Comments
 (0)