@@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5):
43
43
self .norm = LayerNorm (ninp )
44
44
self .dropout = Dropout (dropout )
45
45
46
- def forward (self , src , token_type_input ):
46
+ def forward (self , seq_inputs ):
47
+ src , token_type_input = seq_inputs
47
48
src = self .embed (src ) + self .pos_embed (src ) \
48
49
+ self .tok_type_embed (src , token_type_input )
49
50
return self .dropout (self .norm (src ))
@@ -114,16 +115,16 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None):
114
115
class BertModel (nn .Module ):
115
116
"""Contain a transformer encoder."""
116
117
117
- def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 ):
118
+ def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , embed_layer , dropout = 0.5 ):
118
119
super (BertModel , self ).__init__ ()
119
120
self .model_type = 'Transformer'
120
- self .bert_embed = BertEmbedding ( ntoken , ninp )
121
+ self .bert_embed = embed_layer
121
122
encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
122
123
self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
123
124
self .ninp = ninp
124
125
125
- def forward (self , src , token_type_input ):
126
- src = self .bert_embed (src , token_type_input )
126
+ def forward (self , seq_inputs ):
127
+ src = self .bert_embed (seq_inputs )
127
128
output = self .transformer_encoder (src )
128
129
return output
129
130
@@ -150,15 +151,16 @@ class MLMTask(nn.Module):
150
151
151
152
def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 ):
152
153
super (MLMTask , self ).__init__ ()
153
- self .bert_model = BertModel (ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 )
154
+ embed_layer = BertEmbedding (ntoken , ninp )
155
+ self .bert_model = BertModel (ntoken , ninp , nhead , nhid , nlayers , embed_layer , dropout = 0.5 )
154
156
self .mlm_span = Linear (ninp , ninp )
155
157
self .activation = F .gelu
156
158
self .norm_layer = LayerNorm (ninp , eps = 1e-12 )
157
159
self .mlm_head = Linear (ninp , ntoken )
158
160
159
161
def forward (self , src , token_type_input = None ):
160
162
src = src .transpose (0 , 1 ) # Wrap up by nn.DataParallel
161
- output = self .bert_model (src , token_type_input )
163
+ output = self .bert_model (( src , token_type_input ) )
162
164
output = self .mlm_span (output )
163
165
output = self .activation (output )
164
166
output = self .norm_layer (output )
@@ -199,7 +201,7 @@ def __init__(self, bert_model):
199
201
200
202
def forward (self , src , token_type_input ):
201
203
src = src .transpose (0 , 1 ) # Wrap up by nn.DataParallel
202
- output = self .bert_model (src , token_type_input )
204
+ output = self .bert_model (( src , token_type_input ) )
203
205
# Send the first <'cls'> seq to a classifier
204
206
output = self .activation (self .linear_layer (output [0 ]))
205
207
output = self .ns_span (output )
@@ -216,7 +218,7 @@ def __init__(self, bert_model):
216
218
self .qa_span = Linear (bert_model .ninp , 2 )
217
219
218
220
def forward (self , src , token_type_input ):
219
- output = self .bert_model (src , token_type_input )
221
+ output = self .bert_model (( src , token_type_input ) )
220
222
# transpose output (S, N, E) to (N, S, E)
221
223
output = output .transpose (0 , 1 )
222
224
output = self .activation (output )
0 commit comments