1616"""Dataset for protein structure prediction training.
1717
1818Provides:
19- - StructureDataset: loads protein structures from parquet/PDB files
2019- SyntheticStructureDataset: generates random data for testing
21- - create_dataloader: creates a DataLoader for training
20+ - ParquetStructureDataset: loads from parquet (pre-processed Ca coords)
21+ - MmcifStructureDataset: loads from mmCIF files on-the-fly via BioPython
22+ - create_dataloader: factory function for any dataset type
2223"""
2324
2425import logging
26+ from pathlib import Path
27+ from typing import ClassVar
2528
2629import torch
2730from torch .utils .data import DataLoader , Dataset , DistributedSampler
@@ -90,6 +93,7 @@ class ParquetStructureDataset(Dataset):
9093 Expected columns:
9194 sequence: str - amino acid sequence
9295 coords: list[list[float]] - Ca coordinates (N, 3)
96+ ca_mask: list[int] - optional, 1=valid Ca, 0=missing (defaults to all-1s)
9397 """
9498
9599 def __init__ (self , parquet_path : str , tokenizer , max_seq_length : int = 256 ):
@@ -98,6 +102,7 @@ def __init__(self, parquet_path: str, tokenizer, max_seq_length: int = 256):
98102 self .df = pd .read_parquet (parquet_path )
99103 self .tokenizer = tokenizer
100104 self .max_seq_length = max_seq_length
105+ self .has_ca_mask = "ca_mask" in self .df .columns
101106
102107 def __len__ (self ):
103108 return len (self .df )
@@ -121,11 +126,150 @@ def __getitem__(self, idx):
121126 mask = attention_mask .float ()
122127
123128 # Coordinates: pad to max_seq_length
124- coords_raw = torch .tensor (row ["coords" ], dtype = torch .float32 )
129+ import numpy as np
130+
131+ # Parquet stores list-of-lists as numpy object array of arrays; np.stack handles this
132+ coords_raw = torch .from_numpy (np .stack (row ["coords" ]).astype (np .float32 ))
125133 coords = torch .zeros (self .max_seq_length , 3 )
126134 seq_len = min (len (coords_raw ), self .max_seq_length )
127135 coords [:seq_len ] = coords_raw [:seq_len ]
128136
137+ # Zero out coords for residues with missing Ca atoms
138+ if self .has_ca_mask :
139+ ca_mask_list = row ["ca_mask" ]
140+ for i in range (min (len (ca_mask_list ), self .max_seq_length )):
141+ if ca_mask_list [i ] == 0 :
142+ coords [i ] = 0.0
143+
144+ return {
145+ "input_ids" : input_ids ,
146+ "attention_mask" : attention_mask ,
147+ "mask" : mask ,
148+ "coords" : coords ,
149+ }
150+
151+
152+ class MmcifStructureDataset (Dataset ):
153+ """Loads protein structures directly from mmCIF files via BioPython.
154+
155+ Parses each .cif file on-the-fly, extracts the amino acid sequence and Ca
156+ coordinates, tokenizes with ESM-2, and returns the standard batch format.
157+ """
158+
159+ # 3-letter to 1-letter amino acid mapping
160+ AA_3TO1 : ClassVar [dict [str , str ]] = {
161+ "ALA" : "A" ,
162+ "CYS" : "C" ,
163+ "ASP" : "D" ,
164+ "GLU" : "E" ,
165+ "PHE" : "F" ,
166+ "GLY" : "G" ,
167+ "HIS" : "H" ,
168+ "ILE" : "I" ,
169+ "LYS" : "K" ,
170+ "LEU" : "L" ,
171+ "MET" : "M" ,
172+ "ASN" : "N" ,
173+ "PRO" : "P" ,
174+ "GLN" : "Q" ,
175+ "ARG" : "R" ,
176+ "SER" : "S" ,
177+ "THR" : "T" ,
178+ "VAL" : "V" ,
179+ "TRP" : "W" ,
180+ "TYR" : "Y" ,
181+ "MSE" : "M" ,
182+ }
183+
184+ def __init__ (self , cif_dir : str , tokenizer , max_seq_length : int = 256 , pdb_ids : list [str ] | None = None ):
185+ from Bio .PDB .MMCIFParser import MMCIFParser
186+
187+ self .tokenizer = tokenizer
188+ self .max_seq_length = max_seq_length
189+ self .parser = MMCIFParser (QUIET = True )
190+
191+ cif_path = Path (cif_dir )
192+ all_files = sorted (cif_path .glob ("*.cif" ))
193+
194+ if pdb_ids is not None :
195+ pdb_set = {pid .upper () for pid in pdb_ids }
196+ self .files = [f for f in all_files if f .stem .upper () in pdb_set ]
197+ else :
198+ self .files = all_files
199+
200+ if not self .files :
201+ raise FileNotFoundError (f"No .cif files found in { cif_dir } " )
202+
203+ logger .info ("MmcifStructureDataset: %d CIF files from %s" , len (self .files ), cif_dir )
204+
205+ def __len__ (self ):
206+ return len (self .files )
207+
208+ def _parse_cif (self , cif_path ):
209+ """Parse mmCIF file and extract sequence + Ca coordinates.
210+
211+ Returns (sequence, ca_coords, ca_mask) or raises on failure.
212+ """
213+ pdb_id = cif_path .stem
214+ structure = self .parser .get_structure (pdb_id , str (cif_path ))
215+ model = structure [0 ]
216+
217+ for chain in model :
218+ sequence = []
219+ coords = []
220+ ca_mask = []
221+
222+ for res in chain .get_residues ():
223+ if res .id [0 ] != " " :
224+ continue
225+ resname = res .get_resname ().strip ()
226+ if resname not in self .AA_3TO1 :
227+ continue
228+
229+ sequence .append (self .AA_3TO1 [resname ])
230+ if "CA" in res :
231+ ca = res ["CA" ].get_vector ()
232+ coords .append ([float (ca [0 ]), float (ca [1 ]), float (ca [2 ])])
233+ ca_mask .append (1 )
234+ else :
235+ coords .append ([0.0 , 0.0 , 0.0 ])
236+ ca_mask .append (0 )
237+
238+ if len (sequence ) >= 20 :
239+ return "" .join (sequence ), coords , ca_mask
240+
241+ raise ValueError (f"No valid protein chain in { pdb_id } " )
242+
243+ def __getitem__ (self , idx ):
244+ try :
245+ sequence , ca_coords , ca_mask = self ._parse_cif (self .files [idx ])
246+ except Exception as e :
247+ logger .warning ("Failed to parse %s: %s, falling back to index 0" , self .files [idx ].name , e )
248+ sequence , ca_coords , ca_mask = self ._parse_cif (self .files [0 ])
249+
250+ # Tokenize
251+ encoded = self .tokenizer (
252+ sequence ,
253+ max_length = self .max_seq_length ,
254+ padding = "max_length" ,
255+ truncation = True ,
256+ return_tensors = "pt" ,
257+ )
258+ input_ids = encoded ["input_ids" ].squeeze (0 )
259+ attention_mask = encoded ["attention_mask" ].squeeze (0 )
260+ mask = attention_mask .float ()
261+
262+ # Coordinates: pad to max_seq_length
263+ coords_raw = torch .tensor (ca_coords , dtype = torch .float32 )
264+ coords = torch .zeros (self .max_seq_length , 3 )
265+ seq_len = min (len (coords_raw ), self .max_seq_length )
266+ coords [:seq_len ] = coords_raw [:seq_len ]
267+
268+ # Zero out missing Ca positions
269+ for i in range (seq_len ):
270+ if ca_mask [i ] == 0 :
271+ coords [i ] = 0.0
272+
129273 return {
130274 "input_ids" : input_ids ,
131275 "attention_mask" : attention_mask ,
@@ -143,6 +287,8 @@ def create_dataloader(
143287 parquet_path : str | None = None ,
144288 tokenizer_name : str | None = None ,
145289 num_samples : int = 1000 ,
290+ cif_dir : str | None = None ,
291+ pdb_ids : list [str ] | None = None ,
146292 ** kwargs ,
147293):
148294 """Create a DataLoader for structure prediction training.
@@ -152,10 +298,12 @@ def create_dataloader(
152298 micro_batch_size: Batch size per GPU.
153299 max_seq_length: Maximum sequence length.
154300 num_workers: Number of DataLoader workers.
155- dataset_type: "synthetic" or "parquet ".
301+ dataset_type: "synthetic", "parquet", or "mmcif ".
156302 parquet_path: Path to parquet file (required if dataset_type="parquet").
157- tokenizer_name: HuggingFace tokenizer name (required if dataset_type="parquet").
303+ tokenizer_name: HuggingFace tokenizer name (required if dataset_type="parquet" or "mmcif" ).
158304 num_samples: Number of synthetic samples.
305+ cif_dir: Directory with .cif files (required if dataset_type="mmcif").
306+ pdb_ids: Optional list of PDB IDs to filter (for dataset_type="mmcif").
159307 **kwargs: Additional keyword arguments (ignored).
160308
161309 Returns:
@@ -175,6 +323,16 @@ def create_dataloader(
175323 tokenizer = tokenizer ,
176324 max_seq_length = max_seq_length ,
177325 )
326+ elif dataset_type == "mmcif" :
327+ from transformers import EsmTokenizer
328+
329+ tokenizer = EsmTokenizer .from_pretrained (tokenizer_name )
330+ dataset = MmcifStructureDataset (
331+ cif_dir = cif_dir ,
332+ tokenizer = tokenizer ,
333+ max_seq_length = max_seq_length ,
334+ pdb_ids = pdb_ids ,
335+ )
178336 else :
179337 raise ValueError (f"Unknown dataset_type: { dataset_type } " )
180338
0 commit comments