-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathutils.py
1281 lines (1056 loc) · 50.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import itertools
import json
import linecache
import math
import os
import pickle
import socket
import glob
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler, RandomSampler, DataLoader
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer, BertTokenizer, RobertaTokenizer
from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right
try:
from fairseq.data.data_utils import batch_by_size
FAIRSEQ_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
FAIRSEQ_AVAILABLE = False
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
"""Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer, data_args) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
predictions = pred.predictions
label_ids = pred.label_ids
predictions[..., 0] = tokenizer.pad_token_id
predictions[predictions == -100] = tokenizer.pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(
pred_str, label_str,
rouge_lang=data_args.rouge_lang,
)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class MultiDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
):
super().__init__()
assert "upsampling_factor" in dataset_kwargs, "upsampling_factor required"
assert "total_batch_size" in dataset_kwargs, "total_batch_size required"
assert "actual_batch_size" in dataset_kwargs, "actual_batch_size required"
assert "gradient_accum" in dataset_kwargs, "gradient_accum required"
assert "is_distributed" in dataset_kwargs, "is_distributed required"
assert "dataset_class" in dataset_kwargs, "dataset_class required"
self.dataloaders = []
self.total_batch_size = dataset_kwargs.pop("total_batch_size")
dataset_class = dataset_kwargs.pop("dataset_class")
extension = "tokenized" if dataset_class == TokenizedDataset else "source"
# identify all source training files
datasets = []
delimiter = f'_{type_path}.{extension}'
for src_file in glob.glob(os.path.join(data_dir, f'*{type_path}.{extension}')):
id = os.path.basename(src_file).rsplit(delimiter, 1)[0]
type_path = "".join(os.path.basename(src_file).rsplit(f".{extension}", 1))
dataset = dataset_class(
tokenizer,
type_path=type_path,
data_dir=data_dir,
n_obs=n_obs,
max_target_length=max_target_length,
max_source_length=max_source_length,
prefix=prefix,
data_id=id
)
datasets.append(dataset)
train_sampler = RandomSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=1,
sampler=train_sampler,
collate_fn=lambda batch: batch
)
self.dataloaders.append(dataloader)
assert len(self.dataloaders) > 1, "multiple source/target filepairs required for MultiDataset"
# compute effective length of this dataset and the sampling probabilities
logger.info(f"Found datasets: {len(self.dataloaders)}")
upsampling_factor = dataset_kwargs.get("upsampling_factor")
datapoint_counts = np.array([len(dataset) for dataset in datasets])
logger.info(f"Total datapoints: {np.sum(datapoint_counts)}")
datapoint_probs = datapoint_counts / datapoint_counts.sum()
smoothed_probs = datapoint_probs ** upsampling_factor
self.sampling_probs = smoothed_probs / smoothed_probs.sum()
self.effective_length = int(np.sum(datapoint_counts * self.sampling_probs))
self.iterators = [iter(dataloader) for dataloader in self.dataloaders]
is_distributed = dataset_kwargs.get("is_distributed")
actual_batch_size = dataset_kwargs.get("actual_batch_size")
gradient_accum = dataset_kwargs.get("gradient_accum")
self.per_gpu_effective_batch_size = actual_batch_size * gradient_accum
rank = int(os.environ.get("RANK")) if is_distributed else -1
self.pos_shift_count = rank * self.per_gpu_effective_batch_size
logger.info(f'Rank: {rank}, shifting required: {self.pos_shift_count}')
self.current_dataset_idx = -1
self.current_loader_count = 0
def shift_iterator(self, idx, shift_count):
if shift_count <= 0:
return
iterator = self.iterators[idx]
for _ in range(shift_count):
try:
next(iterator)
except StopIteration:
dataloader = self.dataloaders[idx]
iterator = iter(dataloader)
self.iterators[idx] = iterator
def __len__(self):
return self.effective_length
def __getitem__(self, index):
if self.current_loader_count == 0:
self.current_dataset_idx = np.random.choice(range(len(self.dataloaders)), p=self.sampling_probs)
# start of a new effective batch, shift to appropriate pos
self.shift_iterator(self.current_dataset_idx, self.pos_shift_count)
iterator = self.iterators[self.current_dataset_idx]
self.current_loader_count = (self.current_loader_count + 1) % self.total_batch_size
try:
datapoint = next(iterator)
except StopIteration:
dataloader = self.dataloaders[self.current_dataset_idx]
self.iterators[self.current_dataset_idx] = iter(dataloader)
datapoint = next(self.iterators[self.current_dataset_idx])
if self.current_loader_count == self.per_gpu_effective_batch_size:
# taken allocated datapoints from this effective batch, move to the start of next batch
self.shift_iterator(self.current_dataset_idx, self.total_batch_size - self.current_loader_count - self.pos_shift_count)
self.current_loader_count = 0
return datapoint[0]
class UnistageCrosslingualDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
):
super().__init__()
assert "per_lang_batch_size" in dataset_kwargs, "per_lang_batch_size required"
assert "upsampling_factor" in dataset_kwargs, "upsampling_factor required"
assert "total_batch_size" in dataset_kwargs, "total_batch_size required"
assert "actual_batch_size" in dataset_kwargs, "actual_batch_size required"
assert "gradient_accum" in dataset_kwargs, "gradient_accum required"
assert "is_distributed" in dataset_kwargs, "is_distributed required"
assert "dataset_class" in dataset_kwargs, "dataset_class required"
logger.info("Using cross lingual dataset with unistage sampling")
self.dataloaders = []
self.total_batch_size = dataset_kwargs.pop("total_batch_size")
dataset_class = dataset_kwargs.pop("dataset_class")
extension = "tokenized" if dataset_class == TokenizedDataset else "source"
# identify all source training files
datasets = []
delimiter = f'_{type_path}.{extension}'
for src_file in glob.glob(os.path.join(data_dir, f'*{type_path}.{extension}')):
id = os.path.basename(src_file).rsplit(delimiter, 1)[0]
tgt_lang = id.split("-")[1]
type_path = "".join(os.path.basename(src_file).rsplit(f".{extension}", 1))
dataset = dataset_class(
tokenizer,
type_path=type_path,
data_dir=data_dir,
n_obs=n_obs,
max_target_length=max_target_length,
max_source_length=max_source_length,
prefix=prefix,
data_id=tgt_lang
)
datasets.append(dataset)
train_sampler = RandomSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=1,
sampler=train_sampler,
collate_fn=lambda batch: batch
)
self.dataloaders.append(dataloader)
# assert len(self.dataloaders) > 1, "multiple source/target filepairs required for MultiDataset"
# compute effective length of this dataset and the sampling probabilities
logger.info(f"Found datasets: {len(self.dataloaders)}")
upsampling_factor = dataset_kwargs.get("upsampling_factor")
datapoint_counts = np.array([len(dataset) for dataset in datasets])
logger.info(f"Total datapoints: {np.sum(datapoint_counts)}")
datapoint_probs = datapoint_counts / datapoint_counts.sum()
smoothed_probs = datapoint_probs ** upsampling_factor
self.sampling_probs = smoothed_probs / smoothed_probs.sum()
self.effective_length = int(np.sum(datapoint_counts * self.sampling_probs))
self.iterators = [iter(dataloader) for dataloader in self.dataloaders]
is_distributed = dataset_kwargs.get("is_distributed")
actual_batch_size = dataset_kwargs.get("actual_batch_size")
gradient_accum = dataset_kwargs.get("gradient_accum")
self.per_lang_batch_size = dataset_kwargs.get("per_lang_batch_size")
self.per_gpu_effective_batch_size = actual_batch_size * gradient_accum
self.per_gpu_lang_batch_size = self.per_lang_batch_size // (self.total_batch_size // self.per_gpu_effective_batch_size)
assert self.total_batch_size % self.per_lang_batch_size == 0, "total_batch_size must be divisible by per_lang_batch_size"
rank = int(os.environ.get("RANK")) if is_distributed else -1
self.pos_shift_count = rank * self.per_gpu_lang_batch_size
logger.info(f'Rank: {rank}, shifting required: {self.pos_shift_count}')
logger.info(f"Effective length: {self.effective_length}")
self.current_dataset_idx = -1
self.current_loader_count = 0
def shift_iterator(self, idx, shift_count):
if shift_count <= 0:
return
iterator = self.iterators[idx]
for _ in range(shift_count):
try:
next(iterator)
except StopIteration:
dataloader = self.dataloaders[idx]
iterator = iter(dataloader)
self.iterators[idx] = iterator
def __len__(self):
return self.effective_length
def __getitem__(self, index):
if self.current_loader_count == 0:
self.current_dataset_idx = np.random.choice(range(len(self.dataloaders)), p=self.sampling_probs)
# start of a new effective batch, shift to appropriate pos
self.shift_iterator(self.current_dataset_idx, self.pos_shift_count)
iterator = self.iterators[self.current_dataset_idx]
self.current_loader_count = (self.current_loader_count + 1) % self.per_gpu_lang_batch_size
try:
datapoint = next(iterator)
except StopIteration:
dataloader = self.dataloaders[self.current_dataset_idx]
self.iterators[self.current_dataset_idx] = iter(dataloader)
datapoint = next(self.iterators[self.current_dataset_idx])
if self.current_loader_count == 0:
# taken allocated datapoints from this effective batch, move to the start of next batch
self.shift_iterator(self.current_dataset_idx, self.per_lang_batch_size - self.per_gpu_lang_batch_size - self.pos_shift_count)
return datapoint[0]
class CrosslingualDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
):
super().__init__()
assert "multistage_upsampling_factors" in dataset_kwargs, "multistage_upsampling_factors required"
assert "per_lang_batch_size" in dataset_kwargs, "per_lang_batch_size required"
assert "total_batch_size" in dataset_kwargs, "total_batch_size required"
assert "actual_batch_size" in dataset_kwargs, "actual_batch_size required"
assert "gradient_accum" in dataset_kwargs, "gradient_accum required"
assert "is_distributed" in dataset_kwargs, "is_distributed required"
assert "dataset_class" in dataset_kwargs, "dataset_class required"
assert "minibatching" in dataset_kwargs, "minibatching required"
logger.info("Using cross lingual dataset")
self.dataloaders = {}
self.total_batch_size = dataset_kwargs.pop("total_batch_size")
dataset_class = dataset_kwargs.pop("dataset_class")
extension = "tokenized" if dataset_class == TokenizedDataset else "source"
self.minibatching = dataset_kwargs.pop("minibatching")
# identify all source training files
forward_datapoint_counts = {}
reverse_datapoint_counts = {}
delimiter = f'_{type_path}.{extension}'
# identify all languages first
data_langs = set()
for src_file in glob.glob(os.path.join(data_dir, f'*{type_path}.{extension}')):
id = os.path.basename(src_file).rsplit(delimiter, 1)[0]
data_langs.update(id.split("-"))
total_dataset_count = 0
# now create datasets according to language pairs
for src_lang in data_langs:
self.dataloaders[src_lang] = {}
forward_datapoint_counts[src_lang] = {}
for tgt_lang in data_langs:
src_file = os.path.join(
data_dir,
f"{src_lang}-{tgt_lang}{delimiter}"
)
if not os.path.isfile(src_file):
continue
type_path = "".join(os.path.basename(src_file).rsplit(f".{extension}", 1))
dataset = dataset_class(
tokenizer,
type_path=type_path,
data_dir=data_dir,
n_obs=n_obs,
max_target_length=max_target_length,
max_source_length=max_source_length,
prefix=prefix,
data_id=tgt_lang
)
train_sampler = RandomSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=1,
sampler=train_sampler,
collate_fn=lambda batch: batch
)
self.dataloaders[src_lang][tgt_lang] = dataloader
forward_datapoint_counts[src_lang][tgt_lang] = len(dataset)
source_data = reverse_datapoint_counts.get(tgt_lang, {})
source_data[src_lang] = len(dataset)
reverse_datapoint_counts[tgt_lang] = source_data
total_dataset_count += 1
if not self.dataloaders[src_lang]:
self.dataloaders.pop(src_lang)
forward_datapoint_counts.pop(src_lang)
# compute effective length of this dataset and the sampling probabilities
logger.info(f"Found datasets: {total_dataset_count}")
multistage_upsampling_factors = dataset_kwargs.get("multistage_upsampling_factors")
logger.info(f"Total datapoints: {np.sum(v for l in forward_datapoint_counts.values() for v in l.values())}")
# avoiding matrix operations for easy traversing later
# first stage datapoint probs
def get_sampling_probs(data_matrix):
first_stage_datapoint_counts = {
k: np.sum(list(v.values())) for k, v in data_matrix.items()
}
first_stage_datapoint_probs = {
k: l ** multistage_upsampling_factors[0]
for k, l in zip(
first_stage_datapoint_counts.keys(),
list(first_stage_datapoint_counts.values()) / np.sum(list(first_stage_datapoint_counts.values()))
)
}
first_stage_sampling_probs = {
k: l for k, l in zip(
first_stage_datapoint_probs.keys(),
list(first_stage_datapoint_probs.values()) / np.sum(list(first_stage_datapoint_probs.values()))
)
}
second_stage_sampling_probs = {}
langwise_effective_lengths = []
# second stage datapoint probs
for src_lang, tgt_data_stats in data_matrix.items():
lang_datapoint_probs = {
k: l ** multistage_upsampling_factors[1]
for k, l in zip(
tgt_data_stats.keys(),
list(tgt_data_stats.values()) / np.sum(list(tgt_data_stats.values()))
)
}
lang_sampling_probs = {
k: l for k, l in zip(
lang_datapoint_probs.keys(),
list(lang_datapoint_probs.values()) / np.sum(list(lang_datapoint_probs.values()))
)
}
second_stage_sampling_probs[src_lang] = lang_sampling_probs
langwise_effective_lengths.append(
np.sum(
np.array(list(tgt_data_stats.values())) * np.array(list(lang_sampling_probs.values()))
)
)
effective_length = int(
np.sum(
langwise_effective_lengths * np.array(list(first_stage_sampling_probs.values()))
)
)
return first_stage_sampling_probs, second_stage_sampling_probs, effective_length
(
self.src_first_stage_sampling_probs,
self.src_second_stage_sampling_probs,
src_effective_length
) = get_sampling_probs(forward_datapoint_counts)
(
self.tgt_first_stage_sampling_probs,
self.tgt_second_stage_sampling_probs,
tgt_effective_length
) = get_sampling_probs(reverse_datapoint_counts)
self.effective_length = (src_effective_length + tgt_effective_length) // 2
self.iterators = {}
for src_lang, tgt_dataloaders in self.dataloaders.items():
self.iterators[src_lang] = {k: iter(l) for k, l in tgt_dataloaders.items()}
is_distributed = dataset_kwargs.get("is_distributed")
actual_batch_size = dataset_kwargs.get("actual_batch_size")
gradient_accum = dataset_kwargs.get("gradient_accum")
# for choosing whether src or tgt will be fixed when taking a batch
self.choice = -1
self.per_lang_batch_size = dataset_kwargs.get("per_lang_batch_size")
if self.minibatching is not None:
if self.minibatching == "ignored":
self.per_lang_batch_size = self.total_batch_size
elif self.minibatching == "fixed_src":
self.choice = 0
elif self.minibatching == "fixed_tgt":
self.choice = 1
self.per_gpu_effective_batch_size = actual_batch_size * gradient_accum
self.per_gpu_lang_batch_size = self.per_lang_batch_size // (self.total_batch_size // self.per_gpu_effective_batch_size)
assert self.total_batch_size % self.per_lang_batch_size == 0, "total_batch_size must be divisible by per_lang_batch_size"
rank = int(os.environ.get("RANK")) if is_distributed else -1
self.pos_shift_count = rank * self.per_gpu_lang_batch_size
logger.info(f'Rank: {rank}, shifting required: {self.pos_shift_count}')
logger.info(f"Effective length: {self.effective_length}")
logger.info(f"Minibatching type:" + str(self.minibatching))
self.current_src_lang = None
self.current_tgt_lang = None
self.current_src_loader_count = 0
self.current_tgt_loader_count = 0
def shift_iterator(self, src_lang, tgt_lang, shift_count):
if shift_count <= 0:
return
iterator = self.iterators[src_lang][tgt_lang]
for _ in range(shift_count):
try:
next(iterator)
except StopIteration:
dataloader = self.dataloaders[src_lang][tgt_lang]
iterator = iter(dataloader)
self.iterators[src_lang][tgt_lang] = iterator
def __len__(self):
return self.effective_length
def __getitem__(self, index):
if (
not (self.minibatching is not None and self.minibatching.startswith("fixed"))
and (self.current_src_loader_count + self.current_tgt_loader_count == 0)
):
self.choice = np.random.choice([0, 1], p=[0.5, 0.5])
if self.choice == 0:
# we keep src fixed
if self.current_src_loader_count == 0:
self.current_src_lang = np.random.choice(
list(self.src_first_stage_sampling_probs.keys()),
p=list(self.src_first_stage_sampling_probs.values())
)
if self.current_tgt_loader_count == 0:
self.current_tgt_lang = np.random.choice(
list(self.src_second_stage_sampling_probs[self.current_src_lang].keys()),
p=list(self.src_second_stage_sampling_probs[self.current_src_lang].values())
)
# start of a new effective batch, shift to appropriate pos
self.shift_iterator(
self.current_src_lang,
self.current_tgt_lang,
self.pos_shift_count
)
self.current_src_loader_count = (self.current_src_loader_count + 1) % self.total_batch_size
self.current_tgt_loader_count = (self.current_tgt_loader_count + 1) % self.per_gpu_lang_batch_size
elif self.choice == 1:
# we keep tgt fixed
if self.current_tgt_loader_count == 0:
self.current_tgt_lang = np.random.choice(
list(self.tgt_first_stage_sampling_probs.keys()),
p=list(self.tgt_first_stage_sampling_probs.values())
)
if self.current_src_loader_count == 0:
self.current_src_lang = np.random.choice(
list(self.tgt_second_stage_sampling_probs[self.current_tgt_lang].keys()),
p=list(self.tgt_second_stage_sampling_probs[self.current_tgt_lang].values())
)
# start of a new effective batch, shift to appropriate pos
self.shift_iterator(
self.current_src_lang,
self.current_tgt_lang,
self.pos_shift_count
)
self.current_tgt_loader_count = (self.current_tgt_loader_count + 1) % self.total_batch_size
self.current_src_loader_count = (self.current_src_loader_count + 1) % self.per_gpu_lang_batch_size
iterator = self.iterators[self.current_src_lang][self.current_tgt_lang]
try:
datapoint = next(iterator)
except StopIteration:
dataloader = self.dataloaders[self.current_src_lang][self.current_tgt_lang]
self.iterators[self.current_src_lang][self.current_tgt_lang] = iter(dataloader)
datapoint = next(self.iterators[self.current_src_lang][self.current_tgt_lang])
if (
(self.choice == 0 and self.current_tgt_loader_count == 0) or
(self.choice == 1 and self.current_src_loader_count == 0)
):
# taken allocated datapoints from this effective batch, move to the start of next batch
self.shift_iterator(
self.current_src_lang,
self.current_tgt_lang,
self.per_lang_batch_size - self.per_gpu_lang_batch_size - self.pos_shift_count
)
return datapoint[0]
class TokenizedDataset(Dataset):
"""Dataset to load tokenized data. Backwards compatible with AbstractSeq2SeqDataset"""
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".tokenized")
self.length = self.get_lc(self.src_file)
if n_obs is not None and n_obs != -1:
self.length = min(n_obs, self.length)
def __len__(self):
return self.length
def __getitem__(self, index):
index = index + 1 # linecache starts at 1
source_line = linecache.getline(str(self.src_file), index).rstrip("\n")
return json.loads(source_line)
@staticmethod
def get_lc(data_file):
with open(data_file) as f:
for lc, _ in enumerate(f, 1):
pass
return lc
class TokenizedDataCollator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, batch) -> Dict[str, torch.Tensor]:
processed_batch = {}
for k in batch[0]:
processed_batch[k] = torch.stack(
[torch.tensor(x[k]).squeeze() for x in batch]
)
return processed_batch
class AbstractSeq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.len_file = Path(data_dir).joinpath(type_path + ".len")
if os.path.exists(self.len_file):
self.src_lens = pickle_load(self.len_file)
self.used_char_len = False
else:
self.src_lens = self.get_char_lens(self.src_file)
self.used_char_len = True
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix if prefix is not None else ""
if n_obs is not None and n_obs != -1:
self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id
self.dataset_kwargs = dataset_kwargs
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
def __len__(self):
return len(self.src_lens)
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
@cached_property
def tgt_lens(self):
"""Length in characters of target documents"""
return self.get_char_lens(self.tgt_file)
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
def num_tokens_in_example(i):
return min(self.src_lens[i], self.max_target_length)
# call fairseq cython function
batch_sampler: List[List[int]] = batch_by_size(
sorted_indices,
num_tokens_fn=num_tokens_in_example,
max_tokens=max_tokens_per_batch,
required_batch_size_multiple=64,
)
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
largest_batch_idx = np.argmax(approximate_toks_per_batch)
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
shuffled_batches[largest_batch_idx],
shuffled_batches[0],
)
return shuffled_batches
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
def collate_fn(self, batch):
raise NotImplementedError("You must implement this")
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"labels": target_ids,
}
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**self.dataset_kwargs,
)
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"labels": y,
}
return batch
class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""A dataset that calls prepare_seq2seq_batch."""
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
data_id = self.dataset_kwargs.get("data_id", None)
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {"tgt_texts": tgt_line, "src_texts": source_line, "data_id": data_id}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.max_source_length,
max_target_length=self.max_target_length,
return_tensors="pt",
**self.dataset_kwargs,
).data
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
return batch_encoding
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, padding=None, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
if data_args.src_lang is not None:
self.dataset_kwargs["src_lang"] = data_args.src_lang
if data_args.tgt_lang is not None:
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
self.is_bert_based = self.tokenizer.cls_token is not None
self.padding = padding if padding is not None else ("max_length" if self.tpu_num_cores is not None else "longest")
def __call__(self, batch) -> Dict[str, torch.Tensor]:
init_token_id = self.pad_token_id
lang_idx = None
if "data_id" in batch[0]:
init_token = batch[0]["data_id"]
# get id from vocab
if hasattr(self.data_args, "langid_map"):
mapped_data = self.data_args.langid_map.get(init_token, None)
if mapped_data:
lang_idx, mapped_token = mapped_data
init_token_id = self.tokenizer._convert_token_to_id(mapped_token)
else:
logger.error(f"Unknown langid: {init_token}")
if self.is_bert_based:
batch = self._bert_encode(batch)
elif hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
# labels[labels == self.tokenizer.pad_token_id] = -100
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels, init_token_id)
if self.data_args.use_langid_prefix:
input_ids = self._shift_right_t5(input_ids, init_token_id)
# if self.data_args.use_langid:
# lang_ids = torch.empty(
# labels.shape,
# dtype=labels.dtype).fill_(lang_idx)
elif self.is_bert_based:
# bert based models will automatically add the [CLS] token
pass
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
if not self.is_bert_based:
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
"init_token_id": init_token_id,
# "lang_ids": lang_ids
}
return batch
def _shift_right_t5(self, input_ids, init_token_id):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = init_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding=self.padding, # TPU hack
return_tensors="pt",
**self.dataset_kwargs,
)
return batch_encoding.data
def _bert_encode(self, batch):
inputs = self.tokenizer(
[x["src_texts"] for x in batch],
truncation=True,
max_length=self.data_args.max_source_length,
padding=self.padding, # TPU hack
return_tensors="pt",
**self.dataset_kwargs,
)
outputs = self.tokenizer(
[x["tgt_texts"] for x in batch],
truncation=True,
max_length=self.data_args.max_target_length,
padding=self.padding, # TPU hack
return_tensors="pt",
**self.dataset_kwargs,
)
labels = outputs.input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
output_batch = {
"input_ids" : inputs.input_ids,
"attention_mask" : inputs.attention_mask,
"decoder_input_ids": outputs.input_ids,
"decoder_attention_mask": outputs.attention_mask,
"labels": labels
}
return output_batch
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size, shuffle=True):
self.data, self.bs, self.shuffle = data, batch_size, shuffle
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))