@@ -272,10 +272,11 @@ def gaudi_flash_attn_v1(
272272 def pre_attn_forward (
273273 self ,
274274 hidden_states : torch .Tensor ,
275- position_embeddings : tuple [torch .Tensor , torch .Tensor ],
276275 attention_mask : Optional [torch .Tensor ],
276+ position_ids : Optional [torch .LongTensor ] = None ,
277277 past_key_value : Optional [Cache ] = None ,
278- use_cache : bool = False ,
278+ output_attentions : Optional [bool ] = False ,
279+ use_cache : Optional [bool ] = False ,
279280 cache_position : Optional [torch .LongTensor ] = None ,
280281 token_idx : Optional [torch .Tensor ] = None ,
281282 attn_softmax_bf16 : Optional [bool ] = False ,
@@ -284,6 +285,7 @@ def pre_attn_forward(
284285 flash_attention_recompute : Optional [bool ] = False ,
285286 flash_attention_causal_mask : Optional [bool ] = False ,
286287 cache_idx : Optional [int ] = None ,
288+ position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
287289 ** kwargs ,
288290 ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [tuple [torch .Tensor ]]]:
289291 """
@@ -295,31 +297,38 @@ def pre_attn_forward(
295297 - add new args use_flash_attention
296298 - add new arg flash_attention_recompute
297299 """
298- input_shape = hidden_states .shape [:- 1 ]
299- q_len = input_shape [ 1 ]
300+ bsz , q_len = hidden_states .shape [:2 ]
301+ input_shape = ( bsz , q_len )
300302 hidden_shape = (* input_shape , - 1 , self .head_dim )
301303
302304 query_states = self .q_proj (hidden_states ).view (hidden_shape ).transpose (1 , 2 )
303305 key_states = self .k_proj (hidden_states ).view (hidden_shape ).transpose (1 , 2 )
304306 value_states = self .v_proj (hidden_states ).view (hidden_shape ).transpose (1 , 2 )
305307
308+ if position_ids is None :
309+ position_ids = kwargs .get ("position_ids" )
310+
306311 kv_seq_len = key_states .shape [- 2 ]
307312 if past_key_value is not None :
308- if token_idx is None :
309- if hasattr (past_key_value , "get_usable_length" ):
310- kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
311- else :
312- kv_seq_len += past_key_value [0 ].shape [- 2 ]
313+ if hasattr (past_key_value , "get_usable_length" ):
314+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
313315 else :
314- if reuse_cache :
315- kv_seq_len = past_key_value [0 ][- 2 ]
316- else :
317- kv_seq_len = past_key_value [0 ].shape [- 2 ]
316+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
318317
319- cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
320- query_states , key_states = apply_rotary_pos_emb (
321- query_states , key_states , cos [kwargs ["position_ids" ]], sin [kwargs ["position_ids" ]]
322- )
318+ if position_embeddings is not None :
319+ cos_emb , sin_emb = position_embeddings
320+ else :
321+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
322+
323+ if position_ids is None :
324+ start = kv_seq_len - q_len
325+ position_ids = torch .arange (
326+ start , start + q_len , dtype = torch .long , device = query_states .device
327+ ).unsqueeze (0 )
328+ cos_emb = cos [position_ids ]
329+ sin_emb = sin [position_ids ]
330+
331+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos_emb , sin_emb )
323332
324333 if use_cache :
325334 # reuse k, v, self_attention
@@ -334,8 +343,10 @@ def pre_attn_forward(
334343 key_states .shape , dtype = self .k_proj .weight .dtype , device = key_states .device
335344 )
336345 past_key_value = (past_key , past_value )
346+
337347 key_states = self .k_cache .update (past_key_value [0 ], key_states , 2 , token_idx , self .inp_seq_len )
338348 value_states = self .v_cache .update (past_key_value [1 ], value_states , 2 , token_idx , self .inp_seq_len )
349+
339350 if token_idx is None :
340351 past_key_value = (key_states , value_states )
341352
@@ -386,7 +397,6 @@ def pre_attn_forward(
386397 None ,
387398 flash_attention_recompute ,
388399 )
389-
390400 else :
391401 attn_output , attn_weights = gaudi_eager_attention_forward (
392402 self ,
@@ -450,68 +460,84 @@ def update_sincos_cache(self, seq_len):
450460 def pre_attn (
451461 self ,
452462 hidden_states : torch .Tensor ,
463+ position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
453464 attention_mask : Optional [torch .Tensor ] = None ,
454465 position_ids : Optional [torch .LongTensor ] = None ,
455- past_key_value : Optional [tuple [torch .Tensor ]] = None ,
466+ past_key_value : Optional [Cache ] = None ,
467+ output_attentions : Optional [bool ] = False ,
456468 use_cache : Optional [bool ] = False ,
457469 cache_position : Optional [torch .LongTensor ] = None ,
458- position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
459470 token_idx : Optional [torch .Tensor ] = None ,
460471 attn_softmax_bf16 : Optional [bool ] = False ,
461472 reuse_cache : Optional [bool ] = False ,
462473 use_flash_attention : Optional [bool ] = False ,
463474 flash_attention_recompute : Optional [bool ] = False ,
464475 flash_attention_causal_mask : Optional [bool ] = False ,
465- cache_idx : Optional [int ] = None ,
476+ cache_idx : int = None ,
477+ ** kwargs ,
466478 ) -> tuple [torch .FloatTensor , Optional [tuple [torch .FloatTensor , torch .FloatTensor ]]]:
467479 hidden_states = self .input_layernorm (hidden_states )
468480 hidden_states , attn_weights , present_key_value = self .self_attn .pre_attn_forward (
469481 hidden_states = hidden_states ,
470482 attention_mask = attention_mask ,
471483 position_ids = position_ids ,
472484 past_key_value = past_key_value ,
485+ output_attentions = output_attentions ,
473486 use_cache = use_cache ,
474487 cache_position = cache_position ,
475- position_embeddings = position_embeddings ,
476488 token_idx = token_idx ,
477489 attn_softmax_bf16 = attn_softmax_bf16 ,
478490 reuse_cache = reuse_cache ,
479491 use_flash_attention = use_flash_attention ,
480492 flash_attention_recompute = flash_attention_recompute ,
481493 flash_attention_causal_mask = flash_attention_causal_mask ,
482494 cache_idx = cache_idx ,
495+ position_embeddings = position_embeddings ,
483496 )
484497 return hidden_states , attn_weights , present_key_value
485498
486499 def forward (
487500 self ,
488- hidden_states : torch .Tensor ,
501+ hidden_states : torch .Tensor = None ,
489502 attention_mask : Optional [torch .Tensor ] = None ,
490503 position_ids : Optional [torch .LongTensor ] = None ,
491504 past_key_value : Optional [Cache ] = None ,
505+ output_attentions : Optional [bool ] = False ,
492506 use_cache : Optional [bool ] = False ,
493507 cache_position : Optional [torch .LongTensor ] = None ,
494- token_idx : Optional [torch .Tensor ] = None ,
495- attn_softmax_bf16 : Optional [bool ] = False ,
496- reuse_cache : Optional [bool ] = False ,
497- use_flash_attention : Optional [bool ] = False ,
498- flash_attention_recompute : Optional [bool ] = False ,
499- flash_attention_causal_mask : Optional [bool ] = False ,
500- cache_idx : Optional [int ] = None ,
508+ * args ,
509+ ** kwargs ,
501510 ) -> tuple [torch .FloatTensor , Optional [tuple [torch .FloatTensor , torch .FloatTensor ]]]:
502511 """
503512 Copied from GemmaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
504513 The only differences are:
505514 - add new args token_idx
506515 - add new args attn_softmax_bf16
507516 """
517+ token_idx = kwargs .get ("token_idx" , None )
518+ attn_softmax_bf16 = kwargs .get ("attn_softmax_bf16" , False )
519+ reuse_cache = kwargs .get ("reuse_cache" , False )
520+ use_flash_attention = kwargs .get ("use_flash_attention" , False )
521+ flash_attention_recompute = kwargs .get ("flash_attention_recompute" , False )
522+ flash_attention_causal_mask = kwargs .get ("flash_attention_causal_mask" , False )
523+ cache_idx = kwargs .get ("cache_idx" , None )
524+ position_embeddings = kwargs .get ("position_embeddings" , None )
525+
526+ if hidden_states is None :
527+ hidden_states = kwargs .get ("hidden_states" , None )
528+
529+ if hidden_states is None :
530+ raise ValueError ("hidden_states is required but missing." )
531+
508532 residual = hidden_states
509533
510534 hidden_states , self_attn_weights , present_key_value = self .pre_attn (
511535 hidden_states = hidden_states ,
536+ position_embeddings = position_embeddings ,
512537 attention_mask = attention_mask ,
513538 position_ids = position_ids ,
514539 past_key_value = past_key_value ,
540+ output_attentions = output_attentions ,
515541 use_cache = use_cache ,
516542 cache_position = cache_position ,
517543 token_idx = token_idx ,
@@ -532,13 +558,18 @@ def forward(
532558 hidden_states = self .post_mlp (hidden_states , residual )
533559
534560 outputs = (hidden_states ,)
561+
562+ if output_attentions :
563+ outputs += (self_attn_weights ,)
564+
535565 if use_cache :
536566 outputs += (present_key_value ,)
537567
538568 return outputs
539569
540570 def post_attn_pre_mlp (self , hidden_states , residual ):
541571 hidden_states = self .self_attn .post_attn_forward (hidden_states )
572+ hidden_states = self .post_attention_layernorm (hidden_states )
542573
543574 if self .training :
544575 hidden_states = hidden_states + residual
@@ -547,8 +578,6 @@ def post_attn_pre_mlp(self, hidden_states, residual):
547578 residual .add_ (hidden_states )
548579 hidden_states = residual
549580
550- hidden_states = self .post_attention_layernorm (hidden_states )
551-
552581 hidden_states = self .mlp .pre_mlp_forward (hidden_states )
553582 return hidden_states , residual
554583
@@ -660,7 +689,7 @@ def forward(
660689 htcore .mark_step ()
661690
662691 layer_outputs = decoder_layer (
663- hidden_states ,
692+ hidden_states = hidden_states ,
664693 attention_mask = attention_mask ,
665694 position_ids = position_ids ,
666695 past_key_value = None if past_key_values is None else past_key_values [layer_idx ],
0 commit comments