Skip to content

Subtle errors in Transformer implementation (Chapter 11) #2663

@nickmcgreivy

Description

@nickmcgreivy

I've identified two subtle errors in the from-scratch implementation of the transformer architecture in chapter 11.

Error 1

The first error involves the interaction of the positional encoding with predict_step.

The implementation of PositionalEncoding is correct when training the transformer. In this case, the input to the decoder has shape (batch_size, num_steps), the input to PositionalEncoding has shape (batch_size, num_steps, num_hiddens), and so PositionalEncoding adds a vector

X = X + self.P[:, :X.shape[1], :].to(X.device)

which ranges from i=0 to i=num_steps-1. This is all correct.

But when predicting, the tokens are given one step at a time to the decoder:

outputs, attention_weights = [tgt[:, (0)].unsqueeze(1), ], []
for _ in range(num_steps):
    Y, dec_state = self.decoder(outputs[-1], dec_state)
    outputs.append(Y.argmax(2))

Now the input to PositionalEncoding has shape (batch_size, 1, num_hiddens) and so PositionalEncoding always adds the i=0 component of the vector self.P. This is incorrect. (I see now that this error has already been mentioned by cddc in the comments of 11.7.)

Error 2

The second error is that when computing the validation loss, self.training is False and so dec_valid_lens = None. This means that a causal mask is no longer being applied when computing the validation loss, so the validation loss breaks causality.

We want dec_valid_lens=None during prediction, but not during validation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions