@@ -196,6 +196,7 @@ class BERTEncoder(nn.Block):
196196 def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
197197 num_blks, dropout, max_len=1000, **kwargs):
198198 super(BERTEncoder, self).__init__(**kwargs)
199+ self.num_hiddens = num_hiddens
199200 self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
200201 self.segment_embedding = nn.Embedding(2, num_hiddens)
201202 self.blks = nn.Sequential()
@@ -210,7 +211,9 @@ class BERTEncoder(nn.Block):
210211 def forward(self, tokens, segments, valid_lens):
211212 # Shape of `X` remains unchanged in the following code snippet:
212213 # (batch size, max sequence length, `num_hiddens`)
213- X = self.token_embedding(tokens) + self.segment_embedding(segments)
214+ # the embedding values are multiplied by the square root of the embedding dimension
215+ # to rescale before they are summed up
216+ X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments)
214217 X = X + self.pos_embedding.data(ctx=X.ctx)[:, :X.shape[1], :]
215218 for blk in self.blks:
216219 X = blk(X, valid_lens)
@@ -225,6 +228,7 @@ class BERTEncoder(nn.Module):
225228 def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
226229 num_blks, dropout, max_len=1000, **kwargs):
227230 super(BERTEncoder, self).__init__(**kwargs)
231+ self.num_hiddens = num_hiddens
228232 self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
229233 self.segment_embedding = nn.Embedding(2, num_hiddens)
230234 self.blks = nn.Sequential()
@@ -239,7 +243,9 @@ class BERTEncoder(nn.Module):
239243 def forward(self, tokens, segments, valid_lens):
240244 # Shape of `X` remains unchanged in the following code snippet:
241245 # (batch size, max sequence length, `num_hiddens`)
242- X = self.token_embedding(tokens) + self.segment_embedding(segments)
246+ # the embedding values are multiplied by the square root of the embedding dimension
247+ # to rescale before they are summed up
248+ X = self.token_embedding(tokens)*math.sqrt(self.num_hiddens) + self.segment_embedding(segments)
243249 X = X + self.pos_embedding[:, :X.shape[1], :]
244250 for blk in self.blks:
245251 X = blk(X, valid_lens)
0 commit comments