From 1bf4898f006c43564faf09a153c8731985281f27 Mon Sep 17 00:00:00 2001 From: "Ariel N. Lee" <103224818+arielnlee@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:21:43 -0400 Subject: [PATCH 1/5] fine-tune llava-1.6 mistral 7b and 34b --- llava/model/builder.py | 19 +++- llava/train/train.py | 105 +++++++++++++++----- scripts/v1_6/finetune_lora_llava_34b.sh | 39 ++++++++ scripts/v1_6/finetune_lora_llava_mistral.sh | 39 ++++++++ 4 files changed, 175 insertions(+), 27 deletions(-) create mode 100644 scripts/v1_6/finetune_lora_llava_34b.sh create mode 100644 scripts/v1_6/finetune_lora_llava_mistral.sh diff --git a/llava/model/builder.py b/llava/model/builder.py index e3d50829f..7d9452214 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -50,11 +50,18 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l if 'lora' in model_name.lower() and model_base is None: warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') if 'lora' in model_name.lower() and model_base is not None: - from llava.model.language_model.llava_llama import LlavaConfig - lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + if 'mistral' in model_name.lower(): + from llava.model.language_model.llava_mistral import LlavaMistralConfig + lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + + else: + from llava.model.language_model.llava_llama import LlavaConfig + lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) print('Loading LLaVA from base model...') - model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) @@ -93,6 +100,10 @@ def load_from_hf(repo_id, filename, subfolder=None): tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif 'mistral' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) diff --git a/llava/train/train.py b/llava/train/train.py index 477c668b6..830080da9 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -33,7 +33,7 @@ from llava import conversation as conversation_lib from llava.model import * -from llava.mm_utils import tokenizer_image_token +from llava.mm_utils import tokenizer_image_token, process_anyres_image from PIL import Image @@ -497,6 +497,47 @@ def preprocess_v1( ) +def debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + calculated_len = 0 + + rounds = conversation.split(conv.sep) + print("Tokenized Conversation:") + tokenized_conversation = [] + for rou in rounds: + if has_image: + tokenized_rou = tokenizer_image_token(rou, tokenizer) + else: + tokenized_rou = tokenizer.encode(rou, add_special_tokens=False) + print(tokenized_rou) + tokenized_conversation.extend(tokenized_rou) + calculated_len += len(tokenized_rou) + + print("\nTokenized Target:") + tokenized_target = target[target != IGNORE_INDEX].tolist() + print(tokenized_target) + + print("\nMissing Tokens:") + missing_tokens = [] + conv_idx = 0 + for i, token in enumerate(tokenized_target): + if conv_idx >= len(tokenized_conversation) or token != tokenized_conversation[conv_idx]: + missing_tokens.append((i, token)) + else: + conv_idx += 1 + + if missing_tokens: + for idx, token in missing_tokens: + print(f"Position: {idx}, Token: {token} ({tokenizer.decode([token])})") + else: + print("No missing tokens found.") + + if calculated_len != total_len: + print(f"\nLength mismatch detected. Calculated: {calculated_len}, Actual: {total_len}") + else: + print(f"\nLengths match. Length: {calculated_len}") + + def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, @@ -505,11 +546,9 @@ def preprocess_mpt( conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} - # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: - # Skip the first one if it is not from human source = source[1:] conv.messages = [] @@ -518,27 +557,19 @@ def preprocess_mpt( assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) - - # Tokenize conversations + #print(conv.get_prompt()) if has_image: - input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) - else: - input_ids = tokenizer( - conversations, - return_tensors="pt", - padding="longest", - max_length=tokenizer.model_max_length, - truncation=True, - ).input_ids - + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations], dim=0) + targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT - # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) + #print("target: ", target) + #print("conversation: ", conversation) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt @@ -547,13 +578,14 @@ def preprocess_mpt( cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): + #print(rou) if rou == "": break - parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep + #print("parts ", parts) if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) @@ -562,7 +594,9 @@ def preprocess_mpt( round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 1 - if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: + #if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: + if getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: + #print("yes") round_len += 1 instruction_len += 1 @@ -570,6 +604,8 @@ def preprocess_mpt( cur_len += round_len target[cur_len:] = IGNORE_INDEX + + # debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image) if cur_len < tokenizer.model_max_length: if cur_len != total_len: @@ -660,7 +696,8 @@ class LazySupervisedDataset(Dataset): def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, - data_args: DataArguments): + data_args: DataArguments, + model_config): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) @@ -668,6 +705,7 @@ def __init__(self, data_path: str, self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args + self.model_config = model_config def __len__(self): return len(self.list_data_dict) @@ -699,6 +737,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + image_size = image.size if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size @@ -714,6 +753,8 @@ def expand2square(pil_img, background_color): return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + elif self.data_args.image_aspect_ratio == 'anyres': + image = process_anyres_image(image, processor, self.model_config) else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( @@ -732,6 +773,7 @@ def expand2square(pil_img, background_color): # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image + data_dict['image_size'] = image_size elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size @@ -765,20 +807,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: if 'image' in instances[0]: images = [instance['image'] for instance in instances] + image_sizes = [instance['image_size'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images + batch['image_sizes'] = image_sizes return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, - data_args) -> Dict: + data_args, + model_config) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, - data_args=data_args) + data_args=data_args, + model_config=model_config) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, @@ -823,12 +869,20 @@ def train(attn_implementation=None): cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) + elif 'mistral' in model_args.model_name_or_path.lower(): + model = LlavaMistralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16), + **bnb_model_from_pretrained_args + ) else: model = LlavaLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, - torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16), **bnb_model_from_pretrained_args ) else: @@ -943,6 +997,7 @@ def make_inputs_require_grad(module, input, output): model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): @@ -956,8 +1011,11 @@ def make_inputs_require_grad(module, input, output): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) + model.resize_token_embeddings(len(tokenizer)) + data_module = make_supervised_data_module(tokenizer=tokenizer, - data_args=data_args) + data_args=data_args, + model_config=model.config) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, @@ -989,3 +1047,4 @@ def make_inputs_require_grad(module, input, output): if __name__ == "__main__": train() + \ No newline at end of file diff --git a/scripts/v1_6/finetune_lora_llava_34b.sh b/scripts/v1_6/finetune_lora_llava_34b.sh new file mode 100644 index 000000000..18341e501 --- /dev/null +++ b/scripts/v1_6/finetune_lora_llava_34b.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --lora_enable True --lora_r 16 --lora_alpha 32 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path liuhaotian/llava-v1.6-34b \ + --version chatml_direct_ft \ + --data_path combined_data.json \ + --image_folder random_images \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --mm_patch_merge_type spatial_unpad \ + --image_aspect_ratio anyres \ + --group_by_modality_length False \ + --bf16 False \ + --fp16 True \ + --output_dir ./llava-lora-34b \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 450 \ + --save_total_limit 5 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.05 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 4096 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ \ No newline at end of file diff --git a/scripts/v1_6/finetune_lora_llava_mistral.sh b/scripts/v1_6/finetune_lora_llava_mistral.sh new file mode 100644 index 000000000..f185fe8fe --- /dev/null +++ b/scripts/v1_6/finetune_lora_llava_mistral.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --lora_enable True --lora_r 16 --lora_alpha 32 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path liuhaotian/llava-v1.6-mistral-7b \ + --version mistral_instruct \ + --data_path combined_data.json \ + --image_folder random_images \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --mm_patch_merge_type spatial_unpad \ + --image_aspect_ratio anyres \ + --group_by_modality_length False \ + --bf16 False \ + --fp16 True \ + --output_dir ./llava-lora-mistral \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 5 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.05 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 4096 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ \ No newline at end of file From 40949ba868b3f9efed072edf508591d3688517de Mon Sep 17 00:00:00 2001 From: "Ariel N. Lee" <103224818+arielnlee@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:41:36 -0400 Subject: [PATCH 2/5] minor fix --- llava/train/train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llava/train/train.py b/llava/train/train.py index 830080da9..4f96049f7 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -561,6 +561,14 @@ def preprocess_mpt( if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT From 53e4fcf2222e23b0e7628918f256b439775b8917 Mon Sep 17 00:00:00 2001 From: "Ariel N. Lee" <103224818+arielnlee@users.noreply.github.com> Date: Sun, 31 Mar 2024 22:17:10 -0400 Subject: [PATCH 3/5] update conversation template for 34b fine-tune --- llava/conversation.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/llava/conversation.py b/llava/conversation.py index 00c56867d..d60935389 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -369,6 +369,17 @@ def dict(self): sep="<|im_end|>", ) +conv_chatml_direct_ft = Conversation( + system="""<|im_start|>system\nAnswer the questions.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + + default_conversation = conv_vicuna_v1 conv_templates = { "default": conv_vicuna_v0, @@ -378,6 +389,7 @@ def dict(self): "llama_2": conv_llama_2, "mistral_instruct": conv_mistral_instruct, "chatml_direct": conv_chatml_direct, + "chatml_direct_ft": conv_chatml_direct_ft, "mistral_direct": conv_chatml_direct, "plain": conv_llava_plain, From c889f4a320d5897407065ef4ad1e77ef1ca51fef Mon Sep 17 00:00:00 2001 From: "Ariel N. Lee" <103224818+arielnlee@users.noreply.github.com> Date: Tue, 2 Apr 2024 11:59:23 -0400 Subject: [PATCH 4/5] minor update --- llava/train/train.py | 2 +- scripts/v1_6/finetune_lora_llava_34b.sh | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llava/train/train.py b/llava/train/train.py index 4f96049f7..d241a4dc5 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -73,7 +73,7 @@ class DataArguments: lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) - image_aspect_ratio: str = 'square' + image_aspect_ratio: str = 'anyres' @dataclass diff --git a/scripts/v1_6/finetune_lora_llava_34b.sh b/scripts/v1_6/finetune_lora_llava_34b.sh index 18341e501..5894e5ee3 100644 --- a/scripts/v1_6/finetune_lora_llava_34b.sh +++ b/scripts/v1_6/finetune_lora_llava_34b.sh @@ -5,8 +5,8 @@ deepspeed llava/train/train_mem.py \ --deepspeed ./scripts/zero3.json \ --model_name_or_path liuhaotian/llava-v1.6-34b \ --version chatml_direct_ft \ - --data_path combined_data.json \ - --image_folder random_images \ + --data_path transformed_data.json \ + --image_folder train_images \ --vision_tower openai/clip-vit-large-patch14-336 \ --mm_projector_type mlp2x_gelu \ --mm_vision_select_layer -2 \ @@ -15,8 +15,8 @@ deepspeed llava/train/train_mem.py \ --mm_patch_merge_type spatial_unpad \ --image_aspect_ratio anyres \ --group_by_modality_length False \ - --bf16 False \ - --fp16 True \ + --bf16 True \ + --fp16 False \ --output_dir ./llava-lora-34b \ --num_train_epochs 1 \ --per_device_train_batch_size 4 \ @@ -24,7 +24,7 @@ deepspeed llava/train/train_mem.py \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ - --save_steps 450 \ + --save_steps 250 \ --save_total_limit 5 \ --learning_rate 2e-5 \ --weight_decay 0. \ From ae157ec51863456a4cb93fff1bb3eaa651b3aa5f Mon Sep 17 00:00:00 2001 From: "Ariel N. Lee" <103224818+arielnlee@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:57:44 -0400 Subject: [PATCH 5/5] update anyres --- llava/mm_utils.py | 5 +++++ llava/train/train.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/llava/mm_utils.py b/llava/mm_utils.py index de97345cf..b828ae26e 100644 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -182,6 +182,11 @@ def process_images(images, image_processor, model_cfg): return new_images +def train_process_images(images, image_processor, model_cfg): + new_image = process_anyres_image(images, image_processor, model_cfg.image_grid_pinpoints) + return new_image + + def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] diff --git a/llava/train/train.py b/llava/train/train.py index d241a4dc5..9a88c5f47 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -33,7 +33,7 @@ from llava import conversation as conversation_lib from llava.model import * -from llava.mm_utils import tokenizer_image_token, process_anyres_image +from llava.mm_utils import tokenizer_image_token, train_process_images from PIL import Image @@ -762,7 +762,7 @@ def expand2square(pil_img, background_color): image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] elif self.data_args.image_aspect_ratio == 'anyres': - image = process_anyres_image(image, processor, self.model_config) + image = train_process_images(image, processor, self.model_config) else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal(