|
| 1 | +import torch |
| 2 | +from collections import Counter |
| 3 | +from torch.utils.data import Dataset, DataLoader |
| 4 | +from datasets import load_dataset |
| 5 | +import spacy |
| 6 | +from torch.nn.utils.rnn import pad_sequence |
| 7 | +from tqdm import tqdm |
| 8 | +import random |
| 9 | +random.seed(0) |
| 10 | + |
| 11 | +MAX_LEN = 300 |
| 12 | +MAX_STORIES = 250000 |
| 13 | + |
| 14 | +class Vocabulary: |
| 15 | + def __init__(self, corpus, tokenizer): |
| 16 | + self.tokenizer = tokenizer |
| 17 | + self.word2idx, self.idx2word = self.build_vocab(corpus) |
| 18 | + |
| 19 | + def __len__(self): |
| 20 | + return len(self.word2idx) |
| 21 | + |
| 22 | + def text2idx(self, text): |
| 23 | + tokens = [str(x).strip().lower() for x in self.tokenizer(text)] |
| 24 | + return [self.word2idx[t] if t in self.word2idx.keys() else self.word2idx['<UNK>'] for t in tokens] |
| 25 | + |
| 26 | + def idx2text(self, idxs): |
| 27 | + return [self.idx2word[i] if i in self.idx2word.keys() else '<UNK>' for i in idxs] |
| 28 | + |
| 29 | + |
| 30 | + def build_vocab(self,corpus): |
| 31 | + cntr = Counter() |
| 32 | + for datapoint in tqdm(corpus): |
| 33 | + cntr.update( [str(x).strip().lower() for x in self.tokenizer(datapoint)] ) |
| 34 | + |
| 35 | + tokens = [t for t,c in cntr.items() if c >= 30] |
| 36 | + word2idx = {t:i+4 for i,t in enumerate(tokens)} |
| 37 | + idx2word = {i+4:t for i,t in enumerate(tokens)} |
| 38 | + |
| 39 | + word2idx['<PAD>'] = 0 #add padding token |
| 40 | + idx2word[0] = '<PAD>' |
| 41 | + |
| 42 | + word2idx['<SOS>'] = 1 #add padding token |
| 43 | + idx2word[1] = '<SOS>' |
| 44 | + |
| 45 | + word2idx['<EOS>'] = 2 #add padding token |
| 46 | + idx2word[2] = '<EOS>' |
| 47 | + |
| 48 | + word2idx['<UNK>'] = 3 #add padding token |
| 49 | + idx2word[3] = '<UNK>' |
| 50 | + |
| 51 | + |
| 52 | + return word2idx, idx2word |
| 53 | + |
| 54 | +class TinyStories(Dataset): |
| 55 | + |
| 56 | + def __init__(self,split="train", vocab = None): |
| 57 | + |
| 58 | + print("Loading data...") |
| 59 | + dataset = load_dataset("roneneldan/TinyStories", split=split) |
| 60 | + self.data = [x["text"] for x in random.sample(list(dataset), MAX_STORIES)] |
| 61 | + |
| 62 | + |
| 63 | + if vocab == None: |
| 64 | + print("Building vocab...") |
| 65 | + self.vocab = Vocabulary(self.data, spacy.load('en_core_web_sm').tokenizer) |
| 66 | + else: |
| 67 | + self.vocab = vocab |
| 68 | + |
| 69 | + def __len__(self): |
| 70 | + return len(self.data) |
| 71 | + |
| 72 | + def __getitem__(self, idx): |
| 73 | + x = self.vocab.text2idx(self.data[idx]) |
| 74 | + l = min(MAX_LEN, len(x)) |
| 75 | + numeralized = [self.vocab.word2idx['<SOS>']]+x[:l]+[self.vocab.word2idx['<EOS>']] |
| 76 | + return torch.tensor(numeralized) |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def pad_collate(batch): |
| 80 | + xx_pad = pad_sequence(batch, batch_first=True, padding_value=0) |
| 81 | + |
| 82 | + return xx_pad |
| 83 | + |
| 84 | +def getTinyStoriesDataloadersAndVocab(batch_size=128): |
| 85 | + train = TinyStories(split="train") |
| 86 | + |
| 87 | + collate = TinyStories.pad_collate |
| 88 | + train_loader = DataLoader(train, batch_size=batch_size, num_workers=8, shuffle=True, collate_fn=collate, drop_last=True) |
| 89 | + |
| 90 | + return train_loader, train.vocab |
| 91 | + |
| 92 | + |
0 commit comments