Skip to content

Commit 83b3460

Browse files
Fix a few Transformers 4.55 issues (#2282)
Co-authored-by: IlyasMoutawwakil <ilyas.moutawwakil@gmail.com>
1 parent 1ce0173 commit 83b3460

File tree

8 files changed

+45
-36
lines changed

8 files changed

+45
-36
lines changed

optimum/habana/transformers/models/baichuan/modeling_baichuan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4444
from transformers.utils import logging
4545

46+
from ...generation.utils import GaudiGenerationMixin
4647
from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
4748
from .configuration_baichuan import BaichuanConfig
4849
from .generation_utils import TextIterStreamer, build_chat_input
@@ -1163,7 +1164,7 @@ def no_init_weights(_enable=True):
11631164
_init_weights = old_init_weights
11641165

11651166

1166-
class BaichuanForCausalLM(BaichuanPreTrainedModel):
1167+
class BaichuanForCausalLM(BaichuanPreTrainedModel, GaudiGenerationMixin):
11671168
def __init__(self, config, *model_args, **model_kwargs):
11681169
super().__init__(config, *model_args, **model_kwargs)
11691170
self.model = BaichuanModel(config)

optimum/habana/transformers/models/chatglm/modeling_chatglm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from transformers.utils import logging
4444

4545
from ....utils import warn0
46+
from ...generation.utils import GaudiGenerationMixin
4647
from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
4748
from .configuration_chatglm import ChatGLMConfig
4849

@@ -1347,7 +1348,7 @@ def forward(
13471348
)
13481349

13491350

1350-
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1351+
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GaudiGenerationMixin):
13511352
def __init__(self, config: ChatGLMConfig, empty_init=False, device=None):
13521353
super().__init__(config)
13531354

optimum/habana/transformers/models/gemma/modeling_gemma.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def forward(
661661
):
662662
htcore.mark_step()
663663

