@@ -18,14 +18,19 @@ class SinusoidalPositionalEmbedding(nn.Module):
18
18
Padding symbols are ignored.
19
19
"""
20
20
21
- def __init__ (self , embedding_dim , padding_idx , init_size = 1024 ):
21
+ def __init__ (self , embedding_dim , padding_idx , init_size = 1024 , auto_expand = True ):
22
22
super ().__init__ ()
23
23
self .embedding_dim = embedding_dim
24
24
self .padding_idx = padding_idx if padding_idx is not None else 0
25
- self .register_buffer ("weights" , SinusoidalPositionalEmbedding .get_embedding (
26
- init_size , embedding_dim , padding_idx
27
- ), persistent = False )
25
+ self .register_buffer (
26
+ "weights" ,
27
+ SinusoidalPositionalEmbedding .get_embedding (
28
+ init_size , embedding_dim , padding_idx
29
+ ),
30
+ persistent = False ,
31
+ )
28
32
self .max_positions = int (1e5 )
33
+ self .auto_expand = auto_expand
29
34
self .onnx_trace = False
30
35
31
36
def prepare_for_onnx_export_ (self ):
@@ -75,28 +80,36 @@ def forward(
75
80
bspair = torch .onnx .operators .shape_as_tensor (input )
76
81
bsz , seq_len = bspair [0 ], bspair [1 ]
77
82
max_pos = self .padding_idx + 1 + seq_len
83
+ weights = self .weights
84
+
78
85
if max_pos > self .weights .size (0 ):
79
- # expand embeddings if needed
80
- self .weights = SinusoidalPositionalEmbedding .get_embedding (
86
+ # If the input is longer than the number of pre-computed embeddings,
87
+ # compute the extra embeddings on the fly.
88
+ # Only store the expanded embeddings if auto_expand=True.
89
+ # In multithreading environments, mutating the weights of a module
90
+ # may cause trouble. Set auto_expand=False if this happens.
91
+ weights = SinusoidalPositionalEmbedding .get_embedding (
81
92
max_pos , self .embedding_dim , self .padding_idx
82
93
).to (self .weights )
94
+ if self .auto_expand :
95
+ self .weights = weights
83
96
84
97
if incremental_state is not None :
85
98
# positions is the same for every token when decoding a single step
86
99
pos = timestep .view (- 1 )[0 ] + 1 if timestep is not None else seq_len
87
100
if self .onnx_trace :
88
101
return (
89
- self . weights .index_select (index = self .padding_idx + pos , dim = 0 )
102
+ weights .index_select (index = self .padding_idx + pos , dim = 0 )
90
103
.unsqueeze (1 )
91
104
.repeat (bsz , 1 , 1 )
92
105
)
93
- return self . weights [self .padding_idx + pos , :].expand (bsz , 1 , - 1 )
106
+ return weights [self .padding_idx + pos , :].expand (bsz , 1 , - 1 )
94
107
95
108
positions = utils .make_positions (
96
109
input , self .padding_idx , onnx_trace = self .onnx_trace
97
110
)
98
111
if self .onnx_trace :
99
- flat_embeddings = self . weights .detach ().index_select (0 , positions .view (- 1 ))
112
+ flat_embeddings = weights .detach ().index_select (0 , positions .view (- 1 ))
100
113
embedding_shape = torch .cat (
101
114
(bsz .view (1 ), seq_len .view (1 ), torch .tensor ([- 1 ], dtype = torch .long ))
102
115
)
@@ -105,7 +118,5 @@ def forward(
105
118
)
106
119
return embeddings
107
120
return (
108
- self .weights .index_select (0 , positions .view (- 1 ))
109
- .view (bsz , seq_len , - 1 )
110
- .detach ()
121
+ weights .index_select (0 , positions .view (- 1 )).view (bsz , seq_len , - 1 ).detach ()
111
122
)
0 commit comments