Skip to content

Commit 669edde

Browse files
docs: Add PyTorch DataLoader integration example for torch transforms
- Demonstrates practical usage pattern with AptamerDataset - Shows how to chain GreedyEncode + RandomMask with DataLoader - Provides pattern for batch processing in training loops
1 parent 0e980af commit 669edde

1 file changed

Lines changed: 71 additions & 0 deletions

File tree

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Usage example: Torch-compatible transformations with PyTorch DataLoader."""
2+
3+
import torch
4+
from torch.utils.data import Dataset, DataLoader
5+
6+
# Example assumes PR #246 is merged
7+
# from pyaptamer.trafos.torch import GreedyEncode, RandomMask, DNAtoRNA
8+
9+
10+
class AptamerDataset(Dataset):
11+
"""Minimal example: Sequence dataset with torch transforms."""
12+
13+
def __init__(self, sequences, vocab, max_len=128, augment=True):
14+
"""Initialize with sequences and vocabulary.
15+
16+
Parameters
17+
----------
18+
sequences : list[str]
19+
DNA/RNA sequences
20+
vocab : dict[str, int]
21+
Token to ID mapping
22+
max_len : int
23+
Padded sequence length
24+
augment : bool
25+
Apply random masking augmentation
26+
"""
27+
self.sequences = sequences
28+
# from pyaptamer.trafos.torch import GreedyEncode, RandomMask
29+
self.encoder = None # GreedyEncode(vocab=vocab, max_len=max_len)
30+
self.masker = None # RandomMask(mask_idx=999, mask_rate=0.15) if augment else None
31+
32+
def __len__(self):
33+
return len(self.sequences)
34+
35+
def __getitem__(self, idx):
36+
"""Return encoded and optionally masked sequence."""
37+
seq = self.sequences[idx]
38+
# encoded = self.encoder(seq)
39+
# if self.masker:
40+
# encoded = self.masker(encoded)
41+
# return {
42+
# 'input_ids': encoded,
43+
# 'attention_mask': (encoded != 0).long()
44+
# }
45+
return {'input_ids': torch.zeros(128), 'attention_mask': torch.ones(128)}
46+
47+
48+
# Usage in training loop
49+
if __name__ == "__main__":
50+
# Define vocabulary
51+
vocab = {"A": 1, "T": 2, "C": 3, "G": 4, "AT": 5, "GC": 6}
52+
53+
# Create dataset
54+
sequences = ["ATGCTAGC", "GGCCTTAA", "ATATATAA"]
55+
dataset = AptamerDataset(sequences, vocab=vocab, max_len=64, augment=True)
56+
57+
# Create DataLoader for batched training
58+
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
59+
60+
# Training loop pattern
61+
for batch in dataloader:
62+
input_ids = batch['input_ids'] # Shape: (batch_size, max_len)
63+
attention_mask = batch['attention_mask']
64+
65+
# Model forward pass (example):
66+
# outputs = model(input_ids=input_ids, attention_mask=attention_mask)
67+
# loss = criterion(outputs, targets)
68+
# optimizer.step()
69+
70+
print(f"✓ Batch shape: {input_ids.shape}, dtype: {input_ids.dtype}")
71+
break # Show first batch only

0 commit comments

Comments
 (0)