Skip to content

Commit 5a72a4d

Browse files
committed
Python: improve handling of stats_x_y
1 parent 6943bd2 commit 5a72a4d

File tree

1 file changed

+192
-18
lines changed

1 file changed

+192
-18
lines changed

Utilities/Python/fdsplotlib.py

Lines changed: 192 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def _compute_metrics_block(
129129
Y_sel = Y[comp_mask, :]
130130

131131
# --- support patterns like mean_1_2, max_2_1, end_1_2 ---
132+
# NOTE: we deliberately DO NOT parse "all_*_*" here.
132133
def _parse_stat_xy(m):
134+
m = m.lower().strip()
133135
for base in ("max", "mean", "end"):
134136
pref = base + "_"
135137
if m.startswith(pref):
@@ -147,11 +149,20 @@ def _parse_stat_xy(m):
147149
titles = []
148150
per_curve_series = []
149151

150-
# --- stat_x_y: only compute for the specified first index (MATLAB uses 1-based) ---
152+
# --- stat_x_y: use first index for EXP, second index for MODEL ---
153+
# Example: "mean_2_3"
154+
# EXP side (variant_side='d1') → column 2
155+
# MODEL side (variant_side='d2') → column 3
151156
if idx_first is not None:
152-
j = idx_first - 1
157+
if variant_side == "d2" and idx_second is not None:
158+
idx_use = idx_second
159+
else:
160+
idx_use = idx_first
161+
162+
j = idx_use - 1 # 1-based → 0-based
153163
if j < 0 or j >= ncols:
154164
return np.array([]), [], []
165+
155166
yj = Y_sel[:, j].reshape(-1)
156167

157168
if base == "max":
@@ -166,7 +177,7 @@ def _parse_stat_xy(m):
166177
if out == 0.0:
167178
out = 1e-12
168179

169-
return np.array([out]), [f"curve{idx_first}"], []
180+
return np.array([out]), [f"curve{idx_use}"], []
170181

171182
# --- metric='all': return all finite Y values (one per data point) ---
172183
if metric_str == "all":
@@ -434,6 +445,11 @@ def read_csv_cached(path, **kwargs):
434445
# --- Save measured (experimental) ---
435446
if not gtest:
436447
try:
448+
metric_raw = str(pp.Metric or '').strip()
449+
metric_str = metric_raw.lower()
450+
# For 'all_2_3', treat EXP metric as 'all' (per-column) here.
451+
metric_for_exp_block = "all" if metric_str.startswith("all") else metric_raw
452+
437453
vals_meas_list = []
438454
qty_meas_list = []
439455
if y.ndim == 2 and x.ndim == 2 and y.shape[1] == x.shape[1]:
@@ -444,7 +460,7 @@ def read_csv_cached(path, **kwargs):
444460
xj, yj = xj[mask], yj[mask]
445461
if len(xj) > 0 and len(yj) > 0:
446462
vals_meas, qty_meas, _ = _compute_metrics_block(
447-
x=xj, Y=yj, metric=pp.Metric,
463+
x=xj, Y=yj, metric=metric_for_exp_block,
448464
initial_value=float(pp.d1_Initial_Value or 0.0),
449465
comp_start=float(pp.d1_Comp_Start or np.nan),
450466
comp_end=float(pp.d1_Comp_End or np.nan),
@@ -456,7 +472,7 @@ def read_csv_cached(path, **kwargs):
456472
qty_meas_list.append(qty_meas)
457473
else:
458474
vals_meas, qty_meas, _ = _compute_metrics_block(
459-
x=x, Y=y, metric=pp.Metric,
475+
x=x, Y=y, metric=metric_for_exp_block,
460476
initial_value=float(pp.d1_Initial_Value or 0.0),
461477
comp_start=float(pp.d1_Comp_Start or np.nan),
462478
comp_end=float(pp.d1_Comp_End or np.nan),
@@ -539,24 +555,181 @@ def read_csv_cached(path, **kwargs):
539555
# --- Interpolated, metric-aware model logic ---
540556
if not gtest:
541557
try:
542-
metric_str = str(pp.Metric or '').strip().lower()
558+
metric_raw = str(pp.Metric or '').strip()
559+
metric_str = metric_raw.lower()
543560
meas_list, pred_list, qty_pred_list = [], [], []
544561

562+
# Local parser for stat_x_y patterns (max_2_3, mean_1_4, end_3_2)
563+
def _parse_stat_xy_local(m):
564+
m = m.lower().strip()
565+
for base in ("max", "mean", "end"):
566+
pref = base + "_"
567+
if m.startswith(pref):
568+
try:
569+
a, b = m[len(pref):].split("_", 1)
570+
return base, int(a), int(b)
571+
except Exception:
572+
pass
573+
return m, None, None
574+
575+
base_stat, idx_first_stat, idx_second_stat = _parse_stat_xy_local(metric_str)
576+
545577
# Load experimental again for alignment (safe; cached)
546578
E = read_csv_cached(expdir + pp.d1_Filename,
547579
header=int(pp.d1_Col_Name_Row - 1),
548580
sep=',', engine='python', quotechar='"',
549581
skip_blank_lines=True).dropna(how='all')
550582
E.columns = E.columns.str.strip()
551583
start_idx_exp = int(pp.d1_Data_Row - pp.d1_Col_Name_Row - 1)
552-
x_exp, _ = get_data(E, pp.d1_Ind_Col_Name, start_idx_exp)
553-
y_exp, _ = get_data(E, pp.d1_Dep_Col_Name, start_idx_exp)
584+
x_exp_raw, _ = get_data(E, pp.d1_Ind_Col_Name, start_idx_exp)
585+
y_exp_raw, _ = get_data(E, pp.d1_Dep_Col_Name, start_idx_exp)
586+
587+
x_mod_raw = x
588+
y_mod_raw = y
589+
590+
# --- CASE 1: stat pair metrics (max_2_3, mean_2_3, end_2_3) ---
591+
if base_stat in ("max", "mean", "end") and idx_first_stat is not None:
592+
v_meas, _, _ = _compute_metrics_block(
593+
x=x_exp_raw, Y=y_exp_raw, metric=metric_raw,
594+
initial_value=float(pp.d1_Initial_Value or 0.0),
595+
comp_start=float(pp.d1_Comp_Start or np.nan),
596+
comp_end=float(pp.d1_Comp_End or np.nan),
597+
dep_comp_start=float(pp.d1_Dep_Comp_Start or np.nan),
598+
dep_comp_end=float(pp.d1_Dep_Comp_End or np.nan),
599+
variant_side="d1",
600+
)
601+
v_pred, qty_pred, _ = _compute_metrics_block(
602+
x=x_mod_raw, Y=y_mod_raw, metric=metric_raw,
603+
initial_value=float(pp.d2_Initial_Value or 0.0),
604+
comp_start=float(pp.d2_Comp_Start or np.nan),
605+
comp_end=float(pp.d2_Comp_End or np.nan),
606+
dep_comp_start=float(pp.d2_Dep_Comp_Start or np.nan),
607+
dep_comp_end=float(pp.d2_Dep_Comp_End or np.nan),
608+
variant_side="d2",
609+
)
610+
611+
flat_meas = np.atleast_1d(v_meas)
612+
flat_pred = np.atleast_1d(v_pred)
613+
nmin = min(flat_meas.size, flat_pred.size)
614+
if nmin == 0:
615+
print(f"[dataplot] Warning: no valid data pairs for {pp.Dataname}")
616+
else:
617+
if flat_meas.size != flat_pred.size:
618+
print(f"[dataplot] Truncated unequal vectors for {pp.Dataname}: "
619+
f"Measured={flat_meas.size}, Predicted={flat_pred.size}{nmin}")
620+
flat_meas = flat_meas[:nmin]
621+
flat_pred = flat_pred[:nmin]
622+
623+
Save_Measured_Metric[-1] = flat_meas
624+
Save_Predicted_Metric[-1] = flat_pred
625+
626+
qty_label = str(pp.d2_Dep_Col_Name).strip() or "Unknown"
627+
Save_Predicted_Quantity[-1] = np.array([qty_label] * len(flat_pred), dtype=object)
628+
629+
plt.figure(f.number)
630+
os.makedirs(pltdir, exist_ok=True)
631+
plt.savefig(pltdir + pp.Plot_Filename + '.pdf', backend='pdf')
632+
f_Last = f
633+
continue # move to next config row
634+
635+
# --- CASE 2: "all" with explicit pairing (all_2_3) ---
636+
is_all_pair = False
637+
idx_first_all = idx_second_all = None
638+
if metric_str.startswith("all_"):
639+
try:
640+
rest = metric_str[len("all_"):]
641+
a, b = rest.split("_", 1)
642+
idx_first_all = int(a)
643+
idx_second_all = int(b)
644+
is_all_pair = True
645+
except Exception:
646+
is_all_pair = False
554647

555648
# Normalize shapes to 2D (col-major semantics)
556-
x_exp = np.atleast_2d(x_exp)
557-
y_exp = np.atleast_2d(y_exp)
558-
x_mod = np.atleast_2d(x)
559-
y_mod = np.atleast_2d(y)
649+
x_exp = np.atleast_2d(x_exp_raw)
650+
y_exp = np.atleast_2d(y_exp_raw)
651+
x_mod = np.atleast_2d(x_mod_raw)
652+
y_mod = np.atleast_2d(y_mod_raw)
653+
654+
# Special "all_2_3" handling: one EXP column vs one MODEL column
655+
if is_all_pair and idx_first_all is not None and idx_second_all is not None:
656+
j_e = idx_first_all - 1
657+
j_m = idx_second_all - 1
658+
if j_e < 0 or j_m < 0 or j_e >= y_exp.shape[1] or j_m >= y_mod.shape[1]:
659+
print(f"[dataplot] all-pair index out of range for {pp.Dataname}")
660+
flat_meas = np.array([])
661+
flat_pred = np.array([])
662+
else:
663+
xj_e = np.ravel(x_exp[:, j_e] if x_exp.shape[1] > 1 else x_exp)
664+
yj_e = np.ravel(y_exp[:, j_e])
665+
m_e = np.isfinite(xj_e) & np.isfinite(yj_e)
666+
xj_e, yj_e = xj_e[m_e], yj_e[m_e]
667+
668+
xj_m = np.ravel(x_mod[:, j_m] if x_mod.shape[1] > 1 else x_mod)
669+
yj_m = np.ravel(y_mod[:, j_m])
670+
m_m = np.isfinite(xj_m) & np.isfinite(yj_m)
671+
xj_m, yj_m = xj_m[m_m], yj_m[m_m]
672+
673+
if xj_m.size < 2 or xj_e.size == 0:
674+
flat_meas = np.array([])
675+
flat_pred = np.array([])
676+
else:
677+
yj_m_i = np.interp(xj_e, xj_m, yj_m, left=np.nan, right=np.nan)
678+
mask_pair = np.isfinite(yj_m_i) & np.isfinite(yj_e)
679+
if not np.any(mask_pair):
680+
flat_meas = np.array([])
681+
flat_pred = np.array([])
682+
else:
683+
x_use = xj_e[mask_pair]
684+
y_exp_use = yj_e[mask_pair]
685+
y_mod_use = yj_m_i[mask_pair]
686+
687+
v_meas, _, _ = _compute_metrics_block(
688+
x=x_use, Y=y_exp_use, metric="all",
689+
initial_value=float(pp.d1_Initial_Value or 0.0),
690+
comp_start=float(pp.d1_Comp_Start or np.nan),
691+
comp_end=float(pp.d1_Comp_End or np.nan),
692+
dep_comp_start=float(pp.d1_Dep_Comp_Start or np.nan),
693+
dep_comp_end=float(pp.d1_Dep_Comp_End or np.nan),
694+
variant_side="d1",
695+
)
696+
v_pred, qty_pred, _ = _compute_metrics_block(
697+
x=x_use, Y=y_mod_use, metric="all",
698+
initial_value=float(pp.d2_Initial_Value or 0.0),
699+
comp_start=float(pp.d2_Comp_Start or np.nan),
700+
comp_end=float(pp.d2_Comp_End or np.nan),
701+
dep_comp_start=float(pp.d2_Dep_Comp_Start or np.nan),
702+
dep_comp_end=float(pp.d2_Dep_Comp_End or np.nan),
703+
variant_side="d2",
704+
)
705+
706+
flat_meas = np.atleast_1d(v_meas)
707+
flat_pred = np.atleast_1d(v_pred)
708+
709+
nmin = min(flat_meas.size, flat_pred.size)
710+
if nmin == 0:
711+
print(f"[dataplot] Warning: no valid data pairs for {pp.Dataname}")
712+
else:
713+
if flat_meas.size != flat_pred.size:
714+
print(f"[dataplot] Truncated unequal vectors for {pp.Dataname}: "
715+
f"Measured={flat_meas.size}, Predicted={flat_pred.size}{nmin}")
716+
flat_meas = flat_meas[:nmin]
717+
flat_pred = flat_pred[:nmin]
718+
719+
Save_Measured_Metric[-1] = flat_meas
720+
Save_Predicted_Metric[-1] = flat_pred
721+
722+
qty_label = str(pp.d2_Dep_Col_Name).strip() or "Unknown"
723+
Save_Predicted_Quantity[-1] = np.array([qty_label] * len(flat_pred), dtype=object)
724+
725+
plt.figure(f.number)
726+
os.makedirs(pltdir, exist_ok=True)
727+
plt.savefig(pltdir + pp.Plot_Filename + '.pdf', backend='pdf')
728+
f_Last = f
729+
continue # move to next config row
730+
731+
# --- CASE 3: general metrics (including plain 'all') ---
732+
metric_for_block = "all" if metric_str.startswith("all") else metric_raw
560733

561734
ncols = min(y_exp.shape[1], y_mod.shape[1])
562735

@@ -572,7 +745,7 @@ def read_csv_cached(path, **kwargs):
572745
m_m = np.isfinite(xj_m) & np.isfinite(yj_m)
573746
xj_m, yj_m = xj_m[m_m], yj_m[m_m]
574747

575-
if metric_str == 'all':
748+
if metric_for_block == 'all':
576749
# align by interpolating model to exp x
577750
if xj_m.size < 2 or xj_e.size == 0:
578751
continue
@@ -585,7 +758,7 @@ def read_csv_cached(path, **kwargs):
585758
y_mod_use = yj_m_i[mask_pair]
586759
# compute both on the same x grid
587760
v_meas, _, _ = _compute_metrics_block(
588-
x=x_use, Y=y_exp_use, metric=metric_str,
761+
x=x_use, Y=y_exp_use, metric="all",
589762
initial_value=float(pp.d1_Initial_Value or 0.0),
590763
comp_start=float(pp.d1_Comp_Start or np.nan),
591764
comp_end=float(pp.d1_Comp_End or np.nan),
@@ -594,7 +767,7 @@ def read_csv_cached(path, **kwargs):
594767
variant_side="d1",
595768
)
596769
v_pred, qty_pred, _ = _compute_metrics_block(
597-
x=x_use, Y=y_mod_use, metric=metric_str,
770+
x=x_use, Y=y_mod_use, metric="all",
598771
initial_value=float(pp.d2_Initial_Value or 0.0),
599772
comp_start=float(pp.d2_Comp_Start or np.nan),
600773
comp_end=float(pp.d2_Comp_End or np.nan),
@@ -607,7 +780,7 @@ def read_csv_cached(path, **kwargs):
607780
if yj_e.size == 0 or yj_m.size == 0:
608781
continue
609782
v_meas, _, _ = _compute_metrics_block(
610-
x=xj_e, Y=yj_e, metric=metric_str,
783+
x=xj_e, Y=yj_e, metric=metric_for_block,
611784
initial_value=float(pp.d1_Initial_Value or 0.0),
612785
comp_start=float(pp.d1_Comp_Start or np.nan),
613786
comp_end=float(pp.d1_Comp_End or np.nan),
@@ -616,7 +789,7 @@ def read_csv_cached(path, **kwargs):
616789
variant_side="d1",
617790
)
618791
v_pred, qty_pred, _ = _compute_metrics_block(
619-
x=xj_m, Y=yj_m, metric=metric_str,
792+
x=xj_m, Y=yj_m, metric=metric_for_block,
620793
initial_value=float(pp.d2_Initial_Value or 0.0),
621794
comp_start=float(pp.d2_Comp_Start or np.nan),
622795
comp_end=float(pp.d2_Comp_End or np.nan),
@@ -643,7 +816,7 @@ def read_csv_cached(path, **kwargs):
643816
flat_meas = flat_meas[:nmin]
644817
flat_pred = flat_pred[:nmin]
645818

646-
# Save truncated paired arrays (but don’t overwrite earlier measured data)
819+
# Save truncated paired arrays
647820
Save_Measured_Metric[-1] = flat_meas
648821
Save_Predicted_Metric[-1] = flat_pred
649822

@@ -680,6 +853,7 @@ def read_csv_cached(path, **kwargs):
680853
return saved_data, drange
681854

682855

856+
683857
def get_data(E, spec, start_idx):
684858
"""
685859
Extract data columns from DataFrame E according to spec string.

0 commit comments

Comments
 (0)