Skip to content

Commit 840e497

Browse files
author
Grzegorz Pluto-Prondzinski
authored
Fix gradient checkpointing in GaudiGemmaDecoderLayer + align forward() signature with HF 4.55 (#2353)
1 parent 1ab3548 commit 840e497

File tree

1 file changed

+62
-33
lines changed

1 file changed

+62
-33
lines changed

optimum/habana/transformers/models/gemma/modeling_gemma.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,11 @@ def gaudi_flash_attn_v1(
272272
def pre_attn_forward(
273273
self,
274274
hidden_states: torch.Tensor,
275-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
276275
attention_mask: Optional[torch.Tensor],
276+
position_ids: Optional[torch.LongTensor] = None,
277277
past_key_value: Optional[Cache] = None,
278-
use_cache: bool = False,
278+
output_attentions: Optional[bool] = False,
279+
use_cache: Optional[bool] = False,
279280
cache_position: Optional[torch.LongTensor] = None,
280281
token_idx: Optional[torch.Tensor] = None,
281282
attn_softmax_bf16: Optional[bool] = False,
@@ -284,6 +285,7 @@ def pre_attn_forward(
284285
flash_attention_recompute: Optional[bool] = False,
285286
flash_attention_causal_mask: Optional[bool] = False,
286287
cache_idx: Optional[int] = None,
288+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
287289
**kwargs,
288290
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
289291
"""
@@ -295,31 +297,38 @@ def pre_attn_forward(
295297
- add new args use_flash_attention
296298
- add new arg flash_attention_recompute
297299
"""
298-
input_shape = hidden_states.shape[:-1]
299-
q_len = input_shape[1]
300+
bsz, q_len = hidden_states.shape[:2]
301+
input_shape = (bsz, q_len)
300302
hidden_shape = (*input_shape, -1, self.head_dim)
301303

302304
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
303305
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
304306
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
305307

308+
if position_ids is None:
309+
position_ids = kwargs.get("position_ids")
310+
306311
kv_seq_len = key_states.shape[-2]
307312
if past_key_value is not None:
308-
if token_idx is None:
309-
if hasattr(past_key_value, "get_usable_length"):
310-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
311-
else:
312-
kv_seq_len += past_key_value[0].shape[-2]
313+
if hasattr(past_key_value, "get_usable_length"):
314+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
313315
else:
314-
if reuse_cache:
315-
kv_seq_len = past_key_value[0][-2]
316-
else:
317-
kv_seq_len = past_key_value[0].shape[-2]
316+
kv_seq_len += past_key_value[0].shape[-2]
318317

319-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
320-
query_states, key_states = apply_rotary_pos_emb(
321-
query_states, key_states, cos[kwargs["position_ids"]], sin[kwargs["position_ids"]]
322-
)
318+
if position_embeddings is not None:
319+
cos_emb, sin_emb = position_embeddings
320+
else:
321+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
322+
323+
if position_ids is None:
324+
start = kv_seq_len - q_len
325+
position_ids = torch.arange(
326+
start, start + q_len, dtype=torch.long, device=query_states.device
327+
).unsqueeze(0)
328+
cos_emb = cos[position_ids]
329+
sin_emb = sin[position_ids]
330+
331+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos_emb, sin_emb)
323332

324333
if use_cache:
325334
# reuse k, v, self_attention
@@ -334,8 +343,10 @@ def pre_attn_forward(
334343
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
335344
)
336345
past_key_value = (past_key, past_value)
346+
337347
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
338348
value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
349+
339350
if token_idx is None:
340351
past_key_value = (key_states, value_states)
341352

@@ -386,7 +397,6 @@ def pre_attn_forward(
386397
None,
387398
flash_attention_recompute,
388399
)
389-
390400
else:
391401
attn_output, attn_weights = gaudi_eager_attention_forward(
392402
self,
@@ -450,68 +460,84 @@ def update_sincos_cache(self, seq_len):
450460
def pre_attn(
451461
self,
452462
hidden_states: torch.Tensor,
463+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
453464
attention_mask: Optional[torch.Tensor] = None,
454465
position_ids: Optional[torch.LongTensor] = None,
455-
past_key_value: Optional[tuple[torch.Tensor]] = None,
466+
past_key_value: Optional[Cache] = None,
467+
output_attentions: Optional[bool] = False,
456468
use_cache: Optional[bool] = False,
457469
cache_position: Optional[torch.LongTensor] = None,
458-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
459470
token_idx: Optional[torch.Tensor] = None,
460471
attn_softmax_bf16: Optional[bool] = False,
461472
reuse_cache: Optional[bool] = False,
462473
use_flash_attention: Optional[bool] = False,
463474
flash_attention_recompute: Optional[bool] = False,
464475
flash_attention_causal_mask: Optional[bool] = False,
465-
cache_idx: Optional[int] = None,
476+
cache_idx: int = None,
477+
**kwargs,
466478
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
467479
hidden_states = self.input_layernorm(hidden_states)
468480
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
469481
hidden_states=hidden_states,
470482
attention_mask=attention_mask,
471483
position_ids=position_ids,
472484
past_key_value=past_key_value,
485+
output_attentions=output_attentions,
473486
use_cache=use_cache,
474487
cache_position=cache_position,
475-
position_embeddings=position_embeddings,
476488
token_idx=token_idx,
477489
attn_softmax_bf16=attn_softmax_bf16,
478490
reuse_cache=reuse_cache,
479491
use_flash_attention=use_flash_attention,
480492
flash_attention_recompute=flash_attention_recompute,
481493
flash_attention_causal_mask=flash_attention_causal_mask,
482494
cache_idx=cache_idx,
495+
position_embeddings=position_embeddings,
483496
)
484497
return hidden_states, attn_weights, present_key_value
485498

486499
def forward(
487500
self,
488-
hidden_states: torch.Tensor,
501+
hidden_states: torch.Tensor = None,
489502
attention_mask: Optional[torch.Tensor] = None,
490503
position_ids: Optional[torch.LongTensor] = None,
491504
past_key_value: Optional[Cache] = None,
505+
output_attentions: Optional[bool] = False,
492506
use_cache: Optional[bool] = False,
493507
cache_position: Optional[torch.LongTensor] = None,
494-
token_idx: Optional[torch.Tensor] = None,
495-
attn_softmax_bf16: Optional[bool] = False,
496-
reuse_cache: Optional[bool] = False,
497-
use_flash_attention: Optional[bool] = False,
498-
flash_attention_recompute: Optional[bool] = False,
499-
flash_attention_causal_mask: Optional[bool] = False,
500-
cache_idx: Optional[int] = None,
508+
*args,
509+
**kwargs,
501510
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
502511
"""
503512
Copied from GemmaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
504513
The only differences are:
505514
- add new args token_idx
506515
- add new args attn_softmax_bf16
507516
"""
517+
token_idx = kwargs.get("token_idx", None)
518+
attn_softmax_bf16 = kwargs.get("attn_softmax_bf16", False)
519+
reuse_cache = kwargs.get("reuse_cache", False)
520+
use_flash_attention = kwargs.get("use_flash_attention", False)
521+
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
522+
flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", False)
523+
cache_idx = kwargs.get("cache_idx", None)
524+
position_embeddings = kwargs.get("position_embeddings", None)
525+
526+
if hidden_states is None:
527+
hidden_states = kwargs.get("hidden_states", None)
528+
529+
if hidden_states is None:
530+
raise ValueError("hidden_states is required but missing.")
531+
508532
residual = hidden_states
509533

510534
hidden_states, self_attn_weights, present_key_value = self.pre_attn(
511535
hidden_states=hidden_states,
536+
position_embeddings=position_embeddings,
512537
attention_mask=attention_mask,
513538
position_ids=position_ids,
514539
past_key_value=past_key_value,
540+
output_attentions=output_attentions,
515541
use_cache=use_cache,
516542
cache_position=cache_position,
517543
token_idx=token_idx,
@@ -532,13 +558,18 @@ def forward(
532558
hidden_states = self.post_mlp(hidden_states, residual)
533559

534560
outputs = (hidden_states,)
561+
562+
if output_attentions:
563+
outputs += (self_attn_weights,)
564+
535565
if use_cache:
536566
outputs += (present_key_value,)
537567

538568
return outputs
539569

540570
def post_attn_pre_mlp(self, hidden_states, residual):
541571
hidden_states = self.self_attn.post_attn_forward(hidden_states)
572+
hidden_states = self.post_attention_layernorm(hidden_states)
542573

543574
if self.training:
544575
hidden_states = hidden_states + residual
@@ -547,8 +578,6 @@ def post_attn_pre_mlp(self, hidden_states, residual):
547578
residual.add_(hidden_states)
548579
hidden_states = residual
549580

550-
hidden_states = self.post_attention_layernorm(hidden_states)
551-
552581
hidden_states = self.mlp.pre_mlp_forward(hidden_states)
553582
return hidden_states, residual
554583

@@ -660,7 +689,7 @@ def forward(
660689
htcore.mark_step()
661690

662691
layer_outputs = decoder_layer(
663-
hidden_states,
692+
hidden_states=hidden_states,
664693
attention_mask=attention_mask,
665694
position_ids=position_ids,
666695
past_key_value=None if past_key_values is None else past_key_values[layer_idx],

0 commit comments

Comments
 (0)