-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
128 lines (94 loc) · 4.58 KB
/
train.py
File metadata and controls
128 lines (94 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from torch import optim
from torch import nn
import random
import time
from dataset import SOS_token, EOS_token
from dataset import tensors_from_pair, prepare_data
from utils import time_since, show_plot
from model import AttnDecoderRNN, EncoderRNN
from evaluate import evaluate_randomly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_size = 256
teacher_forcing_ratio = 0.5
def train_iteration(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_func):
input_length = input_tensor.size(0)
target_length = target_tensor.size(0)
encoder_hidden = encoder.init_hidden().to(device)
encoder_outputs = torch.zeros(decoder.max_length, encoder.hidden_size, device=device)
# Zero the model gradients.
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss = 0
# Encoder.
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden = encoder_hidden
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
# Decoder.
if use_teacher_forcing:
# Teacher forcing: Feed the target as the next input
for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
loss += loss_func(decoder_output, target_tensor[di])
decoder_input = target_tensor[di] # Teacher forcing
else:
# Without teacher forcing: use its own predictions as the next input
for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
_, top_index = decoder_output.topk(1)
decoder_input = top_index.squeeze().detach() # detach from history as input
loss += loss_func(decoder_output, target_tensor[di])
if decoder_input.item() == EOS_token:
break
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
return loss.item() / target_length
def train(encoder, decoder, n_iters, pairs, input_lang, output_lang, print_every=1000, plot_every=1000, learning_rate=0.01):
start = time.time()
plot_losses = []
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every
encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
# Randomly sample pairs from training set.
training_pairs = []
for i in range(n_iters):
lang1_sample, lang2_sample = tensors_from_pair(random.choice(pairs), input_lang, output_lang)
lang1_sample.to(device), lang2_sample.to(device)
training_pairs.append((lang1_sample, lang2_sample))
loss_func = nn.NLLLoss()
for iter_ in range(1, n_iters + 1):
training_pair = training_pairs[iter_ - 1] # Get a training pair.
input_tensor = training_pair[0]
target_tensor = training_pair[1]
loss = train_iteration(input_tensor, target_tensor, encoder,
decoder, encoder_optimizer, decoder_optimizer, loss_func)
print_loss_total += loss
plot_loss_total += loss
if iter_ % print_every == 0:
print_loss_avg = print_loss_total / print_every
print_loss_total = 0
print('%s (%d %d%%) %.4f' % (time_since(start, iter_ / n_iters),
iter_, iter_ / n_iters * 100, print_loss_avg))
if iter_ % plot_every == 0:
plot_loss_avg = plot_loss_total / plot_every
plot_losses.append(plot_loss_avg)
plot_loss_total = 0
return plot_losses
def main():
input_lang, output_lang, pairs = prepare_data('eng', 'fra', reverse=False)
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1, max_length=10).to(device)
loss_history = []
for i in range(5):
losses = train(encoder, attn_decoder, len(pairs), pairs=pairs, input_lang=input_lang, output_lang=output_lang, print_every=1000)
loss_history.extend(losses)
evaluate_randomly(encoder, attn_decoder, pairs, max_length=10, input_lang=input_lang, output_lang=output_lang)
show_plot(loss_history)
print('done training')
if __name__ == "__main__":
main()