Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ecbf110

Browse files
authoredOct 18, 2024··
Add an auto_expand option to SinusoidalPositionalEmbedding (#5555)
1 parent 018621f commit ecbf110

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed
 

‎fairseq/modules/positional_embedding.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def PositionalEmbedding(
1414
embedding_dim: int,
1515
padding_idx: int,
1616
learned: bool = False,
17+
auto_expand: bool = True,
1718
):
1819
if learned:
1920
# if padding_idx is specified then offset the embedding ids by
@@ -31,5 +32,6 @@ def PositionalEmbedding(
3132
embedding_dim,
3233
padding_idx,
3334
init_size=num_embeddings + padding_idx + 1,
35+
auto_expand=auto_expand,
3436
)
3537
return m

‎fairseq/modules/sinusoidal_positional_embedding.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,19 @@ class SinusoidalPositionalEmbedding(nn.Module):
1818
Padding symbols are ignored.
1919
"""
2020

21-
def __init__(self, embedding_dim, padding_idx, init_size=1024):
21+
def __init__(self, embedding_dim, padding_idx, init_size=1024, auto_expand=True):
2222
super().__init__()
2323
self.embedding_dim = embedding_dim
2424
self.padding_idx = padding_idx if padding_idx is not None else 0
25-
self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding(
26-
init_size, embedding_dim, padding_idx
27-
), persistent=False)
25+
self.register_buffer(
26+
"weights",
27+
SinusoidalPositionalEmbedding.get_embedding(
28+
init_size, embedding_dim, padding_idx
29+
),
30+
persistent=False,
31+
)
2832
self.max_positions = int(1e5)
33+
self.auto_expand = auto_expand
2934
self.onnx_trace = False
3035

3136
def prepare_for_onnx_export_(self):
@@ -75,28 +80,36 @@ def forward(
7580
bspair = torch.onnx.operators.shape_as_tensor(input)
7681
bsz, seq_len = bspair[0], bspair[1]
7782
max_pos = self.padding_idx + 1 + seq_len
83+
weights = self.weights
84+
7885
if max_pos > self.weights.size(0):
79-
# expand embeddings if needed
80-
self.weights = SinusoidalPositionalEmbedding.get_embedding(
86+
# If the input is longer than the number of pre-computed embeddings,
87+
# compute the extra embeddings on the fly.
88+
# Only store the expanded embeddings if auto_expand=True.
89+
# In multithreading environments, mutating the weights of a module
90+
# may cause trouble. Set auto_expand=False if this happens.
91+
weights = SinusoidalPositionalEmbedding.get_embedding(
8192
max_pos, self.embedding_dim, self.padding_idx
8293
).to(self.weights)
94+
if self.auto_expand:
95+
self.weights = weights
8396

8497
if incremental_state is not None:
8598
# positions is the same for every token when decoding a single step
8699
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
87100
if self.onnx_trace:
88101
return (
89-
self.weights.index_select(index=self.padding_idx + pos, dim=0)
102+
weights.index_select(index=self.padding_idx + pos, dim=0)
90103
.unsqueeze(1)
91104
.repeat(bsz, 1, 1)
92105
)
93-
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
106+
return weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
94107

95108
positions = utils.make_positions(
96109
input, self.padding_idx, onnx_trace=self.onnx_trace
97110
)
98111
if self.onnx_trace:
99-
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
112+
flat_embeddings = weights.detach().index_select(0, positions.view(-1))
100113
embedding_shape = torch.cat(
101114
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
102115
)
@@ -105,7 +118,5 @@ def forward(
105118
)
106119
return embeddings
107120
return (
108-
self.weights.index_select(0, positions.view(-1))
109-
.view(bsz, seq_len, -1)
110-
.detach()
121+
weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
111122
)

0 commit comments

Comments
 (0)
Please sign in to comment.