Skip to content

Commit 8c36e02

Browse files
authored
Align MLlama code with Transformers 4.55 (#2319)
1 parent e2118d0 commit 8c36e02

6 files changed

Lines changed: 150 additions & 357 deletions

File tree

examples/image-to-text/run_pipeline.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import logging
2020
import os
21+
from contextlib import nullcontext
2122
from pathlib import Path
2223

2324
import PIL.Image
@@ -353,51 +354,53 @@ def main():
353354

354355
htcore.hpu_set_env()
355356

357+
if model_type == "mllama" and args.use_flash_attention:
358+
config._attn_implementation = "gaudi_fused_sdpa"
359+
if args.flash_attention_recompute:
360+
os.environ["FLASH_ATTENTION_RECOMPUTE"] = "1"
361+
356362
if args.world_size > 1:
357363
import deepspeed
358364

359-
with deepspeed.OnDevice(dtype=model_dtype, device="cpu"):
360-
model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype)
365+
context = deepspeed.OnDevice(dtype=model_dtype, device="cpu")
366+
else:
367+
context = nullcontext()
368+
369+
with context:
370+
model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, config=config)
371+
372+
if args.world_size > 1:
361373
if model_type == "mllama":
362374
model.language_model = initialize_distributed_model(args, model.language_model, logger, model_dtype)
363375
model.to("hpu")
364376
else:
365377
model = initialize_distributed_model(args, model, logger, model_dtype)
366-
generator = pipeline(
367-
"image-to-text",
368-
model=model,
369-
config=args.model_name_or_path,
370-
tokenizer=args.model_name_or_path,
371-
image_processor=args.model_name_or_path,
372-
torch_dtype=model_dtype,
373-
device="hpu",
374-
)
375-
else:
376-
generator = pipeline(
377-
"image-to-text",
378-
model=args.model_name_or_path,
379-
config=args.model_name_or_path,
380-
tokenizer=args.model_name_or_path,
381-
image_processor=None if model_type == "chatglm" else args.model_name_or_path,
382-
torch_dtype=model_dtype,
383-
device="hpu",
384-
)
385-
if args.use_hpu_graphs:
386-
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
387378

388-
generator.model = wrap_in_hpu_graph(generator.model)
379+
generator = pipeline(
380+
"image-to-text",
381+
model=model,
382+
config=config,
383+
tokenizer=args.model_name_or_path,
384+
image_processor=None if model_type == "chatglm" else args.model_name_or_path,
385+
torch_dtype=model_dtype,
386+
device="hpu",
387+
)
388+
389+
if args.world_size < 2 and args.use_hpu_graphs:
390+
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
391+
392+
generator.model = wrap_in_hpu_graph(generator.model)
389393

390394
if "falcon-11B-vlm" in args.model_name_or_path:
391395
# WA falcon vlm issue that image_token_id == embed size.
392396
generator.model.resize_token_embeddings(generator.tokenizer.vocab_size + 1)
393397
processor.patch_size = config.vision_config.patch_size
398+
394399
generate_kwargs = {
395400
"lazy_mode": use_lazy_mode,
396401
"hpu_graphs": args.use_hpu_graphs,
397402
"max_new_tokens": args.max_new_tokens,
398403
"ignore_eos": args.ignore_eos,
399-
"use_flash_attention": args.use_flash_attention,
400-
"flash_attention_recompute": args.flash_attention_recompute,
401404
"bucket_internal": args.bucket_internal,
402405
"bucket_size": args.bucket_size,
403406
"limit_hpu_graphs": args.limit_hpu_graphs,
@@ -406,6 +409,14 @@ def main():
406409
"logits_bf16": args.logits_bf16,
407410
}
408411

412+
if model_type != "mllama":
413+
generate_kwargs.update(
414+
{
415+
"use_flash_attention": args.use_flash_attention,
416+
"flash_attention_recompute": args.flash_attention_recompute,
417+
}
418+
)
419+
409420
if args.sdp_on_bf16:
410421
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
411422

optimum/habana/transformers/integrations/gaudi_fused_sdpa_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def gaudi_fused_sdpa_attention_forward(
1818
) -> tuple[torch.Tensor, None]:
1919
bsz, num_heads, tgt_len, head_dim = query.shape
2020

21+
softmax_mode = "fast" if os.getenv("FLASH_ATTENTION_FAST_SOFTMAX") == "1" else "None"
22+
2123
if tgt_len == 1:
2224
# next token
23-
softmax_mode = True if os.getenv("QUANT_CONFIG", "") else False
24-
recompute_mode = False
25+
recompute_mode = True if os.getenv("QUANT_CONFIG", "") else False
2526
else:
2627
# first token
27-
softmax_mode = "fast" if os.getenv("FLASH_ATTENTION_FAST_SOFTMAX") == "1" else "None"
2828
recompute_mode = True if os.getenv("FLASH_ATTENTION_RECOMPUTE") == "1" else False
2929

3030
attn_output = FusedSDPA.apply(

optimum/habana/transformers/modeling_utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@
138138
GaudiMllamaTextModel,
139139
GaudiMllamaTextSelfAttention,
140140
GaudiMllamaVisionEncoder,
141-
GaudiMllamaVisionEncoderLayer,
142141
GaudiMllamaVisionModel,
143-
GaudiMllamaVisionSdpaAttention,
144142
GaudiMptAttention,
145143
GaudiMptBlock,
146144
GaudiMptForCausalLM,
@@ -847,19 +845,18 @@ def adapt_transformers_to_gaudi():
847845
transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration
848846

849847
# Optimization for mllama on Gaudi
850-
transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer
851848
transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer
852849
transformers.models.mllama.modeling_mllama.MllamaForCausalLM = GaudiMllamaForCausalLM
853-
transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention
854-
transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention
855850
transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration
851+
transformers.models.mllama.modeling_mllama.MllamaModel = GaudiMllamaModel
852+
transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer
853+
transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention
856854
transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel
857-
transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel
855+
transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention
858856
transformers.models.mllama.modeling_mllama.MllamaVisionEncoder = GaudiMllamaVisionEncoder
859-
transformers.models.mllama.modeling_mllama.MllamaVisionEncoderLayer = GaudiMllamaVisionEncoderLayer
860-
transformers.models.mllama.modeling_mllama.MllamaVisionSdpaAttention = GaudiMllamaVisionSdpaAttention
861-
transformers.models.mllama.modeling_mllama.MllamaModel = GaudiMllamaModel
857+
transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel
862858

859+
# Optimization for deciLM on Gaudi
863860
transformers.AutoConfig.register("deci", DeciLMConfig)
864861
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)
865862

optimum/habana/transformers/models/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,7 @@
226226
GaudiMllamaTextModel,
227227
GaudiMllamaTextSelfAttention,
228228
GaudiMllamaVisionEncoder,
229-
GaudiMllamaVisionEncoderLayer,
230229
GaudiMllamaVisionModel,
231-
GaudiMllamaVisionSdpaAttention,
232230
)
233231
from .modeling_all_models import (
234232
KVCache,

optimum/habana/transformers/models/mllama/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,5 @@
88
GaudiMllamaTextModel,
99
GaudiMllamaTextSelfAttention,
1010
GaudiMllamaVisionEncoder,
11-
GaudiMllamaVisionEncoderLayer,
1211
GaudiMllamaVisionModel,
13-
GaudiMllamaVisionSdpaAttention,
1412
)

0 commit comments

Comments
 (0)