1818import json
1919import logging
2020import os
21+ from contextlib import nullcontext
2122from pathlib import Path
2223
2324import PIL .Image
@@ -353,51 +354,53 @@ def main():
353354
354355 htcore .hpu_set_env ()
355356
357+ if model_type == "mllama" and args .use_flash_attention :
358+ config ._attn_implementation = "gaudi_fused_sdpa"
359+ if args .flash_attention_recompute :
360+ os .environ ["FLASH_ATTENTION_RECOMPUTE" ] = "1"
361+
356362 if args .world_size > 1 :
357363 import deepspeed
358364
359- with deepspeed .OnDevice (dtype = model_dtype , device = "cpu" ):
360- model = AutoModelForVision2Seq .from_pretrained (args .model_name_or_path , torch_dtype = model_dtype )
365+ context = deepspeed .OnDevice (dtype = model_dtype , device = "cpu" )
366+ else :
367+ context = nullcontext ()
368+
369+ with context :
370+ model = AutoModelForVision2Seq .from_pretrained (args .model_name_or_path , torch_dtype = model_dtype , config = config )
371+
372+ if args .world_size > 1 :
361373 if model_type == "mllama" :
362374 model .language_model = initialize_distributed_model (args , model .language_model , logger , model_dtype )
363375 model .to ("hpu" )
364376 else :
365377 model = initialize_distributed_model (args , model , logger , model_dtype )
366- generator = pipeline (
367- "image-to-text" ,
368- model = model ,
369- config = args .model_name_or_path ,
370- tokenizer = args .model_name_or_path ,
371- image_processor = args .model_name_or_path ,
372- torch_dtype = model_dtype ,
373- device = "hpu" ,
374- )
375- else :
376- generator = pipeline (
377- "image-to-text" ,
378- model = args .model_name_or_path ,
379- config = args .model_name_or_path ,
380- tokenizer = args .model_name_or_path ,
381- image_processor = None if model_type == "chatglm" else args .model_name_or_path ,
382- torch_dtype = model_dtype ,
383- device = "hpu" ,
384- )
385- if args .use_hpu_graphs :
386- from habana_frameworks .torch .hpu import wrap_in_hpu_graph
387378
388- generator .model = wrap_in_hpu_graph (generator .model )
379+ generator = pipeline (
380+ "image-to-text" ,
381+ model = model ,
382+ config = config ,
383+ tokenizer = args .model_name_or_path ,
384+ image_processor = None if model_type == "chatglm" else args .model_name_or_path ,
385+ torch_dtype = model_dtype ,
386+ device = "hpu" ,
387+ )
388+
389+ if args .world_size < 2 and args .use_hpu_graphs :
390+ from habana_frameworks .torch .hpu import wrap_in_hpu_graph
391+
392+ generator .model = wrap_in_hpu_graph (generator .model )
389393
390394 if "falcon-11B-vlm" in args .model_name_or_path :
391395 # WA falcon vlm issue that image_token_id == embed size.
392396 generator .model .resize_token_embeddings (generator .tokenizer .vocab_size + 1 )
393397 processor .patch_size = config .vision_config .patch_size
398+
394399 generate_kwargs = {
395400 "lazy_mode" : use_lazy_mode ,
396401 "hpu_graphs" : args .use_hpu_graphs ,
397402 "max_new_tokens" : args .max_new_tokens ,
398403 "ignore_eos" : args .ignore_eos ,
399- "use_flash_attention" : args .use_flash_attention ,
400- "flash_attention_recompute" : args .flash_attention_recompute ,
401404 "bucket_internal" : args .bucket_internal ,
402405 "bucket_size" : args .bucket_size ,
403406 "limit_hpu_graphs" : args .limit_hpu_graphs ,
@@ -406,6 +409,14 @@ def main():
406409 "logits_bf16" : args .logits_bf16 ,
407410 }
408411
412+ if model_type != "mllama" :
413+ generate_kwargs .update (
414+ {
415+ "use_flash_attention" : args .use_flash_attention ,
416+ "flash_attention_recompute" : args .flash_attention_recompute ,
417+ }
418+ )
419+
409420 if args .sdp_on_bf16 :
410421 torch ._C ._set_math_sdp_allow_fp16_bf16_reduction (True )
411422
0 commit comments