Skip to content

Commit 1bf4898

Browse files
committed
fine-tune llava-1.6 mistral 7b and 34b
1 parent 4e2277a commit 1bf4898

File tree

4 files changed

+175
-27
lines changed

4 files changed

+175
-27
lines changed

llava/model/builder.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,18 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
5050
if 'lora' in model_name.lower() and model_base is None:
5151
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
5252
if 'lora' in model_name.lower() and model_base is not None:
53-
from llava.model.language_model.llava_llama import LlavaConfig
54-
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
55-
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
53+
if 'mistral' in model_name.lower():
54+
from llava.model.language_model.llava_mistral import LlavaMistralConfig
55+
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
56+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
57+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58+
59+
else:
60+
from llava.model.language_model.llava_llama import LlavaConfig
61+
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
62+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
63+
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
5664
print('Loading LLaVA from base model...')
57-
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
5865
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
5966
if model.lm_head.weight.shape[0] != token_num:
6067
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
@@ -93,6 +100,10 @@ def load_from_hf(repo_id, filename, subfolder=None):
93100
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
94101
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
95102
model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
103+
elif 'mistral' in model_name.lower():
104+
tokenizer = AutoTokenizer.from_pretrained(model_base)
105+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
106+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
96107
else:
97108
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
98109
cfg_pretrained = AutoConfig.from_pretrained(model_path)

llava/train/train.py

+82-23
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from llava import conversation as conversation_lib
3535
from llava.model import *
36-
from llava.mm_utils import tokenizer_image_token
36+
from llava.mm_utils import tokenizer_image_token, process_anyres_image
3737

3838
from PIL import Image
3939

@@ -497,6 +497,47 @@ def preprocess_v1(
497497
)
498498

499499

