Skip to content

Commit 803e3b6

Browse files
committed
multiplied token embeddings by sqrt of embedding dimension before summing with segment/positional embeddings
1 parent 23d7a5a commit 803e3b6

File tree

1 file changed

+8
-2
lines changed
  • chapter_natural-language-processing-pretraining

1 file changed

+8
-2
lines changed

chapter_natural-language-processing-pretraining/bert.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ class BERTEncoder(nn.Block):
196196
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
197197
num_blks, dropout, max_len=1000, **kwargs):
198198
super(BERTEncoder, self).__init__(**kwargs)
199+
self.num_hiddens = num_hiddens
199200
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
200201
self.segment_embedding = nn.Embedding(2, num_hiddens)
201202
self.blks = nn.Sequential()
@@ -210,7 +211,9 @@ class BERTEncoder(nn.Block):
210211
def forward(self, tokens, segments, valid_lens):
211212
# Shape of `X` remains unchanged in the following code snippet:
212213
# (batch size, max sequence length, `num_hiddens`)
213-
X = self.token_embedding(tokens) + self.segment_embedding(segments)
214+
# the embedding values are multiplied by the square root of the embedding dimension
215+
# to rescale before they are summed up
216+
X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments)
214217
X = X + self.pos_embedding.data(ctx=X.ctx)[:, :X.shape[1], :]
215218
for blk in self.blks:
216219
X = blk(X, valid_lens)
@@ -225,6 +228,7 @@ class BERTEncoder(nn.Module):
225228
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
226229
num_blks, dropout, max_len=1000, **kwargs):
227230
super(BERTEncoder, self).__init__(**kwargs)
231+
self.num_hiddens = num_hiddens
228232
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
229233
self.segment_embedding = nn.Embedding(2, num_hiddens)
230234
self.blks = nn.Sequential()
@@ -239,7 +243,9 @@ class BERTEncoder(nn.Module):
239243
def forward(self, tokens, segments, valid_lens):
240244
# Shape of `X` remains unchanged in the following code snippet:
241245
# (batch size, max sequence length, `num_hiddens`)
242-
X = self.token_embedding(tokens) + self.segment_embedding(segments)
246+
# the embedding values are multiplied by the square root of the embedding dimension
247+
# to rescale before they are summed up
248+
X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments)
243249
X = X + self.pos_embedding[:, :X.shape[1], :]
244250
for blk in self.blks:
245251
X = blk(X, valid_lens)

0 commit comments

Comments
 (0)