Skip to content

Commit 460a87f

Browse files
authored
Merge pull request #16 from E3SM-Project/fix_plots_lnd
Improve variable and viewer handling
2 parents ef3c401 + 3a803db commit 460a87f

File tree

3 files changed

+67
-24
lines changed

3 files changed

+67
-24
lines changed

zppy_interfaces/global_time_series/coupled_global.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def set_var(
161161
valid_vars: List[str],
162162
invalid_vars: List[str],
163163
rgn: str,
164-
) -> None:
164+
) -> List[Variable]:
165+
new_var_list: List[Variable] = []
165166
if exp[exp_key] != "":
166167
try:
167168
dataset_wrapper: DatasetWrapper = DatasetWrapper(exp[exp_key])
@@ -177,7 +178,8 @@ def set_var(
177178
data_array: xarray.core.dataarray.DataArray
178179
units: str
179180
data_array, units = dataset_wrapper.globalAnnual(var)
180-
valid_vars.append(str(var_str))
181+
valid_vars.append(str(var_str)) # Append the name
182+
new_var_list.append(var) # Append the variable itself
181183
except Exception as e:
182184
logger.error(e)
183185
logger.error(f"globalAnnual failed for {var_str}")
@@ -210,6 +212,7 @@ def set_var(
210212
].values
211213
exp["annual"]["year"] = [x.year for x in years]
212214
del dataset_wrapper
215+
return new_var_list
213216

214217

215218
def process_data(
@@ -222,27 +225,29 @@ def process_data(
222225
for exp in exps:
223226
exp["annual"] = {}
224227

225-
set_var(
228+
requested_variables.vars_original = set_var(
226229
exp,
227230
"atmos",
228231
requested_variables.vars_original,
229232
valid_vars,
230233
invalid_vars,
231234
rgn,
232235
)
233-
set_var(
236+
requested_variables.vars_atm = set_var(
234237
exp, "atmos", requested_variables.vars_atm, valid_vars, invalid_vars, rgn
235238
)
236-
set_var(exp, "ice", requested_variables.vars_ice, valid_vars, invalid_vars, rgn)
237-
set_var(
239+
requested_variables.vars_ice = set_var(
240+
exp, "ice", requested_variables.vars_ice, valid_vars, invalid_vars, rgn
241+
)
242+
requested_variables.vars_land = set_var(
238243
exp,
239244
"land",
240245
requested_variables.vars_land,
241246
valid_vars,
242247
invalid_vars,
243248
rgn,
244249
)
245-
set_var(
250+
requested_variables.vars_ocn = set_var(
246251
exp, "ocean", requested_variables.vars_ocn, valid_vars, invalid_vars, rgn
247252
)
248253

@@ -283,12 +288,15 @@ def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: st
283288
invalid_plots: List[str] = []
284289

285290
# Use list of tuples rather than a dict, to keep order
291+
# Note: we use `parameters.plots_original` rather than `requested_variables.vars_original`
292+
# because the "original" plots are expecting plot names that are not variable names.
293+
# The model components however are expecting plot names to be variable names.
286294
mapping: List[Tuple[str, List[str]]] = [
287295
("original", parameters.plots_original),
288-
("atm", parameters.plots_atm),
289-
("ice", parameters.plots_ice),
290-
("lnd", parameters.plots_lnd),
291-
("ocn", parameters.plots_ocn),
296+
("atm", list(map(lambda v: v.variable_name, requested_variables.vars_atm))),
297+
("ice", list(map(lambda v: v.variable_name, requested_variables.vars_ice))),
298+
("lnd", list(map(lambda v: v.variable_name, requested_variables.vars_land))),
299+
("ocn", list(map(lambda v: v.variable_name, requested_variables.vars_ocn))),
292300
]
293301
for component, plot_list in mapping:
294302
make_plot_pdfs(
@@ -332,11 +340,27 @@ def coupled_global(parameters: Parameters) -> None:
332340
# In this case, we don't want the summary PDF.
333341
# Rather, we want to construct a viewer similar to E3SM Diags.
334342
title_and_url_list: List[Tuple[str, str]] = []
335-
for component in ["original", "atm", "ice", "lnd", "ocn"]:
343+
for component in [
344+
"atm",
345+
"ice",
346+
"lnd",
347+
"ocn",
348+
]: # Don't create viewer for original component
336349
vars = get_vars(requested_variables, component)
337350
if vars:
338351
url = create_viewer(parameters, vars, component)
339352
logger.info(f"Viewer URL for {component}: {url}")
340353
title_and_url_list.append((component, url))
354+
# Special case for original plots: always use user-provided dimensions.
355+
vars = get_vars(requested_variables, "original")
356+
if vars:
357+
logger.info("Using user provided dimensions for original plots PDF")
358+
title_and_url_list.append(
359+
(
360+
"original",
361+
f"{parameters.figstr}_glb_original.pdf",
362+
)
363+
)
364+
341365
index_url: str = create_viewer_index(parameters.results_dir, title_and_url_list)
342366
logger.info(f"Viewer index URL: {index_url}")

zppy_interfaces/global_time_series/coupled_global_plotting.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,23 @@ def make_plot_pdfs( # noqa: C901
551551
valid_plots,
552552
invalid_plots,
553553
):
554+
logger.info(f"make_plot_pdfs for rgn={rgn}, component={component}")
554555
num_plots = len(plot_list)
555556
if num_plots == 0:
556557
return
557-
plots_per_page = parameters.nrows * parameters.ncols
558+
559+
# If make_viewer, then we want to do 1 plot per page.
560+
# However, the original plots are excluded from this restriction.
561+
# Note: if the user provides nrows=ncols=1, there will still be a single plot per page
562+
keep_user_dims = (not parameters.make_viewer) or (component == "original")
563+
if keep_user_dims:
564+
nrows = parameters.nrows
565+
ncols = parameters.ncols
566+
else:
567+
nrows = 1
568+
ncols = 1
569+
570+
plots_per_page = nrows * ncols
558571
num_pages = math.ceil(num_plots / plots_per_page)
559572

560573
counter = 0
@@ -565,26 +578,36 @@ def make_plot_pdfs( # noqa: C901
565578
)
566579
for page in range(num_pages):
567580
if plots_per_page == 1:
581+
logger.info("Using reduced figsize")
568582
fig = plt.figure(1, figsize=[13.5 / 2, 16.5 / 4])
569583
else:
584+
logger.info("Using standard figsize")
570585
fig = plt.figure(1, figsize=[13.5, 16.5])
586+
logger.info(f"Figure size={fig.get_size_inches() * fig.dpi}")
571587
fig.suptitle(f"{parameters.figstr}_{rgn}_{component}")
572588
for j in range(plots_per_page):
589+
logger.info(
590+
f"Plotting plot {j} on page {page}. This is plot {counter} in total."
591+
)
573592
# The final page doesn't need to be filled out with plots.
574593
if counter >= num_plots:
575594
break
576-
ax = plt.subplot(parameters.nrows, parameters.ncols, j + 1)
595+
ax = plt.subplot(
596+
nrows,
597+
ncols,
598+
j + 1,
599+
)
600+
plot_name = plot_list[counter]
577601
if component == "original":
578602
try:
579-
plot_function = PLOT_DICT[plot_list[counter]]
603+
plot_function = PLOT_DICT[plot_name]
580604
except KeyError:
581-
raise KeyError(f"Invalid plot name: {plot_list[counter]}")
605+
raise KeyError(f"Invalid plot name: {plot_name}")
582606
try:
583607
plot_function(ax, xlim, exps, rgn)
584-
valid_plots.append(plot_list[counter])
608+
valid_plots.append(plot_name)
585609
except Exception:
586610
traceback.print_exc()
587-
plot_name = plot_list[counter]
588611
required_vars = []
589612
if plot_name == "net_toa_flux_restom":
590613
required_vars = ["RESTOM"]
@@ -603,7 +626,6 @@ def make_plot_pdfs( # noqa: C901
603626
counter += 1
604627
else:
605628
try:
606-
plot_name = plot_list[counter]
607629
plot_generic(ax, xlim, exps, plot_name, rgn)
608630
valid_plots.append(plot_name)
609631
except Exception:
@@ -616,6 +638,7 @@ def make_plot_pdfs( # noqa: C901
616638

617639
fig.tight_layout()
618640
pdf.savefig(1)
641+
# Always save individual PNGs for viewer mode
619642
if plots_per_page == 1:
620643
fig.savefig(
621644
f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{plot_name}.png",
@@ -631,5 +654,5 @@ def make_plot_pdfs( # noqa: C901
631654
f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.png",
632655
dpi=150,
633656
)
634-
plt.clf()
657+
plt.close(fig)
635658
pdf.close()

zppy_interfaces/global_time_series/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ def __init__(self, args: Dict[str, str]):
3535
map(lambda rgn: get_region(rgn), args["regions"].split(","))
3636
)
3737
self.make_viewer: bool = _str2bool(args["make_viewer"])
38-
if self.make_viewer and (self.nrows != 1 or self.ncols != 1):
39-
raise RuntimeError(
40-
f"make_viewer requires 1x1 plots, but nrows={self.nrows} and ncols={self.ncols}"
41-
)
4238

4339
# For both
4440
self.year1: int = int(args["start_yr"])

0 commit comments

Comments
 (0)