Skip to content

Commit 9396546

Browse files
zhangshixuan1987forsyth2
authored andcommitted
Add bug fix to improve robustness and performance
1 parent 2554bd4 commit 9396546

File tree

2 files changed

+215
-77
lines changed

2 files changed

+215
-77
lines changed

zppy_interfaces/pcmdi_diags/synthetic_plots/synthetic_metrics_plotter.py

Lines changed: 121 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -712,18 +712,57 @@ def portrait_metric_plot(
712712
shrink = 0.8 * fscale
713713
legend_fontsize = fontsize * 0.8
714714

715+
# --- SMALL GUARD A: basic inputs present ---
716+
if not var_list:
717+
logger.warning("[Portrait]: No variables to plot (var_list empty); returning.")
718+
return
719+
if not model_list:
720+
logger.warning("[Portrait]: No models to plot (model_list empty); returning.")
721+
return
722+
715723
if group == "mean_climate":
716-
# data for final plot
717-
data_all_nor = np.stack(
718-
[data_dict["djf"], data_dict["mam"], data_dict["jja"], data_dict["son"]]
719-
)
724+
# --- SMALL GUARD B: seasonal arrays exist & stack cleanly ---
725+
required = ["djf", "mam", "jja", "son"]
726+
missing = [k for k in required if (k not in data_dict or data_dict[k] is None)]
727+
if missing:
728+
logger.warning(
729+
"[Portrait]: Missing seasonal arrays for %s; returning. Missing=%s",
730+
group,
731+
missing,
732+
)
733+
return
734+
try:
735+
arrs = [np.asarray(data_dict[k]) for k in required]
736+
if any(a.size == 0 for a in arrs):
737+
logger.warning(
738+
"[Portrait]: One or more seasonal arrays are empty; returning."
739+
)
740+
return
741+
data_all_nor = np.stack(arrs)
742+
except Exception as e:
743+
logger.warning(
744+
"[Portrait]: Failed to stack seasonal arrays: %s; returning.", e
745+
)
746+
return
747+
720748
legend_on = True
721749
legend_labels = ["DJF", "MAM", "JJA", "SON"]
722750
else:
723-
data_all_nor = data_dict
751+
# --- SMALL GUARD C: non-seasonal data present ---
752+
data_all_nor = np.asarray(data_dict)
753+
if data_all_nor.size == 0:
754+
logger.warning("[Portrait]: Input data array is empty; returning.")
755+
return
724756
legend_on = False
725757
legend_labels = []
726758

759+
# --- SMALL GUARD D: minimal shape sanity (avoid cryptic errors downstream) ---
760+
if data_all_nor.ndim < 2:
761+
logger.warning(
762+
"[Portrait]: Data has ndim=%d (<2); returning.", data_all_nor.ndim
763+
)
764+
return
765+
727766
highlight_models = get_highlight_models(model_list, model_name)
728767
lable_colors = []
729768
for model in model_list:
@@ -742,7 +781,6 @@ def portrait_metric_plot(
742781
elif stat in ["stdv_pc_ratio_to_obs"]:
743782
var_range = (0.5, 1.5)
744783
cmap_color = "jet"
745-
cmap_bounds = [0.5, 0.7, 0.9, 1.1, 1.3, 1.5]
746784
cmap_bounds = [r / 10 for r in range(5, 16, 1)]
747785
else:
748786
var_range = (-0.5, 0.5)
@@ -782,7 +820,9 @@ def portrait_metric_plot(
782820

783821
# Add title
784822
fig.suptitle(
785-
f"{region}{group} ({stat_name})", fontsize=fontsize * 1.1, fontweight="bold"
823+
f"{region}{group} ({stat_name})",
824+
fontsize=fontsize * 1.1,
825+
fontweight="bold",
786826
)
787827
fig.tight_layout(rect=[0, 0, 1, 0.95]) # leave top 5 % free for title
788828

@@ -888,29 +928,54 @@ def parcoord_metric_plot(
888928
"#377eb8",
889929
"#dede00",
890930
]
891-
892-
# ensemble mean for E3SM group
931+
# --- SMALL GUARD 1: highlight models may be missing ---
893932
highlight_model1 = get_highlight_models(data_dict.get("model", []), model_name)
894-
irow_str = data_dict[data_dict["model"] == highlight_model1[0]].index[0]
895-
irow_end = data_dict[data_dict["model"] == highlight_model1[-1]].index[0] + 1
896-
data_dict.loc[mean2_name] = data_dict[irow_str:irow_end].mean(
897-
numeric_only=True, skipna=True
898-
)
899-
data_dict.at[mean2_name, "model"] = mean2_name
933+
if not highlight_model1:
934+
# No highlightable models → skip means/highlights later
935+
highlight_model1 = []
936+
have_highlights = False
937+
else:
938+
have_highlights = True
900939

901-
# ensemble mean for CMIP group
902-
irow_sub = data_dict[data_dict["model"] == highlight_model1[0]].index[0]
903-
data_dict.loc[mean1_name] = data_dict[:irow_sub].mean(
904-
numeric_only=True, skipna=True
905-
)
906-
data_dict.at[mean1_name, "model"] = mean1_name
907-
data_dict.loc[mean2_name] = data_dict[irow_sub:].mean(
908-
numeric_only=True, skipna=True
909-
)
910-
data_dict.at[mean2_name, "model"] = mean2_name
940+
# Only compute E3SM mean if we found highlight rows in the DF
941+
if have_highlights:
942+
if (highlight_model1[0] in set(data_dict["model"])) and (
943+
highlight_model1[-1] in set(data_dict["model"])
944+
):
945+
irow_str = data_dict.index[data_dict["model"] == highlight_model1[0]][0]
946+
irow_end = (
947+
data_dict.index[data_dict["model"] == highlight_model1[-1]][0] + 1
948+
)
949+
data_dict.loc[mean2_name] = data_dict.iloc[irow_str:irow_end].mean(
950+
numeric_only=True, skipna=True
951+
)
952+
data_dict.at[mean2_name, "model"] = mean2_name
953+
else:
954+
have_highlights = False # fallback if names weren’t found
955+
956+
# CMIP/E3SM means only if we can split reliably
957+
if have_highlights:
958+
irow_sub = data_dict.index[data_dict["model"] == highlight_model1[0]][0]
959+
data_dict.loc[mean1_name] = data_dict.iloc[:irow_sub].mean(
960+
numeric_only=True, skipna=True
961+
)
962+
data_dict.at[mean1_name, "model"] = mean1_name
963+
data_dict.loc[mean2_name] = data_dict.iloc[irow_sub:].mean(
964+
numeric_only=True, skipna=True
965+
)
966+
data_dict.at[mean2_name, "model"] = mean2_name
967+
968+
# --- SMALL GUARD 1: highlights models
969+
if not have_highlights:
970+
logger.warning(
971+
f"[ParCoord]: No highlightable models found for model_name={model_name}; "
972+
f"Skipping highlight and mean calculations."
973+
)
911974

912-
model_list = data_dict["model"].to_list()
913-
highlight_model2 = highlight_model1 + [mean1_name, mean2_name]
975+
model_list = data_dict["model"].astype(str).to_list()
976+
highlight_model2 = highlight_model1 + (
977+
[mean1_name, mean2_name] if have_highlights else []
978+
)
914979

915980
# colors for highlight lines
916981
lncolors = []
@@ -922,30 +987,48 @@ def parcoord_metric_plot(
922987
else:
923988
lncolors.append(xcolors[i % len(xcolors)])
924989

925-
var_name1 = sorted(var_names.copy())
990+
# --- SMALL GUARD 2: keep only existing, non-empty vars ---
991+
var_name1 = sorted(
992+
v for v in var_names if (v in data_dict.columns) and data_dict[v].notna().any()
993+
)
994+
if not var_name1:
995+
logger.warning(
996+
f"[ParCoord]: Nothing to plot for group={group}, region={region}, stat={stat}. "
997+
f"No valid variables found in metrics data (columns checked={len(var_names)})."
998+
)
999+
return
1000+
9261001
# label information
9271002
var_labels = []
928-
for i, var in enumerate(var_name1):
929-
index = var_names.index(var)
930-
if var_units is not None:
931-
var_labels.append(var_names[index] + "\n" + var_units[index])
1003+
for v in var_name1:
1004+
idx = var_names.index(v)
1005+
if var_units is not None and idx < len(var_units):
1006+
var_labels.append(var_names[idx] + "\n" + var_units[idx])
9321007
else:
933-
var_labels.append(var_names[index])
1008+
var_labels.append(var_names[idx])
9341009

9351010
# final plot data
9361011
data_var = data_dict[var_name1].to_numpy()
9371012

1013+
# --- SMALL GUARD 3: ensure at least 1 column for parallel-coords ---
1014+
if data_var.ndim != 2 or data_var.shape[1] == 0:
1015+
logger.warning(
1016+
f"[ParCoord]: Not enough data to process parallel coordinate plots "
1017+
f"(shape={data_var.shape}); returning without plot."
1018+
)
1019+
return
1020+
9381021
xlabel = "Metric"
9391022
ylabel = "{} ({})".format(stat_name, stat.upper())
9401023

9411024
if "mean_climate" in [group, region]:
942-
title = "Model Performance of Annual Climatology ({}, {})".format(
943-
stat.upper(), region.upper()
944-
)
1025+
title = f"Model Performance of Annual Climatology ({stat.upper()}, {region.upper()})"
9451026
elif "variability_modes" in [group, region]:
946-
title = "Model Performance of Modes Variability ({})".format(stat.upper())
1027+
title = f"Model Performance of Modes Variability ({stat.upper()})"
9471028
elif "enso" in [group, region]:
948-
title = "Model Performance of ENSO ({})".format(stat.upper())
1029+
title = f"Model Performance of ENSO ({stat.upper()})"
1030+
else:
1031+
title = f"Model Performance ({stat.upper()}, {region.upper()})"
9491032

9501033
fig, ax = parallel_coordinate_plot(
9511034
data_var,

0 commit comments

Comments
 (0)