From 803e3b62d9d565baed2c6833cba12eead7bdd416 Mon Sep 17 00:00:00 2001 From: Ildar Gaisin Date: Tue, 26 Aug 2025 12:09:05 +0200 Subject: [PATCH] multiplied token embeddings by sqrt of embedding dimension before summing with segment/positional embeddings --- .../bert.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chapter_natural-language-processing-pretraining/bert.md b/chapter_natural-language-processing-pretraining/bert.md index 007e32e024..e46979384b 100644 --- a/chapter_natural-language-processing-pretraining/bert.md +++ b/chapter_natural-language-processing-pretraining/bert.md @@ -196,6 +196,7 @@ class BERTEncoder(nn.Block): def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len=1000, **kwargs): super(BERTEncoder, self).__init__(**kwargs) + self.num_hiddens = num_hiddens self.token_embedding = nn.Embedding(vocab_size, num_hiddens) self.segment_embedding = nn.Embedding(2, num_hiddens) self.blks = nn.Sequential() @@ -210,7 +211,9 @@ class BERTEncoder(nn.Block): def forward(self, tokens, segments, valid_lens): # Shape of `X` remains unchanged in the following code snippet: # (batch size, max sequence length, `num_hiddens`) - X = self.token_embedding(tokens) + self.segment_embedding(segments) + # the embedding values are multiplied by the square root of the embedding dimension + # to rescale before they are summed up + X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments) X = X + self.pos_embedding.data(ctx=X.ctx)[:, :X.shape[1], :] for blk in self.blks: X = blk(X, valid_lens) @@ -225,6 +228,7 @@ class BERTEncoder(nn.Module): def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len=1000, **kwargs): super(BERTEncoder, self).__init__(**kwargs) + self.num_hiddens = num_hiddens self.token_embedding = nn.Embedding(vocab_size, num_hiddens) self.segment_embedding = nn.Embedding(2, num_hiddens) self.blks = nn.Sequential() @@ -239,7 +243,9 @@ class BERTEncoder(nn.Module): def forward(self, tokens, segments, valid_lens): # Shape of `X` remains unchanged in the following code snippet: # (batch size, max sequence length, `num_hiddens`) - X = self.token_embedding(tokens) + self.segment_embedding(segments) + # the embedding values are multiplied by the square root of the embedding dimension + # to rescale before they are summed up + X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments) X = X + self.pos_embedding[:, :X.shape[1], :] for blk in self.blks: X = blk(X, valid_lens)