|
| 1 | +"""Usage example: Torch-compatible transformations with PyTorch DataLoader.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.utils.data import Dataset, DataLoader |
| 5 | + |
| 6 | +# Example assumes PR #246 is merged |
| 7 | +# from pyaptamer.trafos.torch import GreedyEncode, RandomMask, DNAtoRNA |
| 8 | + |
| 9 | + |
| 10 | +class AptamerDataset(Dataset): |
| 11 | + """Minimal example: Sequence dataset with torch transforms.""" |
| 12 | + |
| 13 | + def __init__(self, sequences, vocab, max_len=128, augment=True): |
| 14 | + """Initialize with sequences and vocabulary. |
| 15 | + |
| 16 | + Parameters |
| 17 | + ---------- |
| 18 | + sequences : list[str] |
| 19 | + DNA/RNA sequences |
| 20 | + vocab : dict[str, int] |
| 21 | + Token to ID mapping |
| 22 | + max_len : int |
| 23 | + Padded sequence length |
| 24 | + augment : bool |
| 25 | + Apply random masking augmentation |
| 26 | + """ |
| 27 | + self.sequences = sequences |
| 28 | + # from pyaptamer.trafos.torch import GreedyEncode, RandomMask |
| 29 | + self.encoder = None # GreedyEncode(vocab=vocab, max_len=max_len) |
| 30 | + self.masker = None # RandomMask(mask_idx=999, mask_rate=0.15) if augment else None |
| 31 | + |
| 32 | + def __len__(self): |
| 33 | + return len(self.sequences) |
| 34 | + |
| 35 | + def __getitem__(self, idx): |
| 36 | + """Return encoded and optionally masked sequence.""" |
| 37 | + seq = self.sequences[idx] |
| 38 | + # encoded = self.encoder(seq) |
| 39 | + # if self.masker: |
| 40 | + # encoded = self.masker(encoded) |
| 41 | + # return { |
| 42 | + # 'input_ids': encoded, |
| 43 | + # 'attention_mask': (encoded != 0).long() |
| 44 | + # } |
| 45 | + return {'input_ids': torch.zeros(128), 'attention_mask': torch.ones(128)} |
| 46 | + |
| 47 | + |
| 48 | +# Usage in training loop |
| 49 | +if __name__ == "__main__": |
| 50 | + # Define vocabulary |
| 51 | + vocab = {"A": 1, "T": 2, "C": 3, "G": 4, "AT": 5, "GC": 6} |
| 52 | + |
| 53 | + # Create dataset |
| 54 | + sequences = ["ATGCTAGC", "GGCCTTAA", "ATATATAA"] |
| 55 | + dataset = AptamerDataset(sequences, vocab=vocab, max_len=64, augment=True) |
| 56 | + |
| 57 | + # Create DataLoader for batched training |
| 58 | + dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0) |
| 59 | + |
| 60 | + # Training loop pattern |
| 61 | + for batch in dataloader: |
| 62 | + input_ids = batch['input_ids'] # Shape: (batch_size, max_len) |
| 63 | + attention_mask = batch['attention_mask'] |
| 64 | + |
| 65 | + # Model forward pass (example): |
| 66 | + # outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
| 67 | + # loss = criterion(outputs, targets) |
| 68 | + # optimizer.step() |
| 69 | + |
| 70 | + print(f"✓ Batch shape: {input_ids.shape}, dtype: {input_ids.dtype}") |
| 71 | + break # Show first batch only |
0 commit comments