500+
def debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image):
501+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
502+
calculated_len = 0
503+
504+
rounds = conversation.split(conv.sep)
505+
print("Tokenized Conversation:")
506+
tokenized_conversation = []
507+
for rou in rounds:
508+
if has_image:
509+
tokenized_rou = tokenizer_image_token(rou, tokenizer)
510+
else:
511+
tokenized_rou = tokenizer.encode(rou, add_special_tokens=False)
512+
print(tokenized_rou)
513+
tokenized_conversation.extend(tokenized_rou)
514+
calculated_len += len(tokenized_rou)
515+
516+
print("\nTokenized Target:")
517+
tokenized_target = target[target != IGNORE_INDEX].tolist()
518+
print(tokenized_target)
519+
520+
print("\nMissing Tokens:")
521+
missing_tokens = []
522+
conv_idx = 0
523+
for i, token in enumerate(tokenized_target):
524+
if conv_idx >= len(tokenized_conversation) or token != tokenized_conversation[conv_idx]:
525+
missing_tokens.append((i, token))
526+
else:
527+
conv_idx += 1
528+
529+
if missing_tokens:
530+
for idx, token in missing_tokens:
531+
print(f"Position: {idx}, Token: {token} ({tokenizer.decode([token])})")
532+
else:
533+
print("No missing tokens found.")
534+
535+
if calculated_len != total_len:
536+
print(f"\nLength mismatch detected. Calculated: {calculated_len}, Actual: {total_len}")
537+
else:
538+
print(f"\nLengths match. Length: {calculated_len}")
539+
540+
500541
def preprocess_mpt(
501542
sources,
502543
tokenizer: transformers.PreTrainedTokenizer,
@@ -505,11 +546,9 @@ def preprocess_mpt(
505546
conv = conversation_lib.default_conversation.copy()
506547
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
507548

508-
# Apply prompt templates
509549
conversations = []
510550
for i, source in enumerate(sources):
511551
if roles[source[0]["from"]] != conv.roles[0]:
512-
# Skip the first one if it is not from human
513552
source = source[1:]
514553

515554
conv.messages = []
@@ -518,27 +557,19 @@ def preprocess_mpt(
518557
assert role == conv.roles[j % 2], f"{i}"
519558
conv.append_message(role, sentence["value"])
520559
conversations.append(conv.get_prompt())
521-
522-
# Tokenize conversations
560+
#print(conv.get_prompt())
523561

524562
if has_image:
525-
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
526-
else:
527-
input_ids = tokenizer(
528-
conversations,
529-
return_tensors="pt",
530-
padding="longest",
531-
max_length=tokenizer.model_max_length,
532-
truncation=True,
533-
).input_ids
534-
563+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations], dim=0)
564+
535565
targets = input_ids.clone()
536566
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
537567

538-
# Mask targets
539568
sep = conv.sep + conv.roles[1]
540569
for conversation, target in zip(conversations, targets):
541570
total_len = int(target.ne(tokenizer.pad_token_id).sum())
571+
#print("target: ", target)
572+
#print("conversation: ", conversation)
542573

543574
rounds = conversation.split(conv.sep)
544575
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
@@ -547,13 +578,14 @@ def preprocess_mpt(
547578
cur_len = 0
548579
target[:cur_len] = IGNORE_INDEX
549580
for i, rou in enumerate(re_rounds):
581+
#print(rou)
550582
if rou == "":
551583
break
552-
553584
parts = rou.split(sep)
554585
if len(parts) != 2:
555586
break
556587
parts[0] += sep
588+
#print("parts ", parts)
557589

558590
if has_image:
559591
round_len = len(tokenizer_image_token(rou, tokenizer))
@@ -562,14 +594,18 @@ def preprocess_mpt(
562594
round_len = len(tokenizer(rou).input_ids)
563595
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
564596

565-
if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
597+
#if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
598+
if getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
599+
#print("yes")
566600
round_len += 1
567601
instruction_len += 1
568602

569603
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
570604

571605
cur_len += round_len
572606
target[cur_len:] = IGNORE_INDEX
607+
608+
# debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image)
573609

574610
if cur_len < tokenizer.model_max_length:
575611
if cur_len != total_len:
@@ -660,14 +696,16 @@ class LazySupervisedDataset(Dataset):
660696

661697
def __init__(self, data_path: str,
662698
tokenizer: transformers.PreTrainedTokenizer,
663-
data_args: DataArguments):
699+
data_args: DataArguments,
700+
model_config):
664701
super(LazySupervisedDataset, self).__init__()
665702
list_data_dict = json.load(open(data_path, "r"))
666703

667704
rank0_print("Formatting inputs...Skip in lazy mode")
668705
self.tokenizer = tokenizer
669706
self.list_data_dict = list_data_dict
670707
self.data_args = data_args
708+
self.model_config = model_config
671709

672710
def __len__(self):
673711
return len(self.list_data_dict)
@@ -699,6 +737,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
699737
image_folder = self.data_args.image_folder
700738
processor = self.data_args.image_processor
701739
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
740+
image_size = image.size
702741
if self.data_args.image_aspect_ratio == 'pad':
703742
def expand2square(pil_img, background_color):
704743
width, height = pil_img.size
@@ -714,6 +753,8 @@ def expand2square(pil_img, background_color):
714753
return result
715754
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
716755
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
756+
elif self.data_args.image_aspect_ratio == 'anyres':
757+
image = process_anyres_image(image, processor, self.model_config)
717758
else:
718759
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
719760
sources = preprocess_multimodal(
@@ -732,6 +773,7 @@ def expand2square(pil_img, background_color):
732773
# image exist in the data
733774
if 'image' in self.list_data_dict[i]:
734775
data_dict['image'] = image
776+
data_dict['image_size'] = image_size
735777
elif self.data_args.is_multimodal:
736778
# image does not exist in the data, but the model is multimodal
737779
crop_size = self.data_args.image_processor.crop_size
@@ -765,20 +807,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
765807

766808
if 'image' in instances[0]:
767809
images = [instance['image'] for instance in instances]
810+
image_sizes = [instance['image_size'] for instance in instances]
768811
if all(x is not None and x.shape == images[0].shape for x in images):
769812
batch['images'] = torch.stack(images)
770813
else:
771814
batch['images'] = images
815+
batch['image_sizes'] = image_sizes
772816

773817
return batch
774818

775819

776820
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
777-
data_args) -> Dict:
821+
data_args,
822+
model_config) -> Dict:
778823
"""Make dataset and collator for supervised fine-tuning."""
779824
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
780825
data_path=data_args.data_path,
781-
data_args=data_args)
826+
data_args=data_args,
827+
model_config=model_config)
782828
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
783829
return dict(train_dataset=train_dataset,
784830
eval_dataset=None,
@@ -823,12 +869,20 @@ def train(attn_implementation=None):
823869
cache_dir=training_args.cache_dir,
824870
**bnb_model_from_pretrained_args
825871
)
872+
elif 'mistral' in model_args.model_name_or_path.lower():
873+
model = LlavaMistralForCausalLM.from_pretrained(
874+
model_args.model_name_or_path,
875+
cache_dir=training_args.cache_dir,
876+
attn_implementation=attn_implementation,
877+
torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16),
878+
**bnb_model_from_pretrained_args
879+
)
826880
else:
827881
model = LlavaLlamaForCausalLM.from_pretrained(
828882
model_args.model_name_or_path,
829883
cache_dir=training_args.cache_dir,
830884
attn_implementation=attn_implementation,
831-
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
885+
torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16),
832886
**bnb_model_from_pretrained_args
833887
)
834888
else:
@@ -943,6 +997,7 @@ def make_inputs_require_grad(module, input, output):
943997
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
944998
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
945999

1000+
9461001
if training_args.bits in [4, 8]:
9471002
from peft.tuners.lora import LoraLayer
9481003
for name, module in model.named_modules():
@@ -956,8 +1011,11 @@ def make_inputs_require_grad(module, input, output):
9561011
if training_args.bf16 and module.weight.dtype == torch.float32:
9571012
module = module.to(torch.bfloat16)
9581013

1014+
model.resize_token_embeddings(len(tokenizer))
1015+
9591016
data_module = make_supervised_data_module(tokenizer=tokenizer,
960-
data_args=data_args)
1017+
data_args=data_args,
1018+
model_config=model.config)
9611019
trainer = LLaVATrainer(model=model,
9621020
tokenizer=tokenizer,
9631021
args=training_args,
@@ -989,3 +1047,4 @@ def make_inputs_require_grad(module, input, output):
9891047

9901048
if __name__ == "__main__":
9911049
train()
1050+
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
3+
deepspeed llava/train/train_mem.py \
4+
--lora_enable True --lora_r 16 --lora_alpha 32 --mm_projector_lr 2e-5 \
5+
--deepspeed ./scripts/zero3.json \
6+
--model_name_or_path liuhaotian/llava-v1.6-34b \
7+
--version chatml_direct_ft \
8+
--data_path combined_data.json \
9+
--image_folder random_images \
10+
--vision_tower openai/clip-vit-large-patch14-336 \
11+
--mm_projector_type mlp2x_gelu \
12+
--mm_vision_select_layer -2 \
13+
--mm_use_im_start_end False \
14+
--mm_use_im_patch_token False \
15+
--mm_patch_merge_type spatial_unpad \
16+
--image_aspect_ratio anyres \
17+
--group_by_modality_length False \
18+
--bf16 False \
19+
--fp16 True \
20+
--output_dir ./llava-lora-34b \
21+
--num_train_epochs 1 \
22+
--per_device_train_batch_size 4 \
23+
--per_device_eval_batch_size 1 \
24+
--gradient_accumulation_steps 1 \
25+
--evaluation_strategy "no" \
26+
--save_strategy "steps" \
27+
--save_steps 450 \
28+
--save_total_limit 5 \
29+
--learning_rate 2e-5 \
30+
--weight_decay 0. \
31+
--warmup_ratio 0.05 \
32+
--lr_scheduler_type "cosine" \
33+
--logging_steps 1 \
34+
--tf32 True \
35+
--model_max_length 4096 \
36+
--gradient_checkpointing True \
37+
--dataloader_num_workers 4 \
38+
--lazy_preprocess True \
39+
--report_to wandb \
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
3+
deepspeed llava/train/train_mem.py \
4+
--lora_enable True --lora_r 16 --lora_alpha 32 --mm_projector_lr 2e-5 \
5+
--deepspeed ./scripts/zero3.json \
6+
--model_name_or_path liuhaotian/llava-v1.6-mistral-7b \
7+
--version mistral_instruct \
8+
--data_path combined_data.json \
9+
--image_folder random_images \
10+
--vision_tower openai/clip-vit-large-patch14-336 \
11+
--mm_projector_type mlp2x_gelu \
12+
--mm_vision_select_layer -2 \
13+
--mm_use_im_start_end False \
14+
--mm_use_im_patch_token False \
15+
--mm_patch_merge_type spatial_unpad \
16+
--image_aspect_ratio anyres \
17+
--group_by_modality_length False \
18+
--bf16 False \
19+
--fp16 True \
20+
--output_dir ./llava-lora-mistral \
21+
--num_train_epochs 1 \
22+
--per_device_train_batch_size 4 \
23+
--per_device_eval_batch_size 1 \
24+
--gradient_accumulation_steps 1 \
25+
--evaluation_strategy "no" \
26+
--save_strategy "steps" \
27+
--save_steps 500 \
28+
--save_total_limit 5 \
29+
--learning_rate 2e-5 \
30+
--weight_decay 0. \
31+
--warmup_ratio 0.05 \
32+
--lr_scheduler_type "cosine" \
33+
--logging_steps 1 \
34+
--tf32 True \
35+
--model_max_length 4096 \
36+
--gradient_checkpointing True \
37+
--dataloader_num_workers 4 \
38+
--lazy_preprocess True \
39+
--report_to wandb \

0 commit comments

Comments
 (0)