664-
hidden_states = decoder_layer(
664+
layer_outputs = decoder_layer(
665665
hidden_states,
666666
attention_mask=attention_mask,
667667
position_ids=position_ids,
@@ -678,8 +678,10 @@ def forward(
678678
**kwargs,
679679
)
680680

681+
hidden_states = layer_outputs[0]
682+
681683
if use_cache:
682-
next_decoder_cache += (hidden_states[1],)
684+
next_decoder_cache += (layer_outputs[1],)
683685

684686
hidden_states = self.norm(hidden_states)
685687

optimum/habana/transformers/models/minicpm/modeling_minicpm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from transformers.utils.import_utils import is_torch_fx_available
5555

5656
from ....utils import warn0
57+
from ...generation.utils import GaudiGenerationMixin
5758
from .configuration_minicpm import MiniCPM3Config
5859

5960

@@ -505,8 +506,7 @@ def forward(
505506
value_states = past_value_states.index_add(
506507
-2, token_idx - 1, value_states - torch.index_select(past_value_states, -2, token_idx - 1)
507508
)
508-
past_key_value.key_cache[self.layer_idx] = key_states
509-
past_key_value.value_cache[self.layer_idx] = value_states
509+
past_key_value.update(key_states, value_states, self.layer_idx)
510510

511511
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
512512

@@ -644,8 +644,7 @@ def forward(
644644
value_states = past_value_states.index_add(
645645
-2, token_idx - 1, value_states - torch.index_select(past_value_states, -2, token_idx - 1)
646646
)
647-
past_key_value.key_cache[self.layer_idx] = key_states
648-
past_key_value.value_cache[self.layer_idx] = value_states
647+
past_key_value.update(key_states, value_states, self.layer_idx)
649648

650649
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
651650
# to be able to avoid many of these transpose/reshape/view.
@@ -854,7 +853,7 @@ def forward(
854853
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
855854
"with a layer index."
856855
)
857-
usable_length = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
856+
usable_length = past_key_value.get_seq_length(self.layer_idx)
858857
if token_idx is None:
859858
kv_seq_len += usable_length
860859
elif usable_length > 0:
@@ -1023,7 +1022,9 @@ def forward(
10231022
"The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
10241023
MINICPM_START_DOCSTRING,
10251024
)
1026-
class MiniCPM3PreTrainedModel(PreTrainedModel):
1025+
class MiniCPM3PreTrainedModel(
1026+
PreTrainedModel,
1027+
):
10271028
config_class = MiniCPM3Config
10281029
base_model_prefix = "model"
10291030
supports_gradient_checkpointing = True
@@ -1200,7 +1201,7 @@ def forward(
12001201
use_legacy_cache = not isinstance(past_key_values, Cache)
12011202
if use_legacy_cache:
12021203
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1203-
past_key_values_length = past_key_values.get_usable_length(seq_length)
1204+
past_key_values_length = past_key_values.get_seq_length()
12041205

12051206
if position_ids is None:
12061207
device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1292,7 +1293,7 @@ def forward(
12921293
)
12931294

12941295

1295-
class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
1296+
class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel, GaudiGenerationMixin):
12961297
_tied_weights_keys = ["lm_head.weight"]
12971298

12981299
def __init__(self, config):

optimum/habana/transformers/models/mistral/modeling_mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def forward(
385385
hidden_states = self.input_layernorm(hidden_states)
386386

387387
# Self Attention
388-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
388+
hidden_states, _, present_key_value = self.self_attn(
389389
hidden_states=hidden_states,
390390
attention_mask=attention_mask,
391391
position_ids=position_ids,

optimum/habana/transformers/models/persimmon/modeling_persimmon.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ def forward(
7373
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
7474
"with a layer index."
7575
)
76-
if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0:
76+
if token_idx is not None and past_key_value.get_seq_length(kv_seq_len, self.layer_idx) > 0:
7777
# When token_idx is used, static seq len = (input token len + max output token len)
78-
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
78+
kv_seq_len = past_key_value.get_seq_length(kv_seq_len, self.layer_idx)
7979
else:
80-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
80+
kv_seq_len += past_key_value.get_seq_length(kv_seq_len, self.layer_idx)
8181
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
8282

8383
# Partial rotary embedding
@@ -98,14 +98,16 @@ def forward(
9898

9999
if past_key_value is not None:
100100
if token_idx is not None:
101-
if 0 <= self.layer_idx < len(past_key_value.key_cache):
102-
past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states)
103-
past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states)
104-
key_states = past_key_value.key_cache[self.layer_idx]
105-
value_states = past_key_value.value_cache[self.layer_idx]
101+
if (
102+
0 <= self.layer_idx < len(past_key_value)
103+
and past_key_value.layers[self.layer_idx].keys is not None
104+
):
105+
past_key_value.layers[self.layer_idx].keys.index_copy_(2, token_idx - 1, key_states)
106+
past_key_value.layers[self.layer_idx].values.index_copy_(2, token_idx - 1, value_states)
107+
key_states = past_key_value.layers[self.layer_idx].keys
108+
value_states = past_key_value.layers[self.layer_idx].values
106109
else:
107-
past_key_value.key_cache.append(key_states)
108-
past_key_value.value_cache.append(value_states)
110+
past_key_value.update(key_states, value_states, self.layer_idx)
109111
else:
110112
# Specific to RoPE models with partial rotation
111113
cache_kwargs = {

optimum/habana/transformers/models/stablelm/modeling_stablelm.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ def forward(
7171
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
7272
"with a layer index."
7373
)
74-
if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0:
74+
if token_idx is not None and past_key_value.get_seq_length(self.layer_idx) > 0:
7575
# When token_idx is used, static seq len = (input token len + max output token len)
76-
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
76+
kv_seq_len = past_key_value.get_seq_length(self.layer_idx)
7777
else:
78-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
78+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
7979
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
8080

8181
# Partial rotary embedding
@@ -96,14 +96,16 @@ def forward(
9696

9797
if past_key_value is not None:
9898
if token_idx is not None:
99-
if 0 <= self.layer_idx < len(past_key_value.key_cache):
100-
past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states)
101-
past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states)
102-
key_states = past_key_value.key_cache[self.layer_idx]
103-
value_states = past_key_value.value_cache[self.layer_idx]
99+
if (
100+
0 <= self.layer_idx < len(past_key_value)
101+
and past_key_value.layers[self.layer_idx].keys is not None
102+
):
103+
past_key_value.layers[self.layer_idx].keys.index_copy_(2, token_idx - 1, key_states)
104+
past_key_value.layers[self.layer_idx].values.index_copy_(2, token_idx - 1, value_states)
105+
key_states = past_key_value.layers[self.layer_idx].keys
106+
value_states = past_key_value.layers[self.layer_idx].values
104107
else:
105-
past_key_value.key_cache.append(key_states)
106-
past_key_value.value_cache.append(value_states)
108+
past_key_value.update(key_states, value_states, self.layer_idx)
107109
else:
108110
# Specific to RoPE models with partial rotation
109111
cache_kwargs = {

tests/baselines/fixture/tests/test_text_generation_example.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@
305305
"throughput": 109.70751574382221
306306
},
307307
"gaudi3": {
308-
"output": "DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models. DeepSpeed is designed to be scalable, and it can be used to train models on a single machine or on a cluster of machines. DeepSpeed is designed to be efficient,",
308+
"output": "DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch and is compatible with existing PyTorch code. DeepSpeed is open source and available on GitHub.\n\nDeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch and is compatible with existing PyTorch code. DeepSpeed is open source and available on GitHub.\n\n<h2>What is",
309309
"throughput": 135.97272017864475
310310
}
311311
},
@@ -415,7 +415,7 @@
415415
"throughput": 134.94827207337997
416416
},
417417
"gaudi3": {
418-
"output": "DeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system.\n\nDeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system",
418+
"output": "DeepSpeed is a machine learning framework that accelerates training and inference of deep learning models. It is designed to be flexible and easy to use, with a focus on performance and scalability. DeepSpeed is built on top of PyTorch, and it provides a set of tools and libraries that can be used to optimize the training and inference of deep learning models.\n\nDeepSpeed is designed to be used with a variety of hardware platforms, including GPUs, TPUs, and CPUs. It provides a",
419419
"throughput": 160.48685620965531
420420
}
421421
},
@@ -425,7 +425,7 @@
425425
"throughput": 71.29570003665306
426426
},
427427
"gaudi3": {
428-
"output": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with multiple GPUs. It is designed to be easy to use and efficient, and it supports a wide range of models and tasks.\n\nDeepSpeed is a deep learning framework that enables training of large models on a single machine with multiple GPUs. It is designed to be easy to use and efficient, and it supports a wide range of models and tasks.\n\nDeepSpeed is a deep learning framework that enables training of large models on a",
428+
"output": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\nThe latest DeepSpeed for PC has come up with a few updates that are better than the previous version. Want to know those? Here are they:\n\n## DeepSpeed Andorid App Summary\n\nDeepSpeed has developed the DeepSpeed for Android. You can find it under the",
429429
"throughput": 81.6817273229847
430430
}
431431
},

0 commit comments

Comments
 (0)