Skip to content

Commit a0e3f55

Browse files
committed
adds wandb
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent c4e88bc commit a0e3f55

File tree

3 files changed

+106
-2
lines changed

3 files changed

+106
-2
lines changed

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@ logger:
6363
wandb_init_args:
6464
project: esm2_minifold_te
6565
name: run_100_650M_real_pdb
66-
mode: offline
66+
mode: online

bionemo-recipes/recipes/esm2_minifold_te/perf_logger.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import torch
2222
import torchmetrics
23-
import wandb
2423
from omegaconf import DictConfig, OmegaConf
2524
from torch.distributed.tensor import DTensor
2625
from tqdm import tqdm
2726

27+
import wandb
2828
from distributed_config import DistributedConfig
2929

3030

@@ -52,6 +52,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
5252
"train/learning_rate": torchmetrics.MeanMetric(),
5353
"train/step_time": torchmetrics.MeanMetric(),
5454
"train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(),
55+
"train/distogram_acc": torchmetrics.MeanMetric(),
56+
"train/contact_precision_8A": torchmetrics.MeanMetric(),
57+
"train/contact_recall_8A": torchmetrics.MeanMetric(),
58+
"train/lddt_from_distogram": torchmetrics.MeanMetric(),
59+
"train/mean_distance_error": torchmetrics.MeanMetric(),
5560
}
5661

5762
self.metrics = torchmetrics.MetricCollection(metrics_dict)
@@ -69,6 +74,7 @@ def log_step(
6974
disto_loss: torch.Tensor | None = None,
7075
grad_norm: torch.Tensor | DTensor | float = 0.0,
7176
lr: float = 0.0,
77+
structure_metrics: dict[str, torch.Tensor] | None = None,
7278
):
7379
"""Log a training step."""
7480
with torch.no_grad():
@@ -90,6 +96,12 @@ def log_step(
9096
self.metrics["train/grad_norm"].update(grad_norm)
9197
self.metrics["train/step_time"].update(step_time)
9298

99+
if structure_metrics is not None:
100+
for key, value in structure_metrics.items():
101+
metric_key = f"train/{key}"
102+
if metric_key in self.metrics:
103+
self.metrics[metric_key].update(value)
104+
93105
memory_allocated = torch.cuda.memory_allocated() / (1024**3)
94106
self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated)
95107

bionemo-recipes/recipes/esm2_minifold_te/train_fsdp2.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,89 @@ def compute_distogram_loss(preds, coords, mask, no_bins=64, max_dist=25.0):
9696
return mean.mean()
9797

9898

99+
def compute_distogram_metrics(preds, coords, mask, no_bins=64, max_dist=25.0, contact_threshold=8.0):
100+
"""Compute structure prediction quality metrics from distogram predictions.
101+
102+
Args:
103+
preds: Predicted distogram logits (B, L, L, no_bins).
104+
coords: Ca coordinates (B, L, 3).
105+
mask: Residue mask (B, L).
106+
no_bins: Number of distance bins.
107+
max_dist: Maximum distance in Angstroms.
108+
contact_threshold: Distance threshold for contact prediction (Angstroms).
109+
110+
Returns:
111+
Dict with: distogram_acc, contact_precision, contact_recall,
112+
lddt_from_distogram, mean_distance_error.
113+
"""
114+
with torch.no_grad():
115+
# True pairwise distances
116+
true_dists = torch.cdist(coords, coords)
117+
118+
# Bin boundaries and centers
119+
boundaries = torch.linspace(2, max_dist, no_bins - 1, device=preds.device)
120+
bin_centers = torch.cat(
121+
[
122+
torch.tensor([1.0], device=preds.device),
123+
(boundaries[:-1] + boundaries[1:]) / 2,
124+
torch.tensor([max_dist + 2.0], device=preds.device),
125+
]
126+
)
127+
128+
# True bin indices
129+
true_bins = (true_dists.unsqueeze(-1) > boundaries).sum(dim=-1)
130+
131+
# Predicted bin indices and probabilities
132+
pred_bins = preds.argmax(dim=-1)
133+
pred_probs = F.softmax(preds, dim=-1)
134+
135+
# Expected predicted distance from distogram
136+
pred_dists = (pred_probs * bin_centers).sum(dim=-1)
137+
138+
# Valid pair mask (exclude self and padding)
139+
square_mask = mask[:, None, :] * mask[:, :, None]
140+
eye = torch.eye(mask.shape[1], device=mask.device).unsqueeze(0)
141+
pair_mask = square_mask * (1 - eye)
142+
n_pairs = pair_mask.sum().clamp(min=1)
143+
144+
# 1. Distogram accuracy
145+
correct = (pred_bins == true_bins).float() * pair_mask
146+
distogram_acc = correct.sum() / n_pairs
147+
148+
# 2. Contact precision and recall at threshold
149+
true_contacts = (true_dists < contact_threshold).float() * pair_mask
150+
pred_contacts = (pred_dists < contact_threshold).float() * pair_mask
151+
152+
tp = (true_contacts * pred_contacts).sum()
153+
contact_precision = tp / pred_contacts.sum().clamp(min=1)
154+
contact_recall = tp / true_contacts.sum().clamp(min=1)
155+
156+
# 3. lDDT from distogram expected distances
157+
# Standard lDDT: fraction of pairwise distances within thresholds
158+
dist_error = torch.abs(pred_dists - true_dists)
159+
lddt_score = (
160+
(dist_error < 0.5).float()
161+
+ (dist_error < 1.0).float()
162+
+ (dist_error < 2.0).float()
163+
+ (dist_error < 4.0).float()
164+
) * 0.25
165+
166+
# Only score pairs within 15Å cutoff (standard lDDT)
167+
lddt_mask = pair_mask * (true_dists < 15.0).float()
168+
lddt_from_distogram = (lddt_score * lddt_mask).sum() / lddt_mask.sum().clamp(min=1)
169+
170+
# 4. Mean distance error (on valid pairs within 15Å)
171+
mean_dist_error = (dist_error * lddt_mask).sum() / lddt_mask.sum().clamp(min=1)
172+
173+
return {
174+
"distogram_acc": distogram_acc,
175+
"contact_precision_8A": contact_precision,
176+
"contact_recall_8A": contact_recall,
177+
"lddt_from_distogram": lddt_from_distogram,
178+
"mean_distance_error": mean_dist_error,
179+
}
180+
181+
99182
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
100183
def main(args: DictConfig) -> float | None:
101184
"""Train ESM2-MiniFold TE with FSDP2.
@@ -227,13 +310,22 @@ def main(args: DictConfig) -> float | None:
227310
scheduler.step()
228311
optimizer.zero_grad()
229312

313+
# Compute structure quality metrics (no grad, cheap)
314+
structure_metrics = compute_distogram_metrics(
315+
preds=r_dict["preds"].float(),
316+
coords=batch["coords"],
317+
mask=batch["mask"],
318+
no_bins=args.model.no_bins,
319+
)
320+
230321
# Logging
231322
perf_logger.log_step(
232323
step=step,
233324
loss=total_loss,
234325
disto_loss=disto_loss,
235326
grad_norm=total_norm,
236327
lr=optimizer.param_groups[0]["lr"],
328+
structure_metrics=structure_metrics,
237329
)
238330

239331
# Checkpointing

0 commit comments

Comments
 (0)