Skip to content

Commit 14022ed

Browse files
astachowiczhabanaAdam Stachowicz
andauthored
Fixes for Device Mismatch and Configuration Conflict (#2283)
Co-authored-by: Adam Stachowicz <astachow@habana.ai>
1 parent 83b3460 commit 14022ed

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

optimum/habana/transformers/modeling_attn_mask_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,23 @@ def to_4d(
132132
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
133133
return self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(device)
134134

135+
@staticmethod
136+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
137+
"""
138+
Adapted from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py#L187
139+
140+
Differences:
141+
- inverted_mask tensor is placed on the same device as the input mask to avoid device mismatch errors.
142+
"""
143+
bsz, src_len = mask.size()
144+
tgt_len = tgt_len if tgt_len is not None else src_len
145+
146+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
147+
148+
inverted_mask = torch.tensor(1.0, dtype=dtype, device=mask.device) - expanded_mask
149+
150+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
151+
135152

136153
def _gaudi_prepare_4d_causal_attention_mask(
137154
attention_mask: Optional[torch.Tensor],

optimum/habana/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def adapt_transformers_to_gaudi():
522522
transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer
523523
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = GaudiLlamaRotaryEmbedding
524524
transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward
525-
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
525+
transformers.AutoConfig.register("llama", LlamaConfig, exist_ok=True)
526526

527527
# Optimization for llava on Gaudi
528528
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration = GaudiLlavaForConditionalGeneration

0 commit comments

Comments
 (0)