-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathDataloader.py
83 lines (75 loc) · 3.07 KB
/
Dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import math
import random
import torch
from torch.autograd import Variable
class Dataloader(object):
"""Class to Load Language Pairs and Make Batch
The inputs are generated by sentencepiece (https://github.com/google/sentencepiece.git).
"""
def __init__(self, srcFilename, tgtFilename, batch_size, cuda=False, volatile=False):
# Need to reload every time because memory error in pickle
srcFile = open(srcFilename)
tgtFile = open(tgtFilename)
src = []
tgt = []
nb_pairs = 0
while True:
src_line = srcFile.readline()
tgt_line = tgtFile.readline()
if src_line=='' and tgt_line=='':
break
src_ids = list(map(int, src_line.strip().split()))
tgt_ids = list(map(int, tgt_line.strip().split()))
if 0 in src_ids or 0 in tgt_ids:
continue
if len(src_ids)>0 and len(src_ids)<=64 and len(tgt_ids)>0 and len(tgt_ids)<=64:
src.append(src_ids)
tgt.append(tgt_ids)
nb_pairs += 1
print('%d pairs are converted in the data' %nb_pairs)
srcFile.close()
tgtFile.close()
sorted_idx = sorted(range(nb_pairs), key=lambda i: len(src[i]))
self.src = [src[i] for i in sorted_idx]
self.tgt = [tgt[i] for i in sorted_idx]
self.batch_size = batch_size
self.nb_pairs = nb_pairs
self.nb_batches = math.ceil(nb_pairs/batch_size)
self.cuda = cuda
self.volatile = volatile
def __len__(self):
return self.nb_batches
def _shuffle_index(self, n, m):
"""Yield indexes for shuffling a length n seq within every m elements"""
indexes = []
for i in range(n):
indexes.append(i)
if (i+1)%m ==0 or i==n-1:
random.shuffle(indexes)
for index in indexes:
yield index
indexes = []
def shuffle(self, m):
"""Shuffle the language pairs within every m elements
This will make sure pairs in the same batch still have similr length.
"""
shuffled_indexes = self._shuffle_index(self.nb_pairs, m)
src, tgt = [], []
for index in shuffled_indexes:
src.append(self.src[index])
tgt.append(self.tgt[index])
self.src = src
self.tgt = tgt
def _wrap(self, sentences):
"""Pad sentences to same length and wrap into Variable"""
max_size = max([len(s) for s in sentences])
out = [s + [0]*(max_size-len(s)) for s in sentences]
out = torch.LongTensor(out)
if self.cuda:
out = out.cuda()
return Variable(out, volatile=self.volatile)
def __getitem__(self, i):
"""Generate the i-th batch and wrap in Variable"""
src_batch = self.src[i*self.batch_size:(i+1)*self.batch_size]
tgt_batch = self.tgt[i*self.batch_size:(i+1)*self.batch_size]
return self._wrap(src_batch), self._wrap(tgt_batch)