- 
                Notifications
    
You must be signed in to change notification settings  - Fork 4.8k
 
Description
I've identified two subtle errors in the from-scratch implementation of the transformer architecture in chapter 11.
Error 1
The first error involves the interaction of the positional encoding with predict_step.
The implementation of PositionalEncoding is correct when training the transformer. In this case, the input to the decoder has shape (batch_size, num_steps), the input to PositionalEncoding has shape (batch_size, num_steps, num_hiddens), and so PositionalEncoding adds a vector
X = X + self.P[:, :X.shape[1], :].to(X.device)
which ranges from i=0 to i=num_steps-1. This is all correct.
But when predicting, the tokens are given one step at a time to the decoder:
outputs, attention_weights = [tgt[:, (0)].unsqueeze(1), ], []
for _ in range(num_steps):
    Y, dec_state = self.decoder(outputs[-1], dec_state)
    outputs.append(Y.argmax(2))
Now the input to PositionalEncoding has shape (batch_size, 1, num_hiddens) and so PositionalEncoding always adds the i=0 component of the vector self.P. This is incorrect. (I see now that this error has already been mentioned by cddc in the comments of 11.7.)
Error 2
The second error is that when computing the validation loss, self.training is False and so dec_valid_lens = None. This means that a causal mask is no longer being applied when computing the validation loss, so the validation loss breaks causality.
We want dec_valid_lens=None during prediction, but not during validation.