Skip to content

Commit c4e88bc

Browse files
committed
adds dataset stuff
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent 44252c4 commit c4e88bc

File tree

3 files changed

+47
-17
lines changed

3 files changed

+47
-17
lines changed

bionemo-recipes/recipes/esm2_minifold_te/dataset.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,19 +181,32 @@ class MmcifStructureDataset(Dataset):
181181
"MSE": "M",
182182
}
183183

184-
def __init__(self, cif_dir: str, tokenizer, max_seq_length: int = 256, pdb_ids: list[str] | None = None):
184+
def __init__(
185+
self,
186+
cif_dir: str,
187+
tokenizer,
188+
max_seq_length: int = 256,
189+
pdb_ids: list[str] | None = None,
190+
min_residues: int = 50,
191+
max_residues: int = 300,
192+
min_ca_completeness: float = 0.9,
193+
):
185194
from Bio.PDB.MMCIFParser import MMCIFParser
186195

187196
self.tokenizer = tokenizer
188197
self.max_seq_length = max_seq_length
198+
self.min_residues = min_residues
199+
self.max_residues = max_residues
200+
self.min_ca_completeness = min_ca_completeness
189201
self.parser = MMCIFParser(QUIET=True)
190202

191203
cif_path = Path(cif_dir)
192204
all_files = sorted(cif_path.glob("*.cif"))
193205

194206
if pdb_ids is not None:
195-
pdb_set = {pid.upper() for pid in pdb_ids}
196-
self.files = [f for f in all_files if f.stem.upper() in pdb_set]
207+
# Preserve caller's ordering (e.g., to match parquet row order)
208+
file_by_id = {f.stem.upper(): f for f in all_files}
209+
self.files = [file_by_id[pid.upper()] for pid in pdb_ids if pid.upper() in file_by_id]
197210
else:
198211
self.files = all_files
199212

@@ -208,24 +221,35 @@ def __len__(self):
208221
def _parse_cif(self, cif_path):
209222
"""Parse mmCIF file and extract sequence + Ca coordinates.
210223
224+
Uses the same filtering as prepare_pdb_dataset.py: min/max residues,
225+
Ca completeness threshold, and truncation to max_residues.
226+
211227
Returns (sequence, ca_coords, ca_mask) or raises on failure.
212228
"""
213229
pdb_id = cif_path.stem
214230
structure = self.parser.get_structure(pdb_id, str(cif_path))
215231
model = structure[0]
216232

217233
for chain in model:
218-
sequence = []
219-
coords = []
220-
ca_mask = []
221-
234+
residues = []
222235
for res in chain.get_residues():
223236
if res.id[0] != " ":
224237
continue
225238
resname = res.get_resname().strip()
226239
if resname not in self.AA_3TO1:
227240
continue
241+
residues.append(res)
242+
243+
if len(residues) < self.min_residues:
244+
continue
245+
if len(residues) > self.max_residues:
246+
residues = residues[: self.max_residues]
228247

248+
sequence = []
249+
coords = []
250+
ca_mask = []
251+
for res in residues:
252+
resname = res.get_resname().strip()
229253
sequence.append(self.AA_3TO1[resname])
230254
if "CA" in res:
231255
ca = res["CA"].get_vector()
@@ -235,8 +259,11 @@ def _parse_cif(self, cif_path):
235259
coords.append([0.0, 0.0, 0.0])
236260
ca_mask.append(0)
237261

238-
if len(sequence) >= 20:
239-
return "".join(sequence), coords, ca_mask
262+
completeness = sum(ca_mask) / len(ca_mask)
263+
if completeness < self.min_ca_completeness:
264+
continue
265+
266+
return "".join(sequence), coords, ca_mask
240267

241268
raise ValueError(f"No valid protein chain in {pdb_id}")
242269

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# ESM2-MiniFold TE: 100-step run with REAL PDB data
22
# 2x RTX 5090, frozen ESM-2 650M
3-
# Usage: torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100_real
3+
# Usage:
4+
# Parquet (default, faster): torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100_real
5+
# MmCIF (on-the-fly parsing): torchrun --nproc_per_node=2 train_fsdp2.py --config-name run_100_real dataset.dataset_type=mmcif
46

57
esm_model_name: facebook/esm2_t33_650M_UR50D
68

@@ -15,8 +17,9 @@ model:
1517
num_recycling: 0
1618

1719
dataset:
18-
dataset_type: parquet
20+
dataset_type: parquet # "parquet" (fast, pre-processed) or "mmcif" (on-the-fly BioPython parsing)
1921
parquet_path: data/pdb_structures.parquet
22+
cif_dir: data/cif_files
2023
tokenizer_name: ${esm_model_name}
2124
micro_batch_size: 2
2225
max_seq_length: 256

bionemo-recipes/recipes/esm2_minifold_te/tests/test_data_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,36 +179,36 @@ def test_first_residue_is_threonine(self, parsed_data):
179179

180180
class TestMmcifStructureDataset:
181181
def test_batch_keys(self, cif_dir, tokenizer):
182-
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH)
182+
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20)
183183
sample = ds[0]
184184
assert set(sample.keys()) == {"input_ids", "attention_mask", "mask", "coords"}
185185

186186
def test_batch_shapes(self, cif_dir, tokenizer):
187-
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH)
187+
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20)
188188
sample = ds[0]
189189
assert sample["input_ids"].shape == (MAX_SEQ_LENGTH,)
190190
assert sample["attention_mask"].shape == (MAX_SEQ_LENGTH,)
191191
assert sample["mask"].shape == (MAX_SEQ_LENGTH,)
192192
assert sample["coords"].shape == (MAX_SEQ_LENGTH, 3)
193193

194194
def test_batch_dtypes(self, cif_dir, tokenizer):
195-
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH)
195+
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20)
196196
sample = ds[0]
197197
assert sample["input_ids"].dtype == torch.long
198198
assert sample["attention_mask"].dtype == torch.long
199199
assert sample["mask"].dtype == torch.float32
200200
assert sample["coords"].dtype == torch.float32
201201

202202
def test_cls_eos_tokens(self, cif_dir, tokenizer):
203-
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH)
203+
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20)
204204
sample = ds[0]
205205
assert sample["input_ids"][0].item() == 0, "First token should be CLS (0)"
206206
# Find EOS position
207207
real_len = sample["attention_mask"].sum().item()
208208
assert sample["input_ids"][int(real_len) - 1].item() == 2, "Last real token should be EOS (2)"
209209

210210
def test_padding_is_zero(self, cif_dir, tokenizer):
211-
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH)
211+
ds = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20)
212212
sample = ds[0]
213213
real_len = sample["attention_mask"].sum().item()
214214
assert (sample["attention_mask"][int(real_len) :] == 0).all()
@@ -248,7 +248,7 @@ class TestDatasetEquivalence:
248248
"""Both datasets must produce matching outputs for the same protein."""
249249

250250
def _get_samples(self, cif_dir, parquet_path, tokenizer):
251-
ds_cif = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH)
251+
ds_cif = MmcifStructureDataset(cif_dir, tokenizer, max_seq_length=MAX_SEQ_LENGTH, min_residues=20)
252252
ds_pq = ParquetStructureDataset(parquet_path, tokenizer, max_seq_length=MAX_SEQ_LENGTH)
253253
return ds_cif[0], ds_pq[0]
254254

0 commit comments

Comments
 (0)