Skip to content

Commit 3251011

Browse files
authored
Update train.py
1 parent 25d45bb commit 3251011

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

Diff for: train.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,16 @@ def main(_):
184184
data_generator = prepare_data.get_batches(data, batch_size)
185185

186186
hidden_size = 256
187-
encoder1 = EncoderRNN(input_lang.n_words, hidden_size)
188-
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1)
187+
encoder1 = Encoder(input_lang.n_words, hidden_size)
188+
decoder1 = Decoder(hidden_size, output_lang.n_words, dropout_p=0.1)
189189

190190

191191
if use_cuda:
192192
encoder1 = encoder1.cuda()
193-
attn_decoder1 = attn_decoder1.cuda()
193+
attn_decoder1 = decoder1.cuda()
194194

195-
trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
195+
trainIters(encoder1, decoder1, 75000, print_every=5000)
196196

197197

198198
if __name__ == '__main__':
199-
main()
199+
main()

0 commit comments

Comments
 (0)