Skip to content

Commit 44252c4

Browse files
committed
adds dataset support for minifold
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent f5967e6 commit 44252c4

File tree

5 files changed

+577
-5
lines changed

5 files changed

+577
-5
lines changed

bionemo-recipes/recipes/esm2_minifold_te/dataset.py

Lines changed: 163 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
"""Dataset for protein structure prediction training.
1717
1818
Provides:
19-
- StructureDataset: loads protein structures from parquet/PDB files
2019
- SyntheticStructureDataset: generates random data for testing
21-
- create_dataloader: creates a DataLoader for training
20+
- ParquetStructureDataset: loads from parquet (pre-processed Ca coords)
21+
- MmcifStructureDataset: loads from mmCIF files on-the-fly via BioPython
22+
- create_dataloader: factory function for any dataset type
2223
"""
2324

2425
import logging
26+
from pathlib import Path
27+
from typing import ClassVar
2528

2629
import torch
2730
from torch.utils.data import DataLoader, Dataset, DistributedSampler
@@ -90,6 +93,7 @@ class ParquetStructureDataset(Dataset):
9093
Expected columns:
9194
sequence: str - amino acid sequence
9295
coords: list[list[float]] - Ca coordinates (N, 3)
96+
ca_mask: list[int] - optional, 1=valid Ca, 0=missing (defaults to all-1s)
9397
"""
9498

9599
def __init__(self, parquet_path: str, tokenizer, max_seq_length: int = 256):
@@ -98,6 +102,7 @@ def __init__(self, parquet_path: str, tokenizer, max_seq_length: int = 256):
98102
self.df = pd.read_parquet(parquet_path)
99103
self.tokenizer = tokenizer
100104
self.max_seq_length = max_seq_length
105+
self.has_ca_mask = "ca_mask" in self.df.columns
101106

102107
def __len__(self):
103108
return len(self.df)
@@ -121,11 +126,150 @@ def __getitem__(self, idx):
121126
mask = attention_mask.float()
122127

123128
# Coordinates: pad to max_seq_length
124-
coords_raw = torch.tensor(row["coords"], dtype=torch.float32)
129+
import numpy as np
130+
131+
# Parquet stores list-of-lists as numpy object array of arrays; np.stack handles this
132+
coords_raw = torch.from_numpy(np.stack(row["coords"]).astype(np.float32))
125133
coords = torch.zeros(self.max_seq_length, 3)
126134
seq_len = min(len(coords_raw), self.max_seq_length)
127135
coords[:seq_len] = coords_raw[:seq_len]
128136

137+
# Zero out coords for residues with missing Ca atoms
138+
if self.has_ca_mask:
139+
ca_mask_list = row["ca_mask"]
140+
for i in range(min(len(ca_mask_list), self.max_seq_length)):
141+
if ca_mask_list[i] == 0:
142+
coords[i] = 0.0
143+
144+
return {
145+
"input_ids": input_ids,
146+
"attention_mask": attention_mask,
147+
"mask": mask,
148+
"coords": coords,
149+
}
150+
151+
152+
class MmcifStructureDataset(Dataset):
153+
"""Loads protein structures directly from mmCIF files via BioPython.
154+
155+
Parses each .cif file on-the-fly, extracts the amino acid sequence and Ca
156+
coordinates, tokenizes with ESM-2, and returns the standard batch format.
157+
"""
158+
159+
# 3-letter to 1-letter amino acid mapping
160+
AA_3TO1: ClassVar[dict[str, str]] = {
161+
"ALA": "A",
162+
"CYS": "C",
163+
"ASP": "D",
164+
"GLU": "E",
165+
"PHE": "F",
166+
"GLY": "G",
167+
"HIS": "H",
168+
"ILE": "I",
169+
"LYS": "K",
170+
"LEU": "L",
171+
"MET": "M",
172+
"ASN": "N",
173+
"PRO": "P",
174+
"GLN": "Q",
175+
"ARG": "R",
176+
"SER": "S",
177+
"THR": "T",
178+
"VAL": "V",
179+
"TRP": "W",
180+
"TYR": "Y",
181+
"MSE": "M",
182+
}
183+
184+
def __init__(self, cif_dir: str, tokenizer, max_seq_length: int = 256, pdb_ids: list[str] | None = None):
185+
from Bio.PDB.MMCIFParser import MMCIFParser
186+
187+
self.tokenizer = tokenizer
188+
self.max_seq_length = max_seq_length
189+
self.parser = MMCIFParser(QUIET=True)
190+
191+
cif_path = Path(cif_dir)
192+
all_files = sorted(cif_path.glob("*.cif"))
193+
194+
if pdb_ids is not None:
195+
pdb_set = {pid.upper() for pid in pdb_ids}
196+
self.files = [f for f in all_files if f.stem.upper() in pdb_set]
197+
else:
198+
self.files = all_files
199+
200+
if not self.files:
201+
raise FileNotFoundError(f"No .cif files found in {cif_dir}")
202+
203+
logger.info("MmcifStructureDataset: %d CIF files from %s", len(self.files), cif_dir)
204+
205+
def __len__(self):
206+
return len(self.files)
207+
208+
def _parse_cif(self, cif_path):
209+
"""Parse mmCIF file and extract sequence + Ca coordinates.
210+
211+
Returns (sequence, ca_coords, ca_mask) or raises on failure.
212+
"""
213+
pdb_id = cif_path.stem
214+
structure = self.parser.get_structure(pdb_id, str(cif_path))
215+
model = structure[0]
216+
217+
for chain in model:
218+
sequence = []
219+
coords = []
220+
ca_mask = []
221+
222+
for res in chain.get_residues():
223+
if res.id[0] != " ":
224+
continue
225+
resname = res.get_resname().strip()
226+
if resname not in self.AA_3TO1:
227+
continue
228+
229+
sequence.append(self.AA_3TO1[resname])
230+
if "CA" in res:
231+
ca = res["CA"].get_vector()
232+
coords.append([float(ca[0]), float(ca[1]), float(ca[2])])
233+
ca_mask.append(1)
234+
else:
235+
coords.append([0.0, 0.0, 0.0])
236+
ca_mask.append(0)
237+
238+
if len(sequence) >= 20:
239+
return "".join(sequence), coords, ca_mask
240+
241+
raise ValueError(f"No valid protein chain in {pdb_id}")
242+
243+
def __getitem__(self, idx):
244+
try:
245+
sequence, ca_coords, ca_mask = self._parse_cif(self.files[idx])
246+
except Exception as e:
247+
logger.warning("Failed to parse %s: %s, falling back to index 0", self.files[idx].name, e)
248+
sequence, ca_coords, ca_mask = self._parse_cif(self.files[0])
249+
250+
# Tokenize
251+
encoded = self.tokenizer(
252+
sequence,
253+
max_length=self.max_seq_length,
254+
padding="max_length",
255+
truncation=True,
256+
return_tensors="pt",
257+
)
258+
input_ids = encoded["input_ids"].squeeze(0)
259+
attention_mask = encoded["attention_mask"].squeeze(0)
260+
mask = attention_mask.float()
261+
262+
# Coordinates: pad to max_seq_length
263+
coords_raw = torch.tensor(ca_coords, dtype=torch.float32)
264+
coords = torch.zeros(self.max_seq_length, 3)
265+
seq_len = min(len(coords_raw), self.max_seq_length)
266+
coords[:seq_len] = coords_raw[:seq_len]
267+
268+
# Zero out missing Ca positions
269+
for i in range(seq_len):
270+
if ca_mask[i] == 0:
271+
coords[i] = 0.0
272+
129273
return {
130274
"input_ids": input_ids,
131275
"attention_mask": attention_mask,
@@ -143,6 +287,8 @@ def create_dataloader(
143287
parquet_path: str | None = None,
144288
tokenizer_name: str | None = None,
145289
num_samples: int = 1000,
290+
cif_dir: str | None = None,
291+
pdb_ids: list[str] | None = None,
146292
**kwargs,
147293
):
148294
"""Create a DataLoader for structure prediction training.
@@ -152,10 +298,12 @@ def create_dataloader(
152298
micro_batch_size: Batch size per GPU.
153299
max_seq_length: Maximum sequence length.
154300
num_workers: Number of DataLoader workers.
155-
dataset_type: "synthetic" or "parquet".
301+
dataset_type: "synthetic", "parquet", or "mmcif".
156302
parquet_path: Path to parquet file (required if dataset_type="parquet").
157-
tokenizer_name: HuggingFace tokenizer name (required if dataset_type="parquet").
303+
tokenizer_name: HuggingFace tokenizer name (required if dataset_type="parquet" or "mmcif").
158304
num_samples: Number of synthetic samples.
305+
cif_dir: Directory with .cif files (required if dataset_type="mmcif").
306+
pdb_ids: Optional list of PDB IDs to filter (for dataset_type="mmcif").
159307
**kwargs: Additional keyword arguments (ignored).
160308
161309
Returns:
@@ -175,6 +323,16 @@ def create_dataloader(
175323
tokenizer=tokenizer,
176324
max_seq_length=max_seq_length,
177325
)
326+
elif dataset_type == "mmcif":
327+
from transformers import EsmTokenizer
328+
329+
tokenizer = EsmTokenizer.from_pretrained(tokenizer_name)
330+
dataset = MmcifStructureDataset(
331+
cif_dir=cif_dir,
332+
tokenizer=tokenizer,
333+
max_seq_length=max_seq_length,
334+
pdb_ids=pdb_ids,
335+
)
178336
else:
179337
raise ValueError(f"Unknown dataset_type: {dataset_type}")
180338

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# ESM2-MiniFold TE: 100-step exploratory run
2+
# 2x RTX 5090, frozen ESM-2 650M, synthetic data
3+
# Usage: torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100
4+
5+
esm_model_name: facebook/esm2_t33_650M_UR50D
6+
7+
num_train_steps: 100
8+
9+
model:
10+
c_s: 1024
11+
c_z: 128
12+
num_blocks: 8
13+
no_bins: 64
14+
use_structure_module: false
15+
num_recycling: 0
16+
17+
dataset:
18+
dataset_type: synthetic
19+
micro_batch_size: 2
20+
max_seq_length: 128
21+
num_workers: 2
22+
num_samples: 1000
23+
24+
optimizer:
25+
folding_lr: 1.0e-4
26+
struct_lr: 1.0e-4
27+
backbone_lr: 3.0e-5
28+
betas: [0.9, 0.98]
29+
eps: 1.0e-8
30+
weight_decay: 0.01
31+
32+
lr_scheduler_kwargs:
33+
num_warmup_steps: 10
34+
num_training_steps: ${num_train_steps}
35+
36+
checkpoint:
37+
ckpt_dir: null
38+
save_final_model: false
39+
resume_from_checkpoint: false
40+
save_every_n_steps: 0
41+
max_checkpoints: 1
42+
43+
mxfp8:
44+
enabled: false
45+
tri_proj: false
46+
tri_gate: false
47+
ffn: false
48+
struct_attn: false
49+
struct_ffn: false
50+
seq_proj: false
51+
dist_head: false
52+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
53+
fp8_recipe_kwargs: {}
54+
55+
logger:
56+
frequency: 5
57+
58+
wandb_init_args:
59+
project: esm2_minifold_te
60+
name: run_100_650M_synthetic
61+
mode: offline
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# ESM2-MiniFold TE: 100-step run with REAL PDB data
2+
# 2x RTX 5090, frozen ESM-2 650M
3+
# Usage: torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100_real
4+
5+
esm_model_name: facebook/esm2_t33_650M_UR50D
6+
7+
num_train_steps: 100
8+
9+
model:
10+
c_s: 1024
11+
c_z: 128
12+
num_blocks: 8
13+
no_bins: 64
14+
use_structure_module: false
15+
num_recycling: 0
16+
17+
dataset:
18+
dataset_type: parquet
19+
parquet_path: data/pdb_structures.parquet
20+
tokenizer_name: ${esm_model_name}
21+
micro_batch_size: 2
22+
max_seq_length: 256
23+
num_workers: 2
24+
num_samples: 1000
25+
26+
optimizer:
27+
folding_lr: 1.0e-4
28+
struct_lr: 1.0e-4
29+
backbone_lr: 3.0e-5
30+
betas: [0.9, 0.98]
31+
eps: 1.0e-8
32+
weight_decay: 0.01
33+
34+
lr_scheduler_kwargs:
35+
num_warmup_steps: 10
36+
num_training_steps: ${num_train_steps}
37+
38+
checkpoint:
39+
ckpt_dir: null
40+
save_final_model: false
41+
resume_from_checkpoint: false
42+
save_every_n_steps: 0
43+
max_checkpoints: 1
44+
45+
mxfp8:
46+
enabled: false
47+
tri_proj: false
48+
tri_gate: false
49+
ffn: false
50+
struct_attn: false
51+
struct_ffn: false
52+
seq_proj: false
53+
dist_head: false
54+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
55+
fp8_recipe_kwargs: {}
56+
57+
logger:
58+
frequency: 5
59+
60+
wandb_init_args:
61+
project: esm2_minifold_te
62+
name: run_100_650M_real_pdb
63+
mode: offline

bionemo-recipes/recipes/esm2_minifold_te/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
biopython>=1.80
12
datasets
23
einops
34
fair-esm

0 commit comments

Comments
 (0)