We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 25d45bb commit 3251011Copy full SHA for 3251011
train.py
@@ -184,16 +184,16 @@ def main(_):
184
data_generator = prepare_data.get_batches(data, batch_size)
185
186
hidden_size = 256
187
- encoder1 = EncoderRNN(input_lang.n_words, hidden_size)
188
- attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1)
+ encoder1 = Encoder(input_lang.n_words, hidden_size)
+ decoder1 = Decoder(hidden_size, output_lang.n_words, dropout_p=0.1)
189
190
191
if use_cuda:
192
encoder1 = encoder1.cuda()
193
- attn_decoder1 = attn_decoder1.cuda()
+ attn_decoder1 = decoder1.cuda()
194
195
- trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
+ trainIters(encoder1, decoder1, 75000, print_every=5000)
196
197
198
if __name__ == '__main__':
199
- main()
+ main()
0 commit comments