Skip to content

Commit 0e9e778

Browse files
karol-brejna-iastachowiczhabanaregisss
authored
CWE 476 llava, gpt2, falcon, cohere, all_model (#2208)
Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
1 parent 0ce775f commit 0e9e778

5 files changed

Lines changed: 13 additions & 5 deletions

File tree

optimum/habana/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def prepare_inputs_for_generation(
323323
# The clone here is for the same reason as for `position_ids`.
324324
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
325325

326-
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
326+
if isinstance(past_key_values, StaticCache) and attention_mask is not None and attention_mask.ndim == 2:
327327
if model_inputs["inputs_embeds"] is not None:
328328
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
329329
device = model_inputs["inputs_embeds"].device

optimum/habana/transformers/models/falcon/modeling_falcon.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,13 @@ def pre_attn_forward(
364364
else:
365365
kv_length = present[0][-2] if reuse_cache else present[0].shape[-2]
366366

367-
if (not reuse_cache) and (token_idx is not None) and (cache_idx is not None) and (query_length == 1):
367+
if (
368+
(not reuse_cache)
369+
and (token_idx is not None)
370+
and (cache_idx is not None)
371+
and (query_length == 1)
372+
and (present is not None)
373+
):
368374
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
369375
# to avoid making past key values as persistent output tensors of HPU graphs.
370376
present = (present[0].shape, present[1].shape)

optimum/habana/transformers/models/gpt_neo/modeling_gpt_neo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def gaudi_gpt_neo_model_forward(
157157
use_cache = use_cache if use_cache is not None else self.config.use_cache
158158
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
159159

160-
if input_ids is not None and inputs_embeds is not None:
160+
if (input_ids is None) == (inputs_embeds is None):
161161
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
162162
elif input_ids is not None:
163163
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)

optimum/habana/transformers/models/llava/modeling_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def forward(
165165

166166
image_features = None
167167
# 2. Merge text and images
168-
if pixel_values is not None and input_ids.shape[1] != 1:
168+
if pixel_values is not None and input_ids is not None and input_ids.shape[1] != 1:
169169
image_outputs = self.vision_tower(
170170
pixel_values,
171171
output_hidden_states=True,

optimum/habana/transformers/models/modeling_all_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ def gaudi_invert_attention_mask(self, encoder_attention_mask: torch.Tensor) -> t
113113
"""
114114
if encoder_attention_mask.dim() == 3:
115115
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
116-
if encoder_attention_mask.dim() == 2:
116+
elif encoder_attention_mask.dim() == 2:
117117
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
118+
else:
119+
raise ValueError(f"encoder_attention_mask must be 2D or 3D, but got shape {encoder_attention_mask.shape}")
118120
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
119121
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
120122
# /transformer/transformer_layers.py#L270

0 commit comments

Comments
 (0)