-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize_v5.py
More file actions
235 lines (191 loc) · 10.5 KB
/
visualize_v5.py
File metadata and controls
235 lines (191 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
visualize_v5.py — Hatch v5 Analysis Visualization
===================================================
4-panel figure:
Top-left: v1→v5 progression bar chart (F1 Folded, F1 Disordered, Mean F1)
Top-right: Feature threshold shift: initial vs optimized (Coordinate Descent)
Bottom-left: K × T_fill heatmap (mean F1, best D=5, T_drain=0.8, SH=True)
Bottom-right: Cup trace comparison (v4 vs v5) on p53 TAD sequence
"""
import csv
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
# ─── Load data ────────────────────────────────────────────────────────────────
# v1–v5 progression
PROGRESSION = {
"v1": {"f1_fold": 0.349, "f1_dis": 0.355, "mean_f1": 0.352},
"v2": {"f1_fold": 0.318, "f1_dis": 0.619, "mean_f1": 0.469},
"v3": {"f1_fold": 0.643, "f1_dis": 0.662, "mean_f1": 0.652},
"v4": {"f1_fold": 0.739, "f1_dis": 0.661, "mean_f1": 0.700},
"v5": {"f1_fold": 0.746, "f1_dis": 0.682, "mean_f1": 0.714},
}
# Feature threshold shifts (from optimized_thresholds_v5.json)
FEATURE_NAMES_SHORT = [
"Hydrophobicity", "Flexibility", "H-Bond Potential",
"Net Charge", "Shannon Entropy", "Proline Freq",
"Bulky Hydrophobic"
]
INITIAL_THRESH = [0.43981, 0.82531, 3.36667, 0.06667, 0.81190, 0.05000, 0.25000]
OPTIMIZED_THRESH = [0.45291, 0.83417, 3.26667, 0.06667, 0.79976, 0.02500, 0.25000]
AUC_WEIGHTS = [0.7807, 0.8212, 0.2895, 0.2933, 0.8415, 0.3462, 1.0000]
# Load v5 grid search results
v5_results = []
with open("optimization_report_v5.csv") as f:
reader = csv.DictReader(f)
for row in reader:
v5_results.append({k: float(v) if k not in ("super_hatch",) else v for k, v in row.items()})
# ─── Figure setup ─────────────────────────────────────────────────────────────
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.patch.set_facecolor("#0d1117")
for ax in axes.flat:
ax.set_facecolor("#161b22")
ax.tick_params(colors="#c9d1d9", labelsize=9)
for spine in ax.spines.values():
spine.set_edgecolor("#30363d")
TITLE_COLOR = "#e6edf3"
LABEL_COLOR = "#8b949e"
ACCENT_FOLD = "#58a6ff"
ACCENT_DIS = "#f78166"
ACCENT_MEAN = "#3fb950"
ACCENT_CD = "#d2a8ff"
# ─── Panel 1: v1→v5 Progression ───────────────────────────────────────────────
ax1 = axes[0, 0]
versions = list(PROGRESSION.keys())
x = np.arange(len(versions))
w = 0.25
f1_fold = [PROGRESSION[v]["f1_fold"] for v in versions]
f1_dis = [PROGRESSION[v]["f1_dis"] for v in versions]
mean_f1 = [PROGRESSION[v]["mean_f1"] for v in versions]
b1 = ax1.bar(x - w, f1_fold, w, label="F1 Folded", color=ACCENT_FOLD, alpha=0.85)
b2 = ax1.bar(x, f1_dis, w, label="F1 Disordered", color=ACCENT_DIS, alpha=0.85)
b3 = ax1.bar(x + w, mean_f1, w, label="Mean F1", color=ACCENT_MEAN, alpha=0.85)
# Annotate mean F1 values
for i, (bar, val) in enumerate(zip(b3, mean_f1)):
ax1.text(bar.get_x() + bar.get_width()/2, val + 0.012,
f"{val:.1%}", ha="center", va="bottom", fontsize=8,
color=ACCENT_MEAN, fontweight="bold")
# Highlight v5
ax1.axvspan(4 - 0.45, 4 + 0.45, alpha=0.08, color=ACCENT_CD, zorder=0)
ax1.set_xticks(x)
ax1.set_xticklabels(versions, color=TITLE_COLOR, fontsize=11, fontweight="bold")
ax1.set_ylim(0, 1.0)
ax1.set_ylabel("F1 Score", color=LABEL_COLOR, fontsize=10)
ax1.set_title("v1 → v5 Performance Progression", color=TITLE_COLOR,
fontsize=12, fontweight="bold", pad=10)
ax1.legend(loc="upper left", facecolor="#21262d", edgecolor="#30363d",
labelcolor=TITLE_COLOR, fontsize=9)
ax1.axhline(0.80, color="#f0e68c", linestyle="--", linewidth=1, alpha=0.5, label="80% Target")
ax1.text(4.6, 0.81, "80% Target", color="#f0e68c", fontsize=8, va="bottom")
ax1.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))
ax1.grid(axis="y", color="#30363d", linewidth=0.5, alpha=0.5)
# ─── Panel 2: Feature Threshold Shifts ────────────────────────────────────────
ax2 = axes[0, 1]
y_pos = np.arange(len(FEATURE_NAMES_SHORT))
# Normalize thresholds to [0,1] range for display (relative shift)
shifts = [(o - i) / max(abs(i), 1e-6) * 100 for i, o in zip(INITIAL_THRESH, OPTIMIZED_THRESH)]
colors = [ACCENT_FOLD if s > 0 else ACCENT_DIS for s in shifts]
# Bar chart of % shift
bars = ax2.barh(y_pos, shifts, color=colors, alpha=0.8, height=0.6)
# Annotate with AUC weight
for i, (bar, w_val, shift) in enumerate(zip(bars, AUC_WEIGHTS, shifts)):
x_pos = shift + (0.3 if shift >= 0 else -0.3)
ha = "left" if shift >= 0 else "right"
ax2.text(x_pos, i, f"w={w_val:.2f}", ha=ha, va="center",
fontsize=8, color=LABEL_COLOR)
ax2.set_yticks(y_pos)
ax2.set_yticklabels(FEATURE_NAMES_SHORT, color=TITLE_COLOR, fontsize=9)
ax2.set_xlabel("Threshold Shift (%)", color=LABEL_COLOR, fontsize=10)
ax2.set_title("Coordinate Descent: Per-Feature Threshold Shifts",
color=TITLE_COLOR, fontsize=12, fontweight="bold", pad=10)
ax2.axvline(0, color="#8b949e", linewidth=1, linestyle="-")
ax2.grid(axis="x", color="#30363d", linewidth=0.5, alpha=0.5)
# Legend
patch_up = mpatches.Patch(color=ACCENT_FOLD, label="Shifted up (stricter folded)")
patch_down = mpatches.Patch(color=ACCENT_DIS, label="Shifted down (stricter folded)")
ax2.legend(handles=[patch_up, patch_down], loc="lower right",
facecolor="#21262d", edgecolor="#30363d", labelcolor=TITLE_COLOR, fontsize=8)
# ─── Panel 3: K × T_fill Heatmap ──────────────────────────────────────────────
ax3 = axes[1, 0]
K_vals = [12, 15, 18, 20]
T_fill_vals = [1.2, 1.5, 1.8]
# Build heatmap matrix: best mean_f1 for each (K, T_fill) with D=5, T_drain=0.8, SH=True
heatmap = np.zeros((len(K_vals), len(T_fill_vals)))
for row in v5_results:
if (int(row["D"]) == 5 and float(row["T_drain"]) == 0.8
and row["super_hatch"] == "True"):
ki = K_vals.index(int(row["K"])) if int(row["K"]) in K_vals else -1
ti = T_fill_vals.index(float(row["T_fill"])) if float(row["T_fill"]) in T_fill_vals else -1
if ki >= 0 and ti >= 0:
heatmap[ki, ti] = float(row["mean_f1"])
cmap = LinearSegmentedColormap.from_list("hatch",
["#1a1f2e", "#1e3a5f", "#1d6fa4", "#3fb950", "#f0e68c"], N=256)
im = ax3.imshow(heatmap, cmap=cmap, aspect="auto", vmin=0.68, vmax=0.75)
ax3.set_xticks(range(len(T_fill_vals)))
ax3.set_xticklabels([f"T_fill={v}" for v in T_fill_vals], color=TITLE_COLOR, fontsize=9)
ax3.set_yticks(range(len(K_vals)))
ax3.set_yticklabels([f"K={v}" for v in K_vals], color=TITLE_COLOR, fontsize=9)
ax3.set_title("Mean F1 Heatmap (D=5, T_drain=0.8, SH=True)",
color=TITLE_COLOR, fontsize=12, fontweight="bold", pad=10)
for i in range(len(K_vals)):
for j in range(len(T_fill_vals)):
val = heatmap[i, j]
ax3.text(j, i, f"{val:.3f}", ha="center", va="center",
fontsize=10, color="white" if val < 0.72 else "#0d1117",
fontweight="bold")
cbar = plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
cbar.ax.tick_params(colors=LABEL_COLOR, labelsize=8)
cbar.set_label("Mean F1", color=LABEL_COLOR, fontsize=9)
# ─── Panel 4: Cup Trace v4 vs v5 ──────────────────────────────────────────────
ax4 = axes[1, 1]
# p53 TAD sequence
P53_TAD = ("MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPG"
"PDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYPQGLNGTVNLFGRNSFEV")
# v4 trace (using v4 best config: K=15, T_fill=1.5, T_drain=1.0, N=2, D=4, SH=True)
from inference_engine_v4 import HatchClassifierV4
from training_engine import train as _train4
_thresholds_v4, _train_means_v4, _ = _train4(window_size=30)
clf_v4 = HatchClassifierV4(
thresholds=_thresholds_v4, train_means=_train_means_v4,
K=15, T_fill=1.5, T_drain=1.0, N=2, D=4, super_hatch=True
)
trace_v4 = clf_v4.trace(P53_TAD)
# v5 trace (best config: K=15, T_fill=1.5, T_drain=0.8, N=2, D=5, SH=True)
from inference_engine_v5 import load_v5_classifier
clf_v5 = load_v5_classifier(K=15, T_fill=1.5, T_drain=0.8, N=2, D=5, super_hatch=True)
trace_v5 = clf_v5.trace(P53_TAD)
pos_v4 = trace_v4["positions"]
pos_v5 = trace_v5["positions"]
cup_v4 = trace_v4["entropy_levels"]
cup_v5 = trace_v5["entropy_levels"]
ax4.plot(pos_v4, cup_v4, color=ACCENT_DIS, linewidth=2.0, alpha=0.9, label="v4 (global thresholds)")
ax4.plot(pos_v5, cup_v5, color=ACCENT_MEAN, linewidth=2.0, alpha=0.9, label="v5 (calibrated thresholds)")
K_LINE = 15
ax4.axhline(K_LINE, color="#f0e68c", linestyle="--", linewidth=1.2, alpha=0.7, label=f"Overflow (K={K_LINE})")
# Mark overflow points
if trace_v4["overflow_at"] is not None:
ov4 = pos_v4[trace_v4["overflow_at"]]
ax4.axvline(ov4, color=ACCENT_DIS, linestyle=":", linewidth=1, alpha=0.6)
ax4.text(ov4 + 1, K_LINE + 0.3, f"v4 overflow\n@pos {ov4}", color=ACCENT_DIS,
fontsize=7.5, va="bottom")
if trace_v5["overflow_at"] is not None:
ov5 = pos_v5[trace_v5["overflow_at"]]
ax4.axvline(ov5, color=ACCENT_MEAN, linestyle=":", linewidth=1, alpha=0.6)
ax4.text(ov5 + 1, K_LINE - 2.5, f"v5 overflow\n@pos {ov5}", color=ACCENT_MEAN,
fontsize=7.5, va="bottom")
ax4.set_xlabel("Sequence Position", color=LABEL_COLOR, fontsize=10)
ax4.set_ylabel("Cup Level", color=LABEL_COLOR, fontsize=10)
ax4.set_title("Cup Trace: v4 vs v5 on p53 TAD (Disordered)",
color=TITLE_COLOR, fontsize=12, fontweight="bold", pad=10)
ax4.legend(loc="upper left", facecolor="#21262d", edgecolor="#30363d",
labelcolor=TITLE_COLOR, fontsize=9)
ax4.grid(color="#30363d", linewidth=0.5, alpha=0.5)
# ─── Save ─────────────────────────────────────────────────────────────────────
plt.suptitle("Hatch v5: Per-Feature Calibrated Thresholds (Coordinate Descent)",
color=TITLE_COLOR, fontsize=14, fontweight="bold", y=1.01)
plt.tight_layout()
plt.savefig("hatch_v5_analysis.png", dpi=150, bbox_inches="tight",
facecolor=fig.get_facecolor())
print("Saved: hatch_v5_analysis.png")