-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_translation.py
More file actions
79 lines (66 loc) · 4.13 KB
/
train_translation.py
File metadata and controls
79 lines (66 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from argparse import ArgumentParser
import numpy as np
from transformer import preprocessing
from transformer.model import Transformer, predict
def main():
# Create argument parser
argparser = ArgumentParser()
# File flags
argparser.add_argument('--train_inputs', type=str, default='./data/train.tags.de-en.en',
help='Path to training data inputs.')
argparser.add_argument('--train_labels', type=str, default='./data/train.tags.de-en.de',
help='Path to training data labels.')
argparser.add_argument('--val_inputs', type=str, default='./data/IWSLT16.TED.tst2012.de-en.en.xml',
help='Path to val data inputs.')
argparser.add_argument('--val_labels', type=str, default='./data/IWSLT16.TED.tst2012.de-en.de.xml',
help='Path to val data labels.')
argparser.add_argument('--test_inputs', type=str, default='./data/IWSLT16.TED.tst2014.de-en.en.xml',
help='Path to test data inputs.')
argparser.add_argument('--test_labels', type=str, default='./data/IWSLT16.TED.tst2014.de-en.de.xml',
help='Path to test data labels.')
argparser.add_argument('--en_vocab_path', type=str, default='./data/en-vocab.csv',
help='Path to English vocabulary file.')
argparser.add_argument('--de_vocab_path', type=str, default='./data/de-vocab.csv',
help='Path to German vocabulary file.')
# Model hyperparameters
argparser.add_argument('--logdir', type=str, default='./tmp/model',
help='Path to model checkpoint directory.')
argparser.add_argument('--num_epochs', type=int, default=10, help='Number of epochs to train for.')
argparser.add_argument('--batch_size', type=int, default=32, help='Size of each batch.')
argparser.add_argument('--learning_rate', type=float, default=1e-4, help='Parameter update rate.')
argparser.add_argument('--mhdpa_heads', type=int, default=8, help='Number of heads in MHDPA module.')
argparser.add_argument('--mlp_units', type=int, default=512, help='Number of MLP units.')
argparser.add_argument('--mhdpa_blocks', type=int, default=6, choices=range(1, 10),
help='Number of MHDPA blocks to use in encoder.')
argparser.add_argument('--dropout_rate', type=float, default=0.1, help='Dropout rate.')
argparser.add_argument('--sequence_length', type=int, default=15,
help='Length of input/output sequences.')
# Running parameters
argparser.add_argument('--mode', type=str, choices=['train', 'test', 'predict'], default='train',
help='Which mode to run script in.')
# Parse arguments
flags = argparser.parse_args()
# Make vocabularies
en_vocab_size = preprocessing.make_vocabulary(flags.train_inputs, flags.en_vocab_path)
de_vocab_size = preprocessing.make_vocabulary(flags.train_labels, flags.de_vocab_path)
(x_train, y_train), (x_val, y_val), (x_test, y_test) = preprocessing.create_datasets(flags)
# Create model
model = Transformer(flags, en_vocab_size, de_vocab_size)
if flags.mode == 'train':
model.fit(x_train, y_train, x_val, y_val)
elif flags.mode == 'test':
loss, acc = model.eval(x_test, y_test)
print(f'[loss: {loss}; acc: {acc}]')
elif flags.mode == 'predict':
indices = np.random.randint(0, len(x_test), size=flags.batch_size)
inputs, true_outputs, predicted_outputs = predict(model, logdir=flags.logdir, inputs=x_test[indices],
labels=y_test[indices], vocab_file=flags.vocabulary_file,
input_seq_len=flags.input_sequence_length,
output_seq_len=flags.output_sequence_length)
# noinspection PyTypeChecker
for s, o, p in zip(inputs, true_outputs, predicted_outputs):
print('input: ', s)
print('true: ', o)
print('predicted: ', p)
if __name__ == '__main__':
main()