Skip to content

Commit 6d47468

Browse files
committed
adds eval script
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent a0e3f55 commit 6d47468

File tree

3 files changed

+252
-3
lines changed

3 files changed

+252
-3
lines changed

bionemo-recipes/recipes/esm2_minifold_te/dataset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,11 @@ def create_dataloader(
316316
num_samples: int = 1000,
317317
cif_dir: str | None = None,
318318
pdb_ids: list[str] | None = None,
319+
shuffle: bool = True,
320+
drop_last: bool = True,
319321
**kwargs,
320322
):
321-
"""Create a DataLoader for structure prediction training.
323+
"""Create a DataLoader for structure prediction training or evaluation.
322324
323325
Args:
324326
dist_config: Distributed training configuration.
@@ -331,6 +333,8 @@ def create_dataloader(
331333
num_samples: Number of synthetic samples.
332334
cif_dir: Directory with .cif files (required if dataset_type="mmcif").
333335
pdb_ids: Optional list of PDB IDs to filter (for dataset_type="mmcif").
336+
shuffle: Whether to shuffle the data (False for eval).
337+
drop_last: Whether to drop the last incomplete batch (False for eval).
334338
**kwargs: Additional keyword arguments (ignored).
335339
336340
Returns:
@@ -367,7 +371,7 @@ def create_dataloader(
367371
dataset,
368372
num_replicas=dist_config.world_size,
369373
rank=dist_config.rank,
370-
shuffle=True,
374+
shuffle=shuffle,
371375
)
372376

373377
dataloader = DataLoader(
@@ -376,7 +380,7 @@ def create_dataloader(
376380
sampler=sampler,
377381
num_workers=num_workers,
378382
pin_memory=True,
379-
drop_last=True,
383+
drop_last=drop_last,
380384
)
381385

382386
return dataloader, sampler
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""FSDP2 evaluation script for ESM2-MiniFold TE structure prediction.
17+
18+
Loads a trained checkpoint and evaluates on a held-out dataset, reporting
19+
structure quality metrics (lDDT, distogram accuracy, contact prediction)
20+
to WandB and stdout.
21+
22+
Usage:
23+
# With FSDP2 distributed checkpoint
24+
torchrun --nproc_per_node=2 eval_fsdp2.py checkpoint.ckpt_dir=/path/to/checkpoints
25+
26+
# With exported safetensors model
27+
torchrun --nproc_per_node=2 eval_fsdp2.py \
28+
checkpoint.ckpt_dir=/path/to/final_model \
29+
checkpoint.checkpoint_type=safetensors
30+
"""
31+
32+
import logging
33+
import os
34+
from pathlib import Path
35+
36+
import hydra
37+
import torch
38+
from omegaconf import DictConfig, OmegaConf
39+
from torch.distributed.device_mesh import init_device_mesh
40+
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
41+
from tqdm import tqdm
42+
43+
import wandb
44+
from checkpoint import load_checkpoint_fsdp2
45+
from dataset import create_dataloader
46+
from distributed_config import DistributedConfig
47+
from modeling_esm2_minifold_te import ESM2MiniFoldTE
48+
from precision_config import FoldingHeadPrecisionConfig
49+
from scheduler import get_linear_schedule_with_warmup
50+
from train_fsdp2 import compute_distogram_loss, compute_distogram_metrics
51+
52+
53+
logger = logging.getLogger(__name__)
54+
logger.setLevel(logging.INFO)
55+
56+
57+
@hydra.main(config_path="hydra_config", config_name="eval", version_base="1.2")
58+
def main(args: DictConfig) -> None:
59+
"""Evaluate ESM2-MiniFold TE on a held-out dataset."""
60+
os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "1"
61+
logging.getLogger("httpx").setLevel(logging.WARNING)
62+
63+
# Initialize distributed
64+
dist_config = DistributedConfig()
65+
logger.info("Initializing eval: %s", dist_config)
66+
device = torch.device(f"cuda:{dist_config.local_rank}")
67+
torch.distributed.init_process_group(backend="nccl", device_id=device)
68+
torch.cuda.set_device(dist_config.local_rank)
69+
70+
device_mesh = init_device_mesh(
71+
"cuda",
72+
mesh_shape=(dist_config.world_size,),
73+
mesh_dim_names=("dp",),
74+
)
75+
76+
# Create model (same architecture as training)
77+
model = ESM2MiniFoldTE(
78+
esm_model_name=args.esm_model_name,
79+
c_s=args.model.c_s,
80+
c_z=args.model.c_z,
81+
num_blocks=args.model.num_blocks,
82+
no_bins=args.model.no_bins,
83+
use_structure_module=args.model.use_structure_module,
84+
).to(device)
85+
86+
# FSDP2 sharding (must match training for checkpoint loading)
87+
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
88+
for block in model.fold.miniformer.blocks:
89+
fully_shard(block, mesh=device_mesh["dp"], mp_policy=mp_policy)
90+
fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy)
91+
92+
# Load checkpoint
93+
ckpt_dir = Path(args.checkpoint.ckpt_dir)
94+
checkpoint_type = args.checkpoint.get("checkpoint_type", "fsdp2")
95+
96+
if checkpoint_type == "fsdp2":
97+
# Need dummy optimizer/scheduler for the checkpoint loader
98+
dummy_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
99+
dummy_scheduler = get_linear_schedule_with_warmup(dummy_optimizer, num_warmup_steps=0, num_training_steps=1)
100+
ckpt_path = ckpt_dir / "train_fsdp2"
101+
model, _, _, _, loaded_step, _ = load_checkpoint_fsdp2(
102+
model=model,
103+
optimizer=dummy_optimizer,
104+
scheduler=dummy_scheduler,
105+
ckpt_path=ckpt_path,
106+
dist_config=dist_config,
107+
)
108+
logger.info("Loaded FSDP2 checkpoint from step %d", loaded_step)
109+
elif checkpoint_type == "safetensors":
110+
from safetensors.torch import load_file
111+
112+
state_dict = load_file(str(ckpt_dir / "model.safetensors"))
113+
model.load_state_dict(state_dict, strict=False)
114+
logger.info("Loaded safetensors model from %s", ckpt_dir)
115+
else:
116+
raise ValueError(f"Unknown checkpoint_type: {checkpoint_type}")
117+
118+
# MXFP8 precision config
119+
precision_config = FoldingHeadPrecisionConfig(**OmegaConf.to_container(args.mxfp8, resolve=True))
120+
if dist_config.is_main_process():
121+
logger.info("Precision: %s", precision_config.summary())
122+
123+
# Create eval dataloader (shuffle=False, drop_last=False from config)
124+
eval_dataloader, _ = create_dataloader(dist_config, **args.eval_dataset)
125+
logger.info("Eval dataset: %d batches", len(eval_dataloader))
126+
127+
# Initialize WandB
128+
run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)
129+
if dist_config.is_main_process():
130+
wandb.init(**args.wandb_init_args, config=run_config)
131+
132+
# Eval loop
133+
model.eval()
134+
all_metrics = {
135+
"loss": [],
136+
"disto_loss": [],
137+
"distogram_acc": [],
138+
"contact_precision_8A": [],
139+
"contact_recall_8A": [],
140+
"lddt_from_distogram": [],
141+
"mean_distance_error": [],
142+
}
143+
144+
progress = tqdm(eval_dataloader, desc="Evaluating", disable=not dist_config.is_main_process())
145+
146+
with torch.no_grad():
147+
for batch in progress:
148+
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
149+
150+
with torch.autocast("cuda", dtype=torch.bfloat16):
151+
r_dict = model(batch, num_recycling=args.model.get("num_recycling", 0))
152+
153+
# Distogram loss
154+
disto_loss = compute_distogram_loss(
155+
preds=r_dict["preds"],
156+
coords=batch["coords"],
157+
mask=batch["mask"],
158+
no_bins=args.model.no_bins,
159+
)
160+
161+
# Structure quality metrics
162+
metrics = compute_distogram_metrics(
163+
preds=r_dict["preds"].float(),
164+
coords=batch["coords"],
165+
mask=batch["mask"],
166+
no_bins=args.model.no_bins,
167+
)
168+
169+
all_metrics["loss"].append(disto_loss.item())
170+
all_metrics["disto_loss"].append(disto_loss.item())
171+
for key, value in metrics.items():
172+
all_metrics[key].append(value.item())
173+
174+
progress.set_postfix(
175+
{
176+
"loss": f"{disto_loss.item():.3f}",
177+
"lddt": f"{metrics['lddt_from_distogram'].item():.3f}",
178+
}
179+
)
180+
181+
# Aggregate metrics
182+
summary = {}
183+
for key, values in all_metrics.items():
184+
if values:
185+
summary[f"eval/{key}"] = sum(values) / len(values)
186+
187+
# Log to WandB and stdout
188+
if dist_config.is_main_process():
189+
wandb.log(summary)
190+
wandb.finish()
191+
192+
if dist_config.local_rank == 0:
193+
logger.info("=== Evaluation Results ===")
194+
logger.info("Batches evaluated: %d", len(all_metrics["loss"]))
195+
for key, value in summary.items():
196+
logger.info(" %s: %.4f", key, value)
197+
198+
torch.distributed.destroy_process_group()
199+
200+
201+
if __name__ == "__main__":
202+
main()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# ESM2-MiniFold TE: Post-training evaluation on held-out structures
2+
# Usage: torchrun --nproc_per_node=2 eval_fsdp2.py checkpoint.ckpt_dir=/path/to/checkpoint
3+
4+
esm_model_name: facebook/esm2_t33_650M_UR50D
5+
6+
model:
7+
c_s: 1024
8+
c_z: 128
9+
num_blocks: 8
10+
no_bins: 64
11+
use_structure_module: false
12+
num_recycling: 0
13+
14+
eval_dataset:
15+
dataset_type: parquet
16+
parquet_path: data/eval_structures.parquet
17+
tokenizer_name: ${esm_model_name}
18+
micro_batch_size: 4
19+
max_seq_length: 256
20+
num_workers: 2
21+
shuffle: false
22+
drop_last: false
23+
24+
checkpoint:
25+
ckpt_dir: ??? # required: path to trained checkpoint or final model
26+
checkpoint_type: fsdp2 # "fsdp2" for distributed checkpoints, "safetensors" for exported model
27+
28+
mxfp8:
29+
enabled: false
30+
tri_proj: false
31+
tri_gate: false
32+
ffn: false
33+
struct_attn: false
34+
struct_ffn: false
35+
seq_proj: false
36+
dist_head: false
37+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
38+
fp8_recipe_kwargs: {}
39+
40+
wandb_init_args:
41+
project: esm2_minifold_te
42+
name: eval_${now:%Y%m%d_%H%M%S}
43+
mode: online

0 commit comments

Comments
 (0)