Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions zppy_interfaces/global_time_series/coupled_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def set_var(
valid_vars: List[str],
invalid_vars: List[str],
rgn: str,
) -> None:
) -> List[Variable]:
new_var_list: List[Variable] = []
if exp[exp_key] != "":
try:
dataset_wrapper: DatasetWrapper = DatasetWrapper(exp[exp_key])
Expand All @@ -177,7 +178,8 @@ def set_var(
data_array: xarray.core.dataarray.DataArray
units: str
data_array, units = dataset_wrapper.globalAnnual(var)
valid_vars.append(str(var_str))
valid_vars.append(str(var_str)) # Append the name
new_var_list.append(var) # Append the variable itself
except Exception as e:
logger.error(e)
logger.error(f"globalAnnual failed for {var_str}")
Expand Down Expand Up @@ -210,6 +212,7 @@ def set_var(
].values
exp["annual"]["year"] = [x.year for x in years]
del dataset_wrapper
return new_var_list


def process_data(
Expand All @@ -222,27 +225,29 @@ def process_data(
for exp in exps:
exp["annual"] = {}

set_var(
requested_variables.vars_original = set_var(
exp,
"atmos",
requested_variables.vars_original,
valid_vars,
invalid_vars,
rgn,
)
set_var(
requested_variables.vars_atm = set_var(
exp, "atmos", requested_variables.vars_atm, valid_vars, invalid_vars, rgn
)
set_var(exp, "ice", requested_variables.vars_ice, valid_vars, invalid_vars, rgn)
set_var(
requested_variables.vars_ice = set_var(
exp, "ice", requested_variables.vars_ice, valid_vars, invalid_vars, rgn
)
requested_variables.vars_land = set_var(
exp,
"land",
requested_variables.vars_land,
valid_vars,
invalid_vars,
rgn,
)
set_var(
requested_variables.vars_ocn = set_var(
exp, "ocean", requested_variables.vars_ocn, valid_vars, invalid_vars, rgn
)

Expand Down Expand Up @@ -283,12 +288,15 @@ def run(parameters: Parameters, requested_variables: RequestedVariables, rgn: st
invalid_plots: List[str] = []

# Use list of tuples rather than a dict, to keep order
# Note: we use `parameters.plots_original` rather than `requested_variables.vars_original`
# because the "original" plots are expecting plot names that are not variable names.
# The model components however are expecting plot names to be variable names.
mapping: List[Tuple[str, List[str]]] = [
("original", parameters.plots_original),
("atm", parameters.plots_atm),
("ice", parameters.plots_ice),
("lnd", parameters.plots_lnd),
("ocn", parameters.plots_ocn),
("atm", list(map(lambda v: v.variable_name, requested_variables.vars_atm))),
("ice", list(map(lambda v: v.variable_name, requested_variables.vars_ice))),
("lnd", list(map(lambda v: v.variable_name, requested_variables.vars_land))),
("ocn", list(map(lambda v: v.variable_name, requested_variables.vars_ocn))),
]
for component, plot_list in mapping:
make_plot_pdfs(
Expand Down Expand Up @@ -332,11 +340,27 @@ def coupled_global(parameters: Parameters) -> None:
# In this case, we don't want the summary PDF.
# Rather, we want to construct a viewer similar to E3SM Diags.
title_and_url_list: List[Tuple[str, str]] = []
for component in ["original", "atm", "ice", "lnd", "ocn"]:
for component in [
"atm",
"ice",
"lnd",
"ocn",
]: # Don't create viewer for original component
vars = get_vars(requested_variables, component)
if vars:
url = create_viewer(parameters, vars, component)
logger.info(f"Viewer URL for {component}: {url}")
title_and_url_list.append((component, url))
# Special case for original plots: always use user-provided dimensions.
vars = get_vars(requested_variables, "original")
if vars:
logger.info("Using user provided dimensions for original plots PDF")
title_and_url_list.append(
(
"original",
f"{parameters.figstr}_glb_original.pdf",
)
)

index_url: str = create_viewer_index(parameters.results_dir, title_and_url_list)
logger.info(f"Viewer index URL: {index_url}")
39 changes: 31 additions & 8 deletions zppy_interfaces/global_time_series/coupled_global_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,23 @@ def make_plot_pdfs( # noqa: C901
valid_plots,
invalid_plots,
):
logger.info(f"make_plot_pdfs for rgn={rgn}, component={component}")
num_plots = len(plot_list)
if num_plots == 0:
return
plots_per_page = parameters.nrows * parameters.ncols

# If make_viewer, then we want to do 1 plot per page.
# However, the original plots are excluded from this restriction.
# Note: if the user provides nrows=ncols=1, there will still be a single plot per page
keep_user_dims = (not parameters.make_viewer) or (component == "original")
if keep_user_dims:
nrows = parameters.nrows
ncols = parameters.ncols
else:
nrows = 1
ncols = 1

plots_per_page = nrows * ncols
num_pages = math.ceil(num_plots / plots_per_page)

counter = 0
Expand All @@ -565,26 +578,36 @@ def make_plot_pdfs( # noqa: C901
)
for page in range(num_pages):
if plots_per_page == 1:
logger.info("Using reduced figsize")
fig = plt.figure(1, figsize=[13.5 / 2, 16.5 / 4])
else:
logger.info("Using standard figsize")
fig = plt.figure(1, figsize=[13.5, 16.5])
logger.info(f"Figure size={fig.get_size_inches() * fig.dpi}")
fig.suptitle(f"{parameters.figstr}_{rgn}_{component}")
for j in range(plots_per_page):
logger.info(
f"Plotting plot {j} on page {page}. This is plot {counter} in total."
)
# The final page doesn't need to be filled out with plots.
if counter >= num_plots:
break
ax = plt.subplot(parameters.nrows, parameters.ncols, j + 1)
ax = plt.subplot(
nrows,
ncols,
j + 1,
)
plot_name = plot_list[counter]
if component == "original":
try:
plot_function = PLOT_DICT[plot_list[counter]]
plot_function = PLOT_DICT[plot_name]
except KeyError:
raise KeyError(f"Invalid plot name: {plot_list[counter]}")
raise KeyError(f"Invalid plot name: {plot_name}")
try:
plot_function(ax, xlim, exps, rgn)
valid_plots.append(plot_list[counter])
valid_plots.append(plot_name)
except Exception:
traceback.print_exc()
plot_name = plot_list[counter]
required_vars = []
if plot_name == "net_toa_flux_restom":
required_vars = ["RESTOM"]
Expand All @@ -603,7 +626,6 @@ def make_plot_pdfs( # noqa: C901
counter += 1
else:
try:
plot_name = plot_list[counter]
plot_generic(ax, xlim, exps, plot_name, rgn)
valid_plots.append(plot_name)
except Exception:
Expand All @@ -616,6 +638,7 @@ def make_plot_pdfs( # noqa: C901

fig.tight_layout()
pdf.savefig(1)
# Always save individual PNGs for viewer mode
if plots_per_page == 1:
fig.savefig(
f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}_{plot_name}.png",
Expand All @@ -631,5 +654,5 @@ def make_plot_pdfs( # noqa: C901
f"{parameters.results_dir}/{parameters.figstr}_{rgn}_{component}.png",
dpi=150,
)
plt.clf()
plt.close(fig)
pdf.close()
4 changes: 0 additions & 4 deletions zppy_interfaces/global_time_series/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def __init__(self, args: Dict[str, str]):
map(lambda rgn: get_region(rgn), args["regions"].split(","))
)
self.make_viewer: bool = _str2bool(args["make_viewer"])
if self.make_viewer and (self.nrows != 1 or self.ncols != 1):
raise RuntimeError(
f"make_viewer requires 1x1 plots, but nrows={self.nrows} and ncols={self.ncols}"
)

# For both
self.year1: int = int(args["start_yr"])
Expand Down
Loading