Skip to content

Commit 0398c99

Browse files
committed
Refactor plotting utilities for maintainability
1 parent 78bf83b commit 0398c99

File tree

5 files changed

+89
-147
lines changed

5 files changed

+89
-147
lines changed

e3sm_diags/plot/aerosol_aeronet_plot.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
import os
2-
31
import matplotlib
42
import numpy as np
53
import xarray as xr
64

7-
from e3sm_diags.driver.utils.io import _get_output_dir
85
from e3sm_diags.driver.utils.type_annotations import MetricsDict
96
from e3sm_diags.logger import _setup_child_logger
107
from e3sm_diags.parameter.core_parameter import CoreParameter
118
from e3sm_diags.plot.lat_lon_plot import _add_colormap
12-
from e3sm_diags.plot.utils import _save_single_subplot
9+
from e3sm_diags.plot.utils import _save_main_plot, _save_single_subplot
1310

1411
matplotlib.use("Agg")
1512
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -37,14 +34,7 @@ def _save_plot_aerosol_aeronet(fig, parameter):
3734
configurations (BORDER_PADDING_COLORMAP for panel 0, BORDER_PADDING_SCATTER for panel 1).
3835
"""
3936
# Save the main plot
40-
for f in parameter.output_format:
41-
f = f.lower().split(".")[-1]
42-
fnm = os.path.join(
43-
_get_output_dir(parameter),
44-
parameter.output_file + "." + f,
45-
)
46-
plt.savefig(fnm)
47-
logger.info(f"Plot saved in: {fnm}")
37+
_save_main_plot(parameter)
4838

4939
# Save subplots with different border padding by calling general function
5040
# for each panel individually

e3sm_diags/plot/enso_diags_plot.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import os
43
from typing import TYPE_CHECKING, List, Tuple
54

65
import cartopy.crs as ccrs
@@ -11,7 +10,6 @@
1110
from numpy.polynomial.polynomial import polyfit
1211

1312
from e3sm_diags.derivations.default_regions_xr import REGION_SPECS
14-
from e3sm_diags.driver.utils.io import _get_output_dir
1513
from e3sm_diags.logger import _setup_child_logger
1614
from e3sm_diags.parameter.enso_diags_parameter import EnsoDiagsParameter
1715
from e3sm_diags.plot.utils import (
@@ -25,6 +23,7 @@
2523
_get_x_ticks,
2624
_get_y_ticks,
2725
_make_lon_cyclic,
26+
_save_main_plot,
2827
_save_plot,
2928
_save_single_subplot,
3029
)
@@ -51,15 +50,7 @@
5150

5251
def _save_plot_scatter(fig: plt.Figure, parameter: EnsoDiagsParameter):
5352
"""Save the scatter plot using the shared _save_single_subplot function."""
54-
# Save the main plot
55-
for f in parameter.output_format:
56-
f = f.lower().split(".")[-1]
57-
fnm = os.path.join(
58-
_get_output_dir(parameter),
59-
parameter.output_file + "." + f,
60-
)
61-
plt.savefig(fnm)
62-
logger.info(f"Plot saved in: {fnm}")
53+
_save_main_plot(parameter)
6354

6455
# Save the single subplot using shared helper (panel_config=None for full figure)
6556
if parameter.output_format_subplot:

e3sm_diags/plot/mp_partition_plot.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import os
2-
31
import matplotlib
42

5-
from e3sm_diags.driver.utils.io import _get_output_dir
63
from e3sm_diags.logger import _setup_child_logger
4+
from e3sm_diags.plot.utils import _save_main_plot
75

86
matplotlib.use("Agg")
97
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -89,11 +87,4 @@ def plot(metrics_dict, parameter):
8987
ax.legend(loc="upper left")
9088
ax.set_title("Mixed-phase Partition LCF [30S - 70S]")
9189

92-
for f in parameter.output_format:
93-
f = f.lower().split(".")[-1]
94-
fnm = os.path.join(
95-
_get_output_dir(parameter),
96-
f"{parameter.output_file}" + "." + f,
97-
)
98-
plt.savefig(fnm)
99-
logger.info(f"Plot saved in: {fnm}")
90+
_save_main_plot(parameter)

e3sm_diags/plot/tc_analysis_plot.py

Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
import os
2-
31
import cartopy.crs as ccrs
42
import cartopy.feature as cfeature
53
import matplotlib
64
import numpy as np
75
import xcdat as xc
86
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
97

10-
from e3sm_diags.driver.utils.io import _get_output_dir
118
from e3sm_diags.logger import _setup_child_logger
12-
from e3sm_diags.plot.utils import MAIN_TITLE_FONTSIZE
9+
from e3sm_diags.plot.utils import MAIN_TITLE_FONTSIZE, _save_main_plot
1310

1411
matplotlib.use("agg")
1512
import matplotlib.pyplot as plt # isort:skip # noqa: E402
@@ -103,16 +100,8 @@ def plot(test, ref, parameter, basin_dict):
103100
y=0.99,
104101
)
105102

106-
output_file_name = "tc-intensity"
107-
for f in parameter.output_format:
108-
f = f.lower().split(".")[-1]
109-
fnm = os.path.join(
110-
_get_output_dir(parameter),
111-
output_file_name + "." + f,
112-
)
113-
plt.savefig(fnm)
114-
logger.info(f"Plot saved in: {fnm}")
115-
plt.close()
103+
parameter.output_file = "tc-intensity"
104+
_save_main_plot(parameter)
116105

117106
# TC frequency of each basins
118107
fig = plt.figure(figsize=(12, 7))
@@ -150,16 +139,8 @@ def plot(test, ref, parameter, basin_dict):
150139
ax.set_ylabel("Fraction")
151140
ax.set_title("Relative frequency of TCs for each ocean basins")
152141

153-
output_file_name = "tc-frequency"
154-
for f in parameter.output_format:
155-
f = f.lower().split(".")[-1]
156-
fnm = os.path.join(
157-
_get_output_dir(parameter),
158-
output_file_name + "." + f,
159-
)
160-
plt.savefig(fnm)
161-
logger.info(f"Plot saved in: {fnm}")
162-
plt.close()
142+
parameter.output_file = "tc-frequency"
143+
_save_main_plot(parameter)
163144

164145
fig1 = plt.figure(figsize=(12, 6))
165146
ax = fig1.add_subplot(111)
@@ -190,16 +171,9 @@ def plot(test, ref, parameter, basin_dict):
190171
ax.set_title(
191172
"Distribution of accumulated cyclone energy (ACE) among various ocean basins"
192173
)
193-
output_file_name = "ace-distribution"
194-
for f in parameter.output_format:
195-
f = f.lower().split(".")[-1]
196-
fnm = os.path.join(
197-
_get_output_dir(parameter),
198-
output_file_name + "." + f,
199-
)
200-
plt.savefig(fnm)
201-
logger.info(f"Plot saved in: {fnm}")
202-
plt.close()
174+
175+
parameter.output_file = "ace-distribution"
176+
_save_main_plot(parameter)
203177

204178
fig, axes = plt.subplots(2, 3, figsize=(12, 6), sharex=True, sharey=True)
205179
fig.subplots_adjust(hspace=0.4, wspace=0.15)
@@ -231,16 +205,8 @@ def plot(test, ref, parameter, basin_dict):
231205
y=0.99,
232206
)
233207

234-
output_file_name = "tc-frequency-annual-cycle"
235-
for f in parameter.output_format:
236-
f = f.lower().split(".")[-1]
237-
fnm = os.path.join(
238-
_get_output_dir(parameter),
239-
output_file_name + "." + f,
240-
)
241-
plt.savefig(fnm)
242-
logger.info(f"Plot saved in: {fnm}")
243-
plt.close()
208+
parameter.output_file = "tc-frequency-annual-cycle"
209+
_save_main_plot(parameter)
244210

245211
##########################################################
246212
# Plot TC tracks density
@@ -286,17 +252,9 @@ def plot_map(test_data, ref_data, region, parameter):
286252

287253
# Figure title
288254
fig.suptitle(PLOT_INFO[region]["title"], x=0.5, y=0.9, fontsize=14)
289-
output_file_name = "{}-density-map".format(region)
255+
parameter.output_file = "{}-density-map".format(region)
290256

291-
for f in parameter.output_format:
292-
f = f.lower().split(".")[-1]
293-
fnm = os.path.join(
294-
_get_output_dir(parameter),
295-
output_file_name + "." + f,
296-
)
297-
plt.savefig(fnm)
298-
logger.info(f"Plot saved in: {fnm}")
299-
plt.close()
257+
_save_main_plot(parameter)
300258

301259

302260
def plot_panel(n, fig, proj, var, var_num_years, region, title):

0 commit comments

Comments
 (0)