-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathPaperGenerationProblem.py
125 lines (109 loc) · 5.2 KB
/
PaperGenerationProblem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import shutil
import numpy as np
import pickle
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.layers import modalities
from tensor2tensor.utils import metrics
from tensor2tensor.utils import mlperf_log
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_encoder
@registry.register_problem
class PaperGenerationProblem(text_problems.Text2SelfProblem):
@property
def corpus_url(self):
url = "Update here with the URL to the Dataset"
if url == "Update here with the URL to the Dataset":
raise Exception("Update the URL to the Dataset")
return
@property
def is_generate_per_split(self):
return False
@property
def vocab_type(self):
return text_problems.VocabType.CHARACTER
@property
def sequence_length(self):
"""Length of each example (in tokens)."""
return 128
@property
def dataset_splits(self):
"""Splits of data to produce and number of output shards for each."""
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 100,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}]
def _maybe_download_data(self, tmp_dir):
if hasattr(self, "paper_dataset"):
return self.paper_dataset
else:
generator_utils.maybe_download(tmp_dir, "paper_dataset.txt", self.corpus_url)
paper_dataset_file = open(os.path.join(tmp_dir, "paper_dataset.txt"), 'rb')
self.paper_dataset = paper_dataset_file.read().decode(encoding='utf-8')
paper_dataset_file.close()
return self.paper_dataset
def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""Generate samples of text.
Args:
data_dir: final data directory. Typically only used in this method to copy
over user-supplied vocab files (for example, if vocab_type ==
VocabType.TOKEN).
tmp_dir: temporary directory that you can use for downloading and scratch.
dataset_split: problem.DatasetSplit, which data split to generate samples
for (for example, training and evaluation).
Yields:
Sample: dict<str feature_name, str text>: for language modeling problems
(i.e. Text2SelfProblems), this generator should yield dicts with only
the "targets" key.
"""
paper_dataset = self._maybe_download_data(tmp_dir)
data_seq_len = self.sequence_length
self.nb_samples = int(np.ceil(len(paper_dataset)/data_seq_len))
for i in range(self.nb_samples):
text = paper_dataset[i*data_seq_len : (i+1)*data_seq_len]
yield {"targets": text}
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
# override generate_encoded_samples, in order to override text2text_generate_encoded function
if dataset_split == problem.DatasetSplit.TRAIN:
mlperf_log.transformer_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING)
elif dataset_split == problem.DatasetSplit.EVAL:
mlperf_log.transformer_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL)
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
def text2text_generate_encoded(sample_generator,
vocab,
targets_vocab=None,
has_inputs=True,
inputs_prefix="",
targets_prefix=""):
# override text2text_generate_encoded, in order to avoid EOS (end of string)
# since for the problem, example sequences should not end
"""Encode Text2Text samples from the generator with the vocab."""
targets_vocab = targets_vocab or vocab
for sample in sample_generator:
if has_inputs:
sample["inputs"] = vocab.encode(inputs_prefix + sample["inputs"])
#sample["inputs"].append(text_encoder.EOS_ID)
sample["targets"] = targets_vocab.encode(targets_prefix + sample["targets"])
#sample["targets"].append(text_encoder.EOS_ID)
yield sample
return text2text_generate_encoded(generator, encoder,
has_inputs=self.has_inputs,
inputs_prefix=self.inputs_prefix,
targets_prefix=self.targets_prefix)
def hparams(self, defaults, unused_model_hparams):
super(PaperGenerationProblem, self).hparams(defaults, unused_model_hparams)
p = defaults
p.loss_multiplier = 1.0
def eval_metrics(self):
return [
metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5,
metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.BITS_PER_CHAR,
metrics.Metrics.NEG_LOG_PERPLEXITY, metrics.Metrics.APPROX_BLEU,
metrics.Metrics.ROUGE_2_F, metrics.Metrics.ROUGE_L_F
]