Skip to content

Commit 47de7a2

Browse files
authored
optimize generation caching (#12)
Over 10x speedup, adds MLP caching and optimizes attention caching. Uses changes from https://t.co/BTwo6NKq9H.
1 parent 0b3d648 commit 47de7a2

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

rudalle/dalle/transformer.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
146146
layernorm_output = self.input_layernorm(hidden_states)
147147

148148
# Self attention.
149-
attention_output, has_cache = self.attention(
149+
attention_output, att_has_cache = self.attention(
150150
layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)
151151

152152
if self.cogview_sandwich_layernorm:
@@ -159,15 +159,16 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
159159
layernorm_output = self.post_attention_layernorm(layernorm_input)
160160

161161
# MLP.
162-
mlp_output = self.mlp(layernorm_output)
162+
mlp_output, mlp_has_cache = self.mlp(
163+
layernorm_output, has_cache=has_cache, use_cache=use_cache)
163164

164165
if self.cogview_sandwich_layernorm:
165166
mlp_output = self.before_second_addition_layernorm(mlp_output)
166167

167168
# Second residual connection.
168169
output = layernorm_input + mlp_output
169170

170-
return output, has_cache
171+
return output, att_has_cache and mlp_has_cache
171172

172173

173174
class DalleSelfAttention(torch.nn.Module):
@@ -212,6 +213,11 @@ def __init__(self, hidden_size, num_attention_heads,
212213
self.dense = torch.nn.Linear(hidden_size, hidden_size)
213214
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
214215

216+
# Cache
217+
self.past_key = None
218+
self.past_value = None
219+
self.past_output = None
220+
215221
def _transpose_for_scores(self, tensor):
216222
""" Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
217223
new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head)
@@ -227,6 +233,7 @@ def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask):
227233
)
228234
else:
229235
attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
236+
ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:]
230237
attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
231238
if self.cogview_pb_relax:
232239
# normalize attention scores. Should not affect resulting softmax value
@@ -258,10 +265,10 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
258265
key_layer = self._transpose_for_scores(mixed_key_layer)
259266
value_layer = self._transpose_for_scores(mixed_value_layer)
260267

268+
# Can be simplified, but I didn't for readability's sake
261269
if use_cache and has_cache:
262-
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
263-
query_layer = torch.cat((self.past_query, query_layer), dim=-2)
264270
key_layer = torch.cat((self.past_key, key_layer), dim=-2)
271+
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
265272
attention_scores = self._calculate_attention_scores(
266273
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
267274
)
@@ -271,13 +278,17 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
271278
)
272279

273280
if use_cache:
274-
self.past_query = query_layer
275281
self.past_key = key_layer
276282
self.past_value = value_layer
277-
has_cache = True
278283
else:
284+
self.past_key = None
285+
self.past_value = None
286+
self.past_output = None
279287
has_cache = False
280288

289+
if use_cache and has_cache:
290+
attention_scores = attention_scores[..., -1:, :]
291+
281292
# Attention probabilities. [b, np, s, s]
282293
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
283294

@@ -298,6 +309,16 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
298309

299310
# Output. [b, s, h]
300311
output = self.dense(context_layer)
312+
313+
if use_cache:
314+
# Can be simplified, but I didn't for readability's sake
315+
if has_cache:
316+
output = torch.cat((self.past_output, output), dim=-2)
317+
self.past_output = output
318+
else:
319+
self.past_output = output
320+
has_cache = True
321+
301322
output = self.output_dropout(output)
302323
return output, has_cache
303324

@@ -321,12 +342,30 @@ def __init__(self, hidden_size, output_dropout_prob):
321342
# Project back to h.
322343
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
323344
self.dropout = torch.nn.Dropout(output_dropout_prob)
345+
# MLP cache
346+
self.past_x = None
347+
348+
def forward(self, hidden_states, has_cache=False, use_cache=False):
349+
if has_cache and use_cache:
350+
hidden_states = hidden_states[:, -1:]
324351

325-
def forward(self, hidden_states):
326352
# [b, s, 4hp]
327353
x = self.dense_h_to_4h(hidden_states)
328354
x = gelu(x)
329355
# [b, s, h]
330356
x = self.dense_4h_to_h(x)
357+
if use_cache:
358+
# Can be simplified, but I didn't for readability's sake
359+
if has_cache:
360+
x = torch.cat((self.past_x, x), dim=-2)
361+
self.past_x = x
362+
else:
363+
self.past_x = x
364+
365+
has_cache = True
366+
else:
367+
self.past_x = None
368+
has_cache = False
331369
output = self.dropout(x)
332-
return output
370+
371+
return output, has_cache

0 commit comments

Comments
 (0)