33
33
34
34
from llava import conversation as conversation_lib
35
35
from llava .model import *
36
- from llava .mm_utils import tokenizer_image_token
36
+ from llava .mm_utils import tokenizer_image_token , process_anyres_image
37
37
38
38
from PIL import Image
39
39
@@ -497,6 +497,47 @@ def preprocess_v1(
497
497
)
498
498
499
499
500
+ def debug_34b_tokenization_length (conversation , target , tokenizer , conv , has_image ):
501
+ total_len = int (target .ne (tokenizer .pad_token_id ).sum ())
502
+ calculated_len = 0
503
+
504
+ rounds = conversation .split (conv .sep )
505
+ print ("Tokenized Conversation:" )
506
+ tokenized_conversation = []
507
+ for rou in rounds :
508
+ if has_image :
509
+ tokenized_rou = tokenizer_image_token (rou , tokenizer )
510
+ else :
511
+ tokenized_rou = tokenizer .encode (rou , add_special_tokens = False )
512
+ print (tokenized_rou )
513
+ tokenized_conversation .extend (tokenized_rou )
514
+ calculated_len += len (tokenized_rou )
515
+
516
+ print ("\n Tokenized Target:" )
517
+ tokenized_target = target [target != IGNORE_INDEX ].tolist ()
518
+ print (tokenized_target )
519
+
520
+ print ("\n Missing Tokens:" )
521
+ missing_tokens = []
522
+ conv_idx = 0
523
+ for i , token in enumerate (tokenized_target ):
524
+ if conv_idx >= len (tokenized_conversation ) or token != tokenized_conversation [conv_idx ]:
525
+ missing_tokens .append ((i , token ))
526
+ else :
527
+ conv_idx += 1
528
+
529
+ if missing_tokens :
530
+ for idx , token in missing_tokens :
531
+ print (f"Position: { idx } , Token: { token } ({ tokenizer .decode ([token ])} )" )
532
+ else :
533
+ print ("No missing tokens found." )
534
+
535
+ if calculated_len != total_len :
536
+ print (f"\n Length mismatch detected. Calculated: { calculated_len } , Actual: { total_len } " )
537
+ else :
538
+ print (f"\n Lengths match. Length: { calculated_len } " )
539
+
540
+
500
541
def preprocess_mpt (
501
542
sources ,
502
543
tokenizer : transformers .PreTrainedTokenizer ,
@@ -505,11 +546,9 @@ def preprocess_mpt(
505
546
conv = conversation_lib .default_conversation .copy ()
506
547
roles = {"human" : conv .roles [0 ], "gpt" : conv .roles [1 ]}
507
548
508
- # Apply prompt templates
509
549
conversations = []
510
550
for i , source in enumerate (sources ):
511
551
if roles [source [0 ]["from" ]] != conv .roles [0 ]:
512
- # Skip the first one if it is not from human
513
552
source = source [1 :]
514
553
515
554
conv .messages = []
@@ -518,27 +557,19 @@ def preprocess_mpt(
518
557
assert role == conv .roles [j % 2 ], f"{ i } "
519
558
conv .append_message (role , sentence ["value" ])
520
559
conversations .append (conv .get_prompt ())
521
-
522
- # Tokenize conversations
560
+ #print(conv.get_prompt())
523
561
524
562
if has_image :
525
- input_ids = torch .stack ([tokenizer_image_token (prompt , tokenizer , return_tensors = 'pt' ) for prompt in conversations ], dim = 0 )
526
- else :
527
- input_ids = tokenizer (
528
- conversations ,
529
- return_tensors = "pt" ,
530
- padding = "longest" ,
531
- max_length = tokenizer .model_max_length ,
532
- truncation = True ,
533
- ).input_ids
534
-
563
+ input_ids = torch .stack ([tokenizer_image_token (prompt , tokenizer , IMAGE_TOKEN_INDEX , return_tensors = 'pt' ) for prompt in conversations ], dim = 0 )
564
+
535
565
targets = input_ids .clone ()
536
566
assert conv .sep_style == conversation_lib .SeparatorStyle .MPT
537
567
538
- # Mask targets
539
568
sep = conv .sep + conv .roles [1 ]
540
569
for conversation , target in zip (conversations , targets ):
541
570
total_len = int (target .ne (tokenizer .pad_token_id ).sum ())
571
+ #print("target: ", target)
572
+ #print("conversation: ", conversation)
542
573
543
574
rounds = conversation .split (conv .sep )
544
575
re_rounds = [conv .sep .join (rounds [:3 ])] # system + user + gpt
@@ -547,13 +578,14 @@ def preprocess_mpt(
547
578
cur_len = 0
548
579
target [:cur_len ] = IGNORE_INDEX
549
580
for i , rou in enumerate (re_rounds ):
581
+ #print(rou)
550
582
if rou == "" :
551
583
break
552
-
553
584
parts = rou .split (sep )
554
585
if len (parts ) != 2 :
555
586
break
556
587
parts [0 ] += sep
588
+ #print("parts ", parts)
557
589
558
590
if has_image :
559
591
round_len = len (tokenizer_image_token (rou , tokenizer ))
@@ -562,14 +594,18 @@ def preprocess_mpt(
562
594
round_len = len (tokenizer (rou ).input_ids )
563
595
instruction_len = len (tokenizer (parts [0 ]).input_ids ) - 1
564
596
565
- if i != 0 and getattr (tokenizer , 'legacy' , False ) and IS_TOKENIZER_GREATER_THAN_0_14 :
597
+ #if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
598
+ if getattr (tokenizer , 'legacy' , False ) and IS_TOKENIZER_GREATER_THAN_0_14 :
599
+ #print("yes")
566
600
round_len += 1
567
601
instruction_len += 1
568
602
569
603
target [cur_len : cur_len + instruction_len ] = IGNORE_INDEX
570
604
571
605
cur_len += round_len
572
606
target [cur_len :] = IGNORE_INDEX
607
+
608
+ # debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image)
573
609
574
610
if cur_len < tokenizer .model_max_length :
575
611
if cur_len != total_len :
@@ -660,14 +696,16 @@ class LazySupervisedDataset(Dataset):
660
696
661
697
def __init__ (self , data_path : str ,
662
698
tokenizer : transformers .PreTrainedTokenizer ,
663
- data_args : DataArguments ):
699
+ data_args : DataArguments ,
700
+ model_config ):
664
701
super (LazySupervisedDataset , self ).__init__ ()
665
702
list_data_dict = json .load (open (data_path , "r" ))
666
703
667
704
rank0_print ("Formatting inputs...Skip in lazy mode" )
668
705
self .tokenizer = tokenizer
669
706
self .list_data_dict = list_data_dict
670
707
self .data_args = data_args
708
+ self .model_config = model_config
671
709
672
710
def __len__ (self ):
673
711
return len (self .list_data_dict )
@@ -699,6 +737,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
699
737
image_folder = self .data_args .image_folder
700
738
processor = self .data_args .image_processor
701
739
image = Image .open (os .path .join (image_folder , image_file )).convert ('RGB' )
740
+ image_size = image .size
702
741
if self .data_args .image_aspect_ratio == 'pad' :
703
742
def expand2square (pil_img , background_color ):
704
743
width , height = pil_img .size
@@ -714,6 +753,8 @@ def expand2square(pil_img, background_color):
714
753
return result
715
754
image = expand2square (image , tuple (int (x * 255 ) for x in processor .image_mean ))
716
755
image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ][0 ]
756
+ elif self .data_args .image_aspect_ratio == 'anyres' :
757
+ image = process_anyres_image (image , processor , self .model_config )
717
758
else :
718
759
image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ][0 ]
719
760
sources = preprocess_multimodal (
@@ -732,6 +773,7 @@ def expand2square(pil_img, background_color):
732
773
# image exist in the data
733
774
if 'image' in self .list_data_dict [i ]:
734
775
data_dict ['image' ] = image
776
+ data_dict ['image_size' ] = image_size
735
777
elif self .data_args .is_multimodal :
736
778
# image does not exist in the data, but the model is multimodal
737
779
crop_size = self .data_args .image_processor .crop_size
@@ -765,20 +807,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
765
807
766
808
if 'image' in instances [0 ]:
767
809
images = [instance ['image' ] for instance in instances ]
810
+ image_sizes = [instance ['image_size' ] for instance in instances ]
768
811
if all (x is not None and x .shape == images [0 ].shape for x in images ):
769
812
batch ['images' ] = torch .stack (images )
770
813
else :
771
814
batch ['images' ] = images
815
+ batch ['image_sizes' ] = image_sizes
772
816
773
817
return batch
774
818
775
819
776
820
def make_supervised_data_module (tokenizer : transformers .PreTrainedTokenizer ,
777
- data_args ) -> Dict :
821
+ data_args ,
822
+ model_config ) -> Dict :
778
823
"""Make dataset and collator for supervised fine-tuning."""
779
824
train_dataset = LazySupervisedDataset (tokenizer = tokenizer ,
780
825
data_path = data_args .data_path ,
781
- data_args = data_args )
826
+ data_args = data_args ,
827
+ model_config = model_config )
782
828
data_collator = DataCollatorForSupervisedDataset (tokenizer = tokenizer )
783
829
return dict (train_dataset = train_dataset ,
784
830
eval_dataset = None ,
@@ -823,12 +869,20 @@ def train(attn_implementation=None):
823
869
cache_dir = training_args .cache_dir ,
824
870
** bnb_model_from_pretrained_args
825
871
)
872
+ elif 'mistral' in model_args .model_name_or_path .lower ():
873
+ model = LlavaMistralForCausalLM .from_pretrained (
874
+ model_args .model_name_or_path ,
875
+ cache_dir = training_args .cache_dir ,
876
+ attn_implementation = attn_implementation ,
877
+ torch_dtype = (torch .bfloat16 if training_args .bf16 else torch .float16 ),
878
+ ** bnb_model_from_pretrained_args
879
+ )
826
880
else :
827
881
model = LlavaLlamaForCausalLM .from_pretrained (
828
882
model_args .model_name_or_path ,
829
883
cache_dir = training_args .cache_dir ,
830
884
attn_implementation = attn_implementation ,
831
- torch_dtype = (torch .bfloat16 if training_args .bf16 else None ),
885
+ torch_dtype = (torch .bfloat16 if training_args .bf16 else torch . float16 ),
832
886
** bnb_model_from_pretrained_args
833
887
)
834
888
else :
@@ -943,6 +997,7 @@ def make_inputs_require_grad(module, input, output):
943
997
model .config .mm_use_im_patch_token = model_args .mm_use_im_patch_token
944
998
model .initialize_vision_tokenizer (model_args , tokenizer = tokenizer )
945
999
1000
+
946
1001
if training_args .bits in [4 , 8 ]:
947
1002
from peft .tuners .lora import LoraLayer
948
1003
for name , module in model .named_modules ():
@@ -956,8 +1011,11 @@ def make_inputs_require_grad(module, input, output):
956
1011
if training_args .bf16 and module .weight .dtype == torch .float32 :
957
1012
module = module .to (torch .bfloat16 )
958
1013
1014
+ model .resize_token_embeddings (len (tokenizer ))
1015
+
959
1016
data_module = make_supervised_data_module (tokenizer = tokenizer ,
960
- data_args = data_args )
1017
+ data_args = data_args ,
1018
+ model_config = model .config )
961
1019
trainer = LLaVATrainer (model = model ,
962
1020
tokenizer = tokenizer ,
963
1021
args = training_args ,
@@ -989,3 +1047,4 @@ def make_inputs_require_grad(module, input, output):
989
1047
990
1048
if __name__ == "__main__" :
991
1049
train ()
1050
+
0 commit comments