-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_engine_v3.py
More file actions
296 lines (249 loc) · 11.9 KB
/
Copy pathinference_engine_v3.py
File metadata and controls
296 lines (249 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
inference_engine_v3.py — Hatch v3: Hysteresis + Super-Hatch
=============================================================
State machine changes from v2:
- Asymmetric thresholds (Hysteresis):
T_fill : score < T_fill → cup fills (default 4)
T_drain : score >= T_drain → counts toward N-consensus drain (default 3)
The gap [T_drain, T_fill) is the "grace zone" — windows scoring in this
range neither fill the cup NOR reset the consecutive counter. They
contribute to the drain consensus without adding entropy.
- Super-Hatch hard reset:
score == 7 (perfect window) → cup_level = 0 immediately, regardless
of N. Biophysically: a 7/7 window is a high-confidence structural
anchor, not noise.
v3 Core Logic (per window):
score = _score_window(features)
if super_hatch and score == 7:
cup_level = 0 # Super-Hatch: hard reset
consecutive_ordered = 0
elif score < T_fill:
cup_level += 1 # Clog: cup fills
consecutive_ordered = 0
elif score >= T_drain: # Grace zone or above: drain-eligible
consecutive_ordered += 1
if consecutive_ordered >= N:
cup_level = max(0, cup_level - D)
consecutive_ordered = 0
if cup_level > K:
return DISORDERED
Key properties:
- When T_drain < T_fill, windows scoring [T_drain, T_fill) are in the
"grace zone": they stop the cup from filling AND contribute to the
drain consensus, but they don't fill either. This is the Hysteresis.
- When T_drain == T_fill (degenerate case), v3 reduces exactly to v2.
- super_hatch=False reduces v3 to the pure Hysteresis model.
Parameters:
W : window size (residues)
K : overflow depth (cup capacity)
T_fill : fill threshold (score < T_fill → fill)
T_drain : drain threshold (score >= T_drain → count toward drain)
N : consecutive drain-eligible windows required to drain
D : drain amount per consensus event
super_hatch : bool — enable score==7 hard reset (default True)
"""
from __future__ import annotations
from collections import deque
from typing import Any, Dict, List, Optional
import numpy as np
from training_engine import (
CANONICAL,
compute_features,
window_generator,
)
# ─── v3 Classifier ────────────────────────────────────────────────────────────
class HatchClassifierV3:
"""
Hatch v3 — Hysteresis + Super-Hatch water-in-cup protein disorder classifier.
Parameters
----------
thresholds : np.ndarray (7,) — local midpoint thresholds from training_engine
train_means : np.ndarray (7,) — folded-class means; determines which side of
each threshold is "folded"
window_size : W — sliding window length in residues
K : cup overflow depth
T_fill : score < T_fill → cup fills (default 4)
T_drain : score >= T_drain → counts toward N-consensus drain (default 3)
N : consecutive drain-eligible windows required to drain (default 2)
D : drain amount per consensus event (default 3)
super_hatch : bool — enable score==7 hard reset (default True)
step : sliding step (default 1)
"""
FOLDED = 1
DISORDERED = 0
def __init__(
self,
thresholds: np.ndarray,
train_means: np.ndarray,
window_size: int = 30,
K: int = 6,
T_fill: int = 4,
T_drain: int = 3,
N: int = 2,
D: int = 3,
super_hatch: bool = True,
step: int = 1,
):
self.thresholds = thresholds
self.train_means = train_means
self.window_size = window_size
self.K = K
self.T_fill = T_fill
self.T_drain = T_drain
self.N = N
self.D = D
self.super_hatch = super_hatch
self.step = step
# Precompute direction flags: True if folded mean > threshold (higher = folded)
self._folded_is_high = train_means > thresholds # shape (7,)
# ── Scoring (identical to v1/v2) ─────────────────────────────────────────
def _score_window(self, features: np.ndarray) -> int:
"""
Count how many of the 7 features satisfy the folded condition.
For each feature:
- feature > threshold if folded_mean > threshold (higher = folded)
- feature < threshold if folded_mean < threshold (lower = folded)
"""
folded_conditions = np.where(
self._folded_is_high,
features > self.thresholds,
features < self.thresholds,
)
return int(np.sum(folded_conditions))
# ── Fast path (classify only) ─────────────────────────────────────────────
def classify(self, sequence: str) -> int:
"""
Classify a protein sequence.
Returns 1 (Folded) or 0 (Disordered).
Optimised for speed: no per-window allocation beyond the score.
"""
canonical = "".join(aa for aa in sequence.upper() if aa in CANONICAL)
if not canonical or len(canonical) < self.window_size:
return self.DISORDERED # too short — conservative
cup_level = 0
consecutive_ordered = 0
for window in window_generator(canonical, self.window_size, self.step):
features = compute_features(window)
score = self._score_window(features)
# ── v3 state machine ─────────────────────────────────────────────
if self.super_hatch and score == 7:
cup_level = 0
consecutive_ordered = 0
elif score < self.T_fill:
cup_level += 1
consecutive_ordered = 0
elif score >= self.T_drain:
consecutive_ordered += 1
if consecutive_ordered >= self.N:
cup_level = max(0, cup_level - self.D)
consecutive_ordered = 0
if cup_level > self.K:
return self.DISORDERED
return self.FOLDED
# ── Diagnostic trace ─────────────────────────────────────────────────────
def trace(self, sequence: str) -> Dict[str, Any]:
"""
Run the state machine and return a full per-window trace.
Returns
-------
dict with keys:
positions : list[int] — window start indices
scores : list[int] — per-window feature scores (0–7)
entropy_levels : list[int] — cup_level after each window
drain_events : list[bool] — True when consensus drain fired
super_hatch_events : list[bool] — True when Super-Hatch fired
fill_events : list[bool] — True when cup filled
grace_events : list[bool] — True when window was in grace zone
overflow_at : int | None — index of first overflow, or None
prediction : int — 0 (Disordered) or 1 (Folded)
n_windows : int — total windows evaluated
"""
canonical = "".join(aa for aa in sequence.upper() if aa in CANONICAL)
if not canonical or len(canonical) < self.window_size:
return {
"prediction": self.DISORDERED,
"overflow_at": None,
"positions": [],
"scores": [],
"entropy_levels": [],
"drain_events": [],
"super_hatch_events": [],
"fill_events": [],
"grace_events": [],
"n_windows": 0,
}
cup_level = 0
consecutive_ordered = 0
overflow_at: Optional[int] = None
positions: List[int] = []
scores: List[int] = []
entropy_levels: List[int] = []
drain_events: List[bool] = []
super_hatch_events: List[bool] = []
fill_events: List[bool] = []
grace_events: List[bool] = []
for i, window in enumerate(
window_generator(canonical, self.window_size, self.step)
):
features = compute_features(window)
score = self._score_window(features)
# ── v3 state machine ─────────────────────────────────────────────
is_drain = False
is_super_hatch = False
is_fill = False
is_grace = False
if self.super_hatch and score == 7:
cup_level = 0
consecutive_ordered = 0
is_super_hatch = True
elif score < self.T_fill:
cup_level += 1
consecutive_ordered = 0
is_fill = True
elif score >= self.T_drain:
if score < self.T_fill: # always False here, but kept for clarity
is_grace = True
else:
is_grace = (score < self.T_fill) # False when T_drain < T_fill
# Grace zone: T_drain <= score < T_fill
is_grace = (self.T_drain <= score < self.T_fill)
consecutive_ordered += 1
if consecutive_ordered >= self.N:
cup_level = max(0, cup_level - self.D)
consecutive_ordered = 0
is_drain = True
# ── Overflow check ────────────────────────────────────────────────
if cup_level > self.K and overflow_at is None:
overflow_at = i
# Window start position in the original canonical sequence
positions.append(i * self.step)
scores.append(score)
entropy_levels.append(cup_level)
drain_events.append(is_drain)
super_hatch_events.append(is_super_hatch)
fill_events.append(is_fill)
grace_events.append(is_grace)
prediction = self.DISORDERED if overflow_at is not None else self.FOLDED
return {
"prediction": prediction,
"overflow_at": overflow_at,
"positions": positions,
"scores": scores,
"entropy_levels": entropy_levels,
"drain_events": drain_events,
"super_hatch_events": super_hatch_events,
"fill_events": fill_events,
"grace_events": grace_events,
"n_windows": len(positions),
}
# ── Convenience ──────────────────────────────────────────────────────────
# Alias for API compatibility with v1/v2
def predict(self, sequence: str) -> int:
return self.classify(sequence)
def describe(self) -> str:
sh = "ON" if self.super_hatch else "OFF"
return (
f"HatchV3(W={self.window_size}, K={self.K}, "
f"T_fill={self.T_fill}, T_drain={self.T_drain}, "
f"N={self.N}, D={self.D}, super_hatch={sh})"
)