@@ -96,6 +96,89 @@ def compute_distogram_loss(preds, coords, mask, no_bins=64, max_dist=25.0):
9696 return mean .mean ()
9797
9898
99+ def compute_distogram_metrics (preds , coords , mask , no_bins = 64 , max_dist = 25.0 , contact_threshold = 8.0 ):
100+ """Compute structure prediction quality metrics from distogram predictions.
101+
102+ Args:
103+ preds: Predicted distogram logits (B, L, L, no_bins).
104+ coords: Ca coordinates (B, L, 3).
105+ mask: Residue mask (B, L).
106+ no_bins: Number of distance bins.
107+ max_dist: Maximum distance in Angstroms.
108+ contact_threshold: Distance threshold for contact prediction (Angstroms).
109+
110+ Returns:
111+ Dict with: distogram_acc, contact_precision, contact_recall,
112+ lddt_from_distogram, mean_distance_error.
113+ """
114+ with torch .no_grad ():
115+ # True pairwise distances
116+ true_dists = torch .cdist (coords , coords )
117+
118+ # Bin boundaries and centers
119+ boundaries = torch .linspace (2 , max_dist , no_bins - 1 , device = preds .device )
120+ bin_centers = torch .cat (
121+ [
122+ torch .tensor ([1.0 ], device = preds .device ),
123+ (boundaries [:- 1 ] + boundaries [1 :]) / 2 ,
124+ torch .tensor ([max_dist + 2.0 ], device = preds .device ),
125+ ]
126+ )
127+
128+ # True bin indices
129+ true_bins = (true_dists .unsqueeze (- 1 ) > boundaries ).sum (dim = - 1 )
130+
131+ # Predicted bin indices and probabilities
132+ pred_bins = preds .argmax (dim = - 1 )
133+ pred_probs = F .softmax (preds , dim = - 1 )
134+
135+ # Expected predicted distance from distogram
136+ pred_dists = (pred_probs * bin_centers ).sum (dim = - 1 )
137+
138+ # Valid pair mask (exclude self and padding)
139+ square_mask = mask [:, None , :] * mask [:, :, None ]
140+ eye = torch .eye (mask .shape [1 ], device = mask .device ).unsqueeze (0 )
141+ pair_mask = square_mask * (1 - eye )
142+ n_pairs = pair_mask .sum ().clamp (min = 1 )
143+
144+ # 1. Distogram accuracy
145+ correct = (pred_bins == true_bins ).float () * pair_mask
146+ distogram_acc = correct .sum () / n_pairs
147+
148+ # 2. Contact precision and recall at threshold
149+ true_contacts = (true_dists < contact_threshold ).float () * pair_mask
150+ pred_contacts = (pred_dists < contact_threshold ).float () * pair_mask
151+
152+ tp = (true_contacts * pred_contacts ).sum ()
153+ contact_precision = tp / pred_contacts .sum ().clamp (min = 1 )
154+ contact_recall = tp / true_contacts .sum ().clamp (min = 1 )
155+
156+ # 3. lDDT from distogram expected distances
157+ # Standard lDDT: fraction of pairwise distances within thresholds
158+ dist_error = torch .abs (pred_dists - true_dists )
159+ lddt_score = (
160+ (dist_error < 0.5 ).float ()
161+ + (dist_error < 1.0 ).float ()
162+ + (dist_error < 2.0 ).float ()
163+ + (dist_error < 4.0 ).float ()
164+ ) * 0.25
165+
166+ # Only score pairs within 15Å cutoff (standard lDDT)
167+ lddt_mask = pair_mask * (true_dists < 15.0 ).float ()
168+ lddt_from_distogram = (lddt_score * lddt_mask ).sum () / lddt_mask .sum ().clamp (min = 1 )
169+
170+ # 4. Mean distance error (on valid pairs within 15Å)
171+ mean_dist_error = (dist_error * lddt_mask ).sum () / lddt_mask .sum ().clamp (min = 1 )
172+
173+ return {
174+ "distogram_acc" : distogram_acc ,
175+ "contact_precision_8A" : contact_precision ,
176+ "contact_recall_8A" : contact_recall ,
177+ "lddt_from_distogram" : lddt_from_distogram ,
178+ "mean_distance_error" : mean_dist_error ,
179+ }
180+
181+
99182@hydra .main (config_path = "hydra_config" , config_name = "L0_sanity" , version_base = "1.2" )
100183def main (args : DictConfig ) -> float | None :
101184 """Train ESM2-MiniFold TE with FSDP2.
@@ -227,13 +310,22 @@ def main(args: DictConfig) -> float | None:
227310 scheduler .step ()
228311 optimizer .zero_grad ()
229312
313+ # Compute structure quality metrics (no grad, cheap)
314+ structure_metrics = compute_distogram_metrics (
315+ preds = r_dict ["preds" ].float (),
316+ coords = batch ["coords" ],
317+ mask = batch ["mask" ],
318+ no_bins = args .model .no_bins ,
319+ )
320+
230321 # Logging
231322 perf_logger .log_step (
232323 step = step ,
233324 loss = total_loss ,
234325 disto_loss = disto_loss ,
235326 grad_norm = total_norm ,
236327 lr = optimizer .param_groups [0 ]["lr" ],
328+ structure_metrics = structure_metrics ,
237329 )
238330
239331 # Checkpointing
0 commit comments