@@ -234,26 +234,18 @@ def llama_attn_forward_legacy(
234
234
** kwargs ,
235
235
) -> Tuple [torch .Tensor , Optional [torch .Tensor ],
236
236
Optional [Tuple [torch .Tensor ]]]:
237
- # LlamaFlashAttention2 attention does not support output_attentions
237
+ # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501
238
238
if 'padding_mask' in kwargs :
239
239
warnings .warn (
240
- 'Passing `padding_mask` is deprecated and will be removed in v4.37'
241
- ' Please make sure use `attention_mask` instead.`' )
242
-
243
- # overwrite attention_mask with padding_mask
244
- attention_mask = kwargs .pop ('padding_mask' )
245
-
246
- output_attentions = False
240
+ 'Passing `padding_mask` is deprecated and will be removed in '
241
+ 'v4.37. Please make sure use `attention_mask` instead.`' )
247
242
248
243
bsz , q_len , _ = hidden_states .size ()
249
244
250
245
query_states = self .q_proj (hidden_states )
251
246
key_states = self .k_proj (hidden_states )
252
247
value_states = self .v_proj (hidden_states )
253
248
254
- # Flash attention requires the input to have the shape
255
- # batch_size x seq_length x head_dim x hidden_dim
256
- # therefore we just need to keep the original shape
257
249
query_states = query_states .view (bsz , q_len , self .num_heads ,
258
250
self .head_dim ).transpose (1 , 2 )
259
251
key_states = key_states .view (bsz , q_len , self .num_key_value_heads ,
@@ -263,6 +255,13 @@ def llama_attn_forward_legacy(
263
255
264
256
kv_seq_len = key_states .shape [- 2 ]
265
257
if past_key_value is not None :
258
+ if self .layer_idx is None :
259
+ raise ValueError (
260
+ 'The cache structure has changed since version v4.36. '
261
+ f'If you are using { self .__class__ .__name__ } '
262
+ 'for auto-regressive decoding with k/v caching, '
263
+ 'please make sure to initialize the attention class '
264
+ 'with a layer index.' )
266
265
kv_seq_len += past_key_value .get_usable_length (kv_seq_len ,
267
266
self .layer_idx )
268
267
assert position_ids is not None
@@ -282,10 +281,6 @@ def llama_attn_forward_legacy(
282
281
key_states = repeat_kv (key_states , self .num_key_value_groups )
283
282
value_states = repeat_kv (value_states , self .num_key_value_groups )
284
283
285
- # repeat kv for sequence parallel
286
- key_states = repeat_kv_bshd (key_states , self .num_key_value_groups )
287
- value_states = repeat_kv_bshd (value_states , self .num_key_value_groups )
288
-
289
284
assert SUPPORT_FLASH2
290
285
query_states = query_states .transpose (1 , 2 )
291
286
key_states = key_states .transpose (1 , 2 )
0 commit comments