From 2d7f0e0b6fac50fe3c1423acece7f6d75bc0d4e1 Mon Sep 17 00:00:00 2001 From: Yuzhe Date: Tue, 25 Jun 2024 15:53:47 +0800 Subject: [PATCH 1/2] add qwen2 support for pretraining and finetuning --- llava/__init__.py | 2 +- llava/conversation.py | 63 +- llava/eval/model_qa.py | 4 +- llava/eval/model_vqa.py | 1 + llava/eval/model_vqa_loader.py | 1 + llava/eval/model_vqa_mmbench.py | 18 +- llava/eval/model_vqa_science.py | 1 + llava/mm_utils.py | 5 +- llava/model/__init__.py | 1 + llava/model/builder.py | 1 + llava/model/language_model/llava_mpt.py | 28 +- llava/model/language_model/llava_qwen.py | 159 ++++ llava/model/llava_arch.py | 1 + llava/serve/cli.py | 2 +- llava/serve/controller.py | 8 +- llava/serve/gradio_web_server.py | 14 +- llava/serve/model_worker.py | 6 +- llava/serve/sglang_worker.py | 6 +- llava/serve/test_message.py | 7 +- llava/train/train.py | 129 ++- llava/train/train_mem.py | 2 +- llava/train/train_qwen.py | 1103 ++++++++++++++++++++++ llava/train/train_xformers.py | 2 +- llava/utils.py | 1 + scripts/v1_5/finetune_qwen_2.sh | 37 + 25 files changed, 1531 insertions(+), 71 deletions(-) create mode 100644 llava/model/language_model/llava_qwen.py create mode 100644 llava/train/train_qwen.py create mode 100644 scripts/v1_5/finetune_qwen_2.sh diff --git a/llava/__init__.py b/llava/__init__.py index 4d1f016db..8b70d448d 100644 --- a/llava/__init__.py +++ b/llava/__init__.py @@ -1 +1 @@ -from .model import LlavaLlamaForCausalLM +from .model import LlavaLlamaForCausalLM, LlavaQwen2ForCausalLM diff --git a/llava/conversation.py b/llava/conversation.py index 00c56867d..aef42c937 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -13,6 +13,8 @@ class SeparatorStyle(Enum): MPT = auto() PLAIN = auto() LLAMA_2 = auto() + QWEN_2 = auto() # fix: add qwen2 + CHATML = auto() @dataclasses.dataclass @@ -51,6 +53,27 @@ def get_prompt(self): ret += role + ": " + message + self.sep else: ret += role + ":" + elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2 + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + message, images = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] @@ -71,8 +94,8 @@ def get_prompt(self): else: ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: - wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg - wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + def wrap_sys(msg): return f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + def wrap_inst(msg): return f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): @@ -82,7 +105,8 @@ def get_prompt(self): if message: if type(message) is tuple: message, _, _ = message - if i == 0: message = wrap_sys(self.system) + message + if i == 0: + message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message @@ -369,12 +393,38 @@ def dict(self): sep="<|im_end|>", ) -default_conversation = conv_vicuna_v1 + +# fix: add qwen2 +conv_qwen_2 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="qwen_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.QWEN_2, + sep=" ", + sep2="<|endoftext|>", +) + +# conv_qwen_2 = Conversation( +# system="""<|im_start|>system +# You are a helpful assistant.""", +# roles=("<|im_start|>user", "<|im_start|>assistant"), +# version="qwen_v2", +# messages=[], +# offset=0, +# sep_style=SeparatorStyle.CHATML, +# sep="<|im_end|>", +# ) + +default_conversation = conv_qwen_2 conv_templates = { - "default": conv_vicuna_v0, + "default": conv_qwen_2, "v0": conv_vicuna_v0, "v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1, + "qwen_2": conv_qwen_2, "llama_2": conv_llama_2, "mistral_instruct": conv_mistral_instruct, "chatml_direct": conv_chatml_direct, @@ -391,6 +441,5 @@ def dict(self): "mpt": conv_mpt, } - if __name__ == "__main__": - print(default_conversation.get_prompt()) + print("conversation:", default_conversation.get_prompt()) diff --git a/llava/eval/model_qa.py b/llava/eval/model_qa.py index 2e254da15..f0b9599ed 100644 --- a/llava/eval/model_qa.py +++ b/llava/eval/model_qa.py @@ -17,8 +17,7 @@ def eval_model(model_name, questions_file, answers_file): model_name = os.path.expanduser(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_name, - torch_dtype=torch.float16).cuda() - + torch_dtype=torch.float16).cuda() ques_file = open(os.path.expanduser(questions_file), "r") ans_file = open(os.path.expanduser(answers_file), "w") @@ -54,6 +53,7 @@ def eval_model(model_name, questions_file, answers_file): ans_file.flush() ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") diff --git a/llava/eval/model_vqa.py b/llava/eval/model_vqa.py index 938706438..516dec236 100644 --- a/llava/eval/model_vqa.py +++ b/llava/eval/model_vqa.py @@ -83,6 +83,7 @@ def eval_model(args): ans_file.flush() ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") diff --git a/llava/eval/model_vqa_loader.py b/llava/eval/model_vqa_loader.py index d435b7d83..a2b30c0df 100644 --- a/llava/eval/model_vqa_loader.py +++ b/llava/eval/model_vqa_loader.py @@ -125,6 +125,7 @@ def eval_model(args): # ans_file.flush() ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") diff --git a/llava/eval/model_vqa_mmbench.py b/llava/eval/model_vqa_mmbench.py index bd7a4c808..045350711 100644 --- a/llava/eval/model_vqa_mmbench.py +++ b/llava/eval/model_vqa_mmbench.py @@ -41,6 +41,7 @@ def is_none(value): return True return False + def get_options(row, options): parsed_options = [] for option in options: @@ -124,14 +125,14 @@ def eval_model(args): ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, - "round_id": round_idx, - "prompt": cur_prompt, - "text": outputs, - "options": options, - "option_char": cur_option_char, - "answer_id": ans_id, - "model_id": model_name, - "metadata": {}}) + "\n") + "round_id": round_idx, + "prompt": cur_prompt, + "text": outputs, + "options": options, + "option_char": cur_option_char, + "answer_id": ans_id, + "model_id": model_name, + "metadata": {}}) + "\n") ans_file.flush() # rotate options @@ -139,6 +140,7 @@ def eval_model(args): cur_option_char = cur_option_char[1:] + cur_option_char[:1] ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") diff --git a/llava/eval/model_vqa_science.py b/llava/eval/model_vqa_science.py index 90fc681a2..a62d3668e 100644 --- a/llava/eval/model_vqa_science.py +++ b/llava/eval/model_vqa_science.py @@ -93,6 +93,7 @@ def eval_model(args): ans_file.flush() ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") diff --git a/llava/mm_utils.py b/llava/mm_utils.py index de97345cf..43836cc16 100644 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -212,6 +212,7 @@ def get_model_name_from_path(model_path): else: return model_paths[-1] + class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords @@ -226,7 +227,7 @@ def __init__(self, keywords, tokenizer, input_ids): self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] - + def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] @@ -239,7 +240,7 @@ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor if keyword in outputs: return True return False - + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: outputs = [] for i in range(output_ids.shape[0]): diff --git a/llava/model/__init__.py b/llava/model/__init__.py index dbd91789f..9952691ed 100644 --- a/llava/model/__init__.py +++ b/llava/model/__init__.py @@ -2,5 +2,6 @@ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig + from .language_model.llava_qwen import LlavaQwen2ForCausalLM, LlavaConfig except: pass diff --git a/llava/model/builder.py b/llava/model/builder.py index e3d50829f..5ba12783a 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -66,6 +66,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l else: # this is probably from HF Hub from huggingface_hub import hf_hub_download + def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, diff --git a/llava/model/language_model/llava_mpt.py b/llava/model/language_model/llava_mpt.py index 02e5237ec..7243ed647 100644 --- a/llava/model/language_model/llava_mpt.py +++ b/llava/model/language_model/llava_mpt.py @@ -18,7 +18,7 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, \ - MptConfig, MptForCausalLM, MptModel + MptConfig, MptForCausalLM, MptModel from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM @@ -32,7 +32,7 @@ class LlavaMptModel(LlavaMetaModel, MptModel): def __init__(self, config: MptConfig): config.hidden_size = config.d_model super(LlavaMptModel, self).__init__(config) - + def embed_tokens(self, x): return self.wte(x) @@ -58,20 +58,20 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - images=None): + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + images=None): input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) - + return super().forward( input_ids, past_key_values=past_key_values, diff --git a/llava/model/language_model/llava_qwen.py b/llava/model/language_model/llava_qwen.py new file mode 100644 index 000000000..3c2d47913 --- /dev/null +++ b/llava/model/language_model/llava_qwen.py @@ -0,0 +1,159 @@ + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM + + +class LlavaConfig(Qwen2Config): + model_type = "llava_qwen2" + + +class LlavaQwen2Model(LlavaMetaModel, Qwen2Model): + config_class = LlavaConfig + + def __init__(self, config: Qwen2Config): + super(LlavaQwen2Model, self).__init__(config) + + +class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): + config_class = LlavaConfig + + def __init__(self, config): + super(Qwen2ForCausalLM, self).__init__(config) + self.model = LlavaQwen2Model(config) + # self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + + +AutoConfig.register("llava_qwen2", LlavaConfig) +AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index d71650eac..90ee7685f 100644 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -91,6 +91,7 @@ def initialize_vision_modules(self, model_args, fsdp=None): if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} diff --git a/llava/serve/cli.py b/llava/serve/cli.py index 5ecb30d56..d38527c4d 100644 --- a/llava/serve/cli.py +++ b/llava/serve/cli.py @@ -82,7 +82,7 @@ def main(args): else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp image = None - + conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() diff --git a/llava/serve/controller.py b/llava/serve/controller.py index d4bf1b4c4..387a1191d 100644 --- a/llava/serve/controller.py +++ b/llava/serve/controller.py @@ -132,14 +132,14 @@ def get_worker_address(self, model_name: str): worker_speeds = worker_speeds / norm if True: # Directly return address pt = np.random.choice(np.arange(len(worker_names)), - p=worker_speeds) + p=worker_speeds) worker_name = worker_names[pt] return worker_name # Check status before returning while True: pt = np.random.choice(np.arange(len(worker_names)), - p=worker_speeds) + p=worker_speeds) worker_name = worker_names[pt] if self.get_worker_status(worker_name): @@ -202,7 +202,7 @@ def worker_api_generate_stream(self, params): try: response = requests.post(worker_addr + "/worker_generate_stream", - json=params, stream=True, timeout=5) + json=params, stream=True, timeout=5) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: yield chunk + b"\0" @@ -214,9 +214,9 @@ def worker_api_generate_stream(self, params): } yield json.dumps(ret).encode() + b"\0" - # Let the controller act as a worker to achieve hierarchical # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): model_names = set() speed = 0 diff --git a/llava/serve/gradio_web_server.py b/llava/serve/gradio_web_server.py index c07efc122..e45693d9e 100644 --- a/llava/serve/gradio_web_server.py +++ b/llava/serve/gradio_web_server.py @@ -8,10 +8,10 @@ import requests from llava.conversation import (default_conversation, conv_templates, - SeparatorStyle) + SeparatorStyle) from llava.constants import LOGDIR from llava.utils import (build_logger, server_error_msg, - violates_moderation, moderation_msg) + violates_moderation, moderation_msg) import hashlib @@ -205,7 +205,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: # Query worker address controller_url = args.controller_url ret = requests.post(controller_url + "/get_worker_address", - json={"model": model_name}) + json={"model": model_name}) worker_addr = ret.json()["address"] logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") @@ -247,7 +247,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: try: # Stream output response = requests.post(worker_addr + "/worker_generate_stream", - headers=headers, json=pload, stream=True, timeout=10) + headers=headers, json=pload, stream=True, timeout=10) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) @@ -285,6 +285,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: } fout.write(json.dumps(data) + "\n") + title_markdown = (""" # 🌋 LLaVA: Large Language and Vision Assistant [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)] @@ -312,6 +313,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: """ + def build_demo(embed_mode, cur_dir=None, concurrency_count=10): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: @@ -364,7 +366,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) flag_btn = gr.Button(value="âš ī¸ Flag", interactive=False) - #stop_btn = gr.Button(value="âšī¸ Stop Generation", interactive=False) + # stop_btn = gr.Button(value="âšī¸ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="đŸ—‘ī¸ Clear", interactive=False) @@ -459,7 +461,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10): parser.add_argument("--controller-url", type=str, default="http://localhost:21001") parser.add_argument("--concurrency-count", type=int, default=16) parser.add_argument("--model-list-mode", type=str, default="once", - choices=["once", "reload"]) + choices=["once", "reload"]) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") diff --git a/llava/serve/model_worker.py b/llava/serve/model_worker.py index 914432989..94dc4a042 100644 --- a/llava/serve/model_worker.py +++ b/llava/serve/model_worker.py @@ -17,7 +17,7 @@ from llava.constants import WORKER_HEART_BEAT_INTERVAL from llava.utils import (build_logger, server_error_msg, - pretty_print_semaphore) + pretty_print_semaphore) from llava.model.builder import load_pretrained_model from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN @@ -254,9 +254,9 @@ async def get_status(request: Request): parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) parser.add_argument("--worker-address", type=str, - default="http://localhost:21002") + default="http://localhost:21002") parser.add_argument("--controller-address", type=str, - default="http://localhost:21001") + default="http://localhost:21001") parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--model-name", type=str) diff --git a/llava/serve/sglang_worker.py b/llava/serve/sglang_worker.py index a3297b7c2..119bbe793 100644 --- a/llava/serve/sglang_worker.py +++ b/llava/serve/sglang_worker.py @@ -18,7 +18,7 @@ from llava.constants import WORKER_HEART_BEAT_INTERVAL from llava.utils import (build_logger, server_error_msg, - pretty_print_semaphore) + pretty_print_semaphore) from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square from llava.constants import DEFAULT_IMAGE_TOKEN @@ -224,9 +224,9 @@ async def get_status(request: Request): parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) parser.add_argument("--worker-address", type=str, - default="http://localhost:21002") + default="http://localhost:21002") parser.add_argument("--controller-address", type=str, - default="http://localhost:21001") + default="http://localhost:21001") parser.add_argument("--model-name", type=str) parser.add_argument("--sgl-endpoint", type=str) parser.add_argument("--limit-model-concurrency", type=int, default=5) diff --git a/llava/serve/test_message.py b/llava/serve/test_message.py index 6b090faed..d104cf366 100644 --- a/llava/serve/test_message.py +++ b/llava/serve/test_message.py @@ -18,7 +18,7 @@ def main(): print(f"Models: {models}") ret = requests.post(controller_addr + "/get_worker_address", - json={"model": args.model_name}) + json={"model": args.model_name}) worker_addr = ret.json()["address"] print(f"worker_addr: {worker_addr}") @@ -38,7 +38,7 @@ def main(): "stop": conv.sep, } response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, - json=pload, stream=True) + json=pload, stream=True) print(prompt.replace(conv.sep, "\n"), end="") for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): @@ -55,8 +55,7 @@ def main(): parser.add_argument("--worker-address", type=str) parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--max-new-tokens", type=int, default=32) - parser.add_argument("--message", type=str, default= - "Tell me a story with more than 1000 words.") + parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") args = parser.parse_args() main() diff --git a/llava/train/train.py b/llava/train/train.py index 477c668b6..87a29dce4 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from packaging import version import os import copy from dataclasses import dataclass, field @@ -46,7 +47,6 @@ def rank0_print(*args): print(*args) -from packaging import version IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') @@ -177,7 +177,7 @@ def find_all_linear_names(model): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - if 'lm_head' in lora_module_names: # needed for 16-bit + if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) @@ -392,7 +392,7 @@ def preprocess_llama_2( round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 - target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX @@ -411,6 +411,103 @@ def preprocess_llama_2( ) +# fix: add qwen2 +def preprocess_qwen_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2 + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + rounds_len = len(rounds) + cur_len = 0 + # target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_ids = tokenizer_image_token(rou, tokenizer) + instruction_ids = tokenizer_image_token(parts[0], tokenizer) + equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] + + instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) + round_len = len(round_ids) + + else: + round_ids = tokenizer(rou).input_ids + instruction_ids = tokenizer(parts[0]).input_ids + equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] + + instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) + round_len = len(round_ids) + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len += 1 + instruction_len += 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len + rounds_len - 2: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, @@ -478,7 +575,7 @@ def preprocess_v1( round_len -= 1 instruction_len -= 1 - target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX @@ -541,7 +638,7 @@ def preprocess_mpt( total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) - re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 @@ -566,7 +663,7 @@ def preprocess_mpt( round_len += 1 instruction_len += 1 - target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX @@ -627,6 +724,9 @@ def preprocess( return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer, has_image=has_image) + # fix: add qwen2 + if conversation_lib.default_conversation.version.startswith("qwen_v2"): + return preprocess_qwen_2(sources, tokenizer, has_image=has_image) # add end signal and concatenate together conversations = [] for source in sources: @@ -634,6 +734,7 @@ def preprocess( conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations + def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] @@ -777,8 +878,8 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, - data_path=data_args.data_path, - data_args=data_args) + data_path=data_args.data_path, + data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, @@ -809,7 +910,7 @@ def train(attn_implementation=None): llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, - bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} ) )) @@ -846,7 +947,7 @@ def train(attn_implementation=None): if training_args.bits in [4, 8]: from peft import prepare_model_for_kbit_training - model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) if training_args.gradient_checkpointing: @@ -912,7 +1013,7 @@ def make_inputs_require_grad(module, input, output): model_args=model_args, fsdp=training_args.fsdp ) - + vision_tower = model.get_vision_tower() vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) @@ -959,9 +1060,9 @@ def make_inputs_require_grad(module, input, output): data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = LLaVATrainer(model=model, - tokenizer=tokenizer, - args=training_args, - **data_module) + tokenizer=tokenizer, + args=training_args, + **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) diff --git a/llava/train/train_mem.py b/llava/train/train_mem.py index 29ea06170..d51a2e2ea 100644 --- a/llava/train/train_mem.py +++ b/llava/train/train_mem.py @@ -1,4 +1,4 @@ -from llava.train.train import train +from llava.train.train_qwen import train if __name__ == "__main__": train(attn_implementation="flash_attention_2") diff --git a/llava/train/train_qwen.py b/llava/train/train_qwen.py new file mode 100644 index 000000000..76ee7589d --- /dev/null +++ b/llava/train/train_qwen.py @@ -0,0 +1,1103 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from packaging import version +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch + +import transformers +import tokenizers + +from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from torch.utils.data import Dataset +from llava.train.llava_trainer import LLaVATrainer + +from llava import conversation as conversation_lib +from llava.model import * +from llava.mm_utils import tokenizer_image_token + +from PIL import Image + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + vision_tower: Optional[str] = field(default=None) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default='linear') + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default='flat') + mm_vision_select_feature: Optional[str] = field(default="patch") + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if getattr(trainer.args, "tune_mm_mlp_adapter", False): + # Only save Adapter + keys_to_match = ['mm_projector'] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(['embed_tokens', 'embed_in']) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split('/')[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith('checkpoint-'): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + else: + torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + return + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '') + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_llama_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +# fix: add qwen2 +def preprocess_qwen_2( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + # print('-----preprocess_qwen_2-------') + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.QWEN_2 + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + rounds_len = len(rounds) + cur_len = 0 + # target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_ids = tokenizer_image_token(rou, tokenizer) + instruction_ids = tokenizer_image_token(parts[0], tokenizer) + equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] + + instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) + round_len = len(round_ids) + + else: + round_ids = tokenizer(rou).input_ids + instruction_ids = tokenizer(parts[0]).input_ids + equal_parts = [x == y for x, y in zip(round_ids, instruction_ids)] + + instruction_len = equal_parts.index(False) if False in equal_parts else len(equal_parts) + round_len = len(round_ids) + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len += 1 + instruction_len += 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len + rounds_len - 2: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_mpt( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt + cur_len = 0 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: + round_len += 1 + instruction_len += 1 + + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_IMAGE_TOKEN in source[0]['value'] + source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + # print("conversation:",conversation_lib.default_conversation.version) + # conversation_lib.default_conversation.version == "qwen_v2" + + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + # print('--v1--') + return preprocess_v1(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "mpt": + # print('--mpt--') + return preprocess_mpt(sources, tokenizer, has_image=has_image) + # fix: add qwen2 + if conversation_lib.default_conversation.version.startswith("qwen_v2"): + # print('--qwen_v2--') + return preprocess_qwen_2(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + list_data_dict = json.load(open(data_path, "r")) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + if self.data_args.image_aspect_ratio == 'pad': + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + crop_size = self.data_args.image_processor.crop_size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + labels = labels[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_skip_modules=["mm_projector"], + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + if model_args.vision_tower is not None: + if 'mpt' in model_args.model_name_or_path: + config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + config.attn_config['attn_impl'] = training_args.mpt_attn_impl + model = LlavaMptForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=training_args.cache_dir, + **bnb_model_from_pretrained_args + ) + else: + model = LlavaQwen2ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + else: + model = transformers.Qwen2ForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + if 'mpt' in model_args.model_name_or_path: + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right" + ) + else: + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + if model_args.version == "v0": + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token="[PAD]"), + tokenizer=tokenizer, + model=model, + ) + elif model_args.version == "v0.5": + tokenizer.pad_token = tokenizer.unk_token + else: + if tokenizer.unk_token: + tokenizer.pad_token = tokenizer.unk_token + else: # use qwen + tokenizer.legacy = False + if model_args.version in conversation_lib.conv_templates: + # print('version:', model_args.version) + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + if model_args.vision_tower is not None: + model.get_model().initialize_vision_modules( + model_args=model_args, + fsdp=training_args.fsdp + ) + + vision_tower = model.get_vision_tower() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + if model_args.tune_mm_mlp_adapter: + model.requires_grad_(False) + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + if training_args.freeze_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = LLaVATrainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/llava/train/train_xformers.py b/llava/train/train_xformers.py index 23a59bf4e..e33096429 100644 --- a/llava/train/train_xformers.py +++ b/llava/train/train_xformers.py @@ -1,13 +1,13 @@ # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. # Need to call this before importing transformers. +from llava.train.train import train from llava.train.llama_xformers_attn_monkey_patch import ( replace_llama_attn_with_xformers_attn, ) replace_llama_attn_with_xformers_attn() -from llava.train.train import train if __name__ == "__main__": train() diff --git a/llava/utils.py b/llava/utils.py index 4006cf917..5e1b7eb29 100644 --- a/llava/utils.py +++ b/llava/utils.py @@ -61,6 +61,7 @@ class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ + def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger diff --git a/scripts/v1_5/finetune_qwen_2.sh b/scripts/v1_5/finetune_qwen_2.sh new file mode 100644 index 000000000..4d1593ac6 --- /dev/null +++ b/scripts/v1_5/finetune_qwen_2.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path lmsys/vicuna-13b-v1.5 \ + --version qwen_2 \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb From 24c4e351df2ddd4df55b50d3c96e8300ce05028b Mon Sep 17 00:00:00 2001 From: Yuzhe Date: Thu, 1 Aug 2024 01:18:22 +0800 Subject: [PATCH 2/2] update slurm script for qwen2 --- scripts/v1_5/finetune_qwen_2.sh | 6 +++--- scripts/v1_5/pretrain_qwen_2.sh | 35 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 scripts/v1_5/pretrain_qwen_2.sh diff --git a/scripts/v1_5/finetune_qwen_2.sh b/scripts/v1_5/finetune_qwen_2.sh index 4d1593ac6..ac56b7698 100644 --- a/scripts/v1_5/finetune_qwen_2.sh +++ b/scripts/v1_5/finetune_qwen_2.sh @@ -2,12 +2,12 @@ deepspeed llava/train/train_mem.py \ --deepspeed ./scripts/zero3.json \ - --model_name_or_path lmsys/vicuna-13b-v1.5 \ + --model_name_or_path Qwen/Qwen2-1.5B \ --version qwen_2 \ --data_path ./playground/data/llava_v1_5_mix665k.json \ --image_folder ./playground/data \ --vision_tower openai/clip-vit-large-patch14-336 \ - --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ + --pretrain_mm_mlp_adapter ./checkpoints/Qwen2-1.5B-pretrain/mm_projector.bin \ --mm_projector_type mlp2x_gelu \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ @@ -15,7 +15,7 @@ deepspeed llava/train/train_mem.py \ --image_aspect_ratio pad \ --group_by_modality_length True \ --bf16 True \ - --output_dir ./checkpoints/llava-v1.5-13b \ + --output_dir ./checkpoints/LLaVA-Qwen2-1.5B \ --num_train_epochs 1 \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 4 \ diff --git a/scripts/v1_5/pretrain_qwen_2.sh b/scripts/v1_5/pretrain_qwen_2.sh new file mode 100644 index 000000000..5814b0c76 --- /dev/null +++ b/scripts/v1_5/pretrain_qwen_2.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path Qwen/Qwen2-1.5B \ + --version plain \ + --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ + --image_folder ./playground/data/LLaVA-Pretrain/images \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --tune_mm_mlp_adapter True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/Qwen2-1.5B-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb