@@ -146,7 +146,7 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
146146 layernorm_output = self .input_layernorm (hidden_states )
147147
148148 # Self attention.
149- attention_output , has_cache = self .attention (
149+ attention_output , att_has_cache = self .attention (
150150 layernorm_output , ltor_mask , has_cache = has_cache , use_cache = use_cache )
151151
152152 if self .cogview_sandwich_layernorm :
@@ -159,15 +159,16 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
159159 layernorm_output = self .post_attention_layernorm (layernorm_input )
160160
161161 # MLP.
162- mlp_output = self .mlp (layernorm_output )
162+ mlp_output , mlp_has_cache = self .mlp (
163+ layernorm_output , has_cache = has_cache , use_cache = use_cache )
163164
164165 if self .cogview_sandwich_layernorm :
165166 mlp_output = self .before_second_addition_layernorm (mlp_output )
166167
167168 # Second residual connection.
168169 output = layernorm_input + mlp_output
169170
170- return output , has_cache
171+ return output , att_has_cache and mlp_has_cache
171172
172173
173174class DalleSelfAttention (torch .nn .Module ):
@@ -212,6 +213,11 @@ def __init__(self, hidden_size, num_attention_heads,
212213 self .dense = torch .nn .Linear (hidden_size , hidden_size )
213214 self .output_dropout = torch .nn .Dropout (output_dropout_prob )
214215
216+ # Cache
217+ self .past_key = None
218+ self .past_value = None
219+ self .past_output = None
220+
215221 def _transpose_for_scores (self , tensor ):
216222 """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
217223 new_tensor_shape = tensor .size ()[:- 1 ] + (self .num_attention_heads , self .hidden_size_per_attention_head )
@@ -227,6 +233,7 @@ def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask):
227233 )
228234 else :
229235 attention_scores = torch .matmul (query_layer , key_t ) / math .sqrt (self .hidden_size_per_attention_head )
236+ ltor_mask = ltor_mask [:, :, - attention_scores .shape [- 2 ]:]
230237 attention_scores = torch .mul (attention_scores , ltor_mask ) - 10000.0 * (1.0 - ltor_mask )
231238 if self .cogview_pb_relax :
232239 # normalize attention scores. Should not affect resulting softmax value
@@ -258,10 +265,10 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
258265 key_layer = self ._transpose_for_scores (mixed_key_layer )
259266 value_layer = self ._transpose_for_scores (mixed_value_layer )
260267
268+ # Can be simplified, but I didn't for readability's sake
261269 if use_cache and has_cache :
262- value_layer = torch .cat ((self .past_value , value_layer ), dim = - 2 )
263- query_layer = torch .cat ((self .past_query , query_layer ), dim = - 2 )
264270 key_layer = torch .cat ((self .past_key , key_layer ), dim = - 2 )
271+ value_layer = torch .cat ((self .past_value , value_layer ), dim = - 2 )
265272 attention_scores = self ._calculate_attention_scores (
266273 query_layer = query_layer , key_layer = key_layer , ltor_mask = ltor_mask
267274 )
@@ -271,13 +278,17 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
271278 )
272279
273280 if use_cache :
274- self .past_query = query_layer
275281 self .past_key = key_layer
276282 self .past_value = value_layer
277- has_cache = True
278283 else :
284+ self .past_key = None
285+ self .past_value = None
286+ self .past_output = None
279287 has_cache = False
280288
289+ if use_cache and has_cache :
290+ attention_scores = attention_scores [..., - 1 :, :]
291+
281292 # Attention probabilities. [b, np, s, s]
282293 attention_probs = torch .nn .Softmax (dim = - 1 )(attention_scores )
283294
@@ -298,6 +309,16 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
298309
299310 # Output. [b, s, h]
300311 output = self .dense (context_layer )
312+
313+ if use_cache :
314+ # Can be simplified, but I didn't for readability's sake
315+ if has_cache :
316+ output = torch .cat ((self .past_output , output ), dim = - 2 )
317+ self .past_output = output
318+ else :
319+ self .past_output = output
320+ has_cache = True
321+
301322 output = self .output_dropout (output )
302323 return output , has_cache
303324
@@ -321,12 +342,30 @@ def __init__(self, hidden_size, output_dropout_prob):
321342 # Project back to h.
322343 self .dense_4h_to_h = torch .nn .Linear (4 * hidden_size , hidden_size )
323344 self .dropout = torch .nn .Dropout (output_dropout_prob )
345+ # MLP cache
346+ self .past_x = None
347+
348+ def forward (self , hidden_states , has_cache = False , use_cache = False ):
349+ if has_cache and use_cache :
350+ hidden_states = hidden_states [:, - 1 :]
324351
325- def forward (self , hidden_states ):
326352 # [b, s, 4hp]
327353 x = self .dense_h_to_4h (hidden_states )
328354 x = gelu (x )
329355 # [b, s, h]
330356 x = self .dense_4h_to_h (x )
357+ if use_cache :
358+ # Can be simplified, but I didn't for readability's sake
359+ if has_cache :
360+ x = torch .cat ((self .past_x , x ), dim = - 2 )
361+ self .past_x = x
362+ else :
363+ self .past_x = x
364+
365+ has_cache = True
366+ else :
367+ self .past_x = None
368+ has_cache = False
331369 output = self .dropout (x )
332- return output
370+
371+ return output , has_cache
0 commit comments