Skip to content

Commit 4c661ca

Browse files
kdavis-mozillakylegao91
authored andcommitted
Fixed shape documentation (#131)
1 parent aef9b9f commit 4c661ca

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

seq2seq/models/TopKDecoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ class TopKDecoder(torch.nn.Module):
4949
Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio
5050
- **inputs** (seq_len, batch, input_size): list of sequences, whose length is the batch size and within which
5151
each sequence is a list of token IDs. It is used for teacher forcing when provided. (default is `None`)
52-
- **encoder_hidden** (batch, seq_len, hidden_size): tensor containing the features in the hidden state `h` of
53-
encoder. Used as the initial hidden state of the decoder.
52+
- **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features
53+
in the hidden state `h` of encoder. Used as the initial hidden state of the decoder.
5454
- **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder.
5555
Used for attention mechanism (default is `None`).
5656
- **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state

0 commit comments

Comments
 (0)