diff --git a/llava/VLLMSafety/discriminator.py b/llava/VLLMSafety/discriminator.py new file mode 100644 index 000000000..0a20a7027 --- /dev/null +++ b/llava/VLLMSafety/discriminator.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import numpy as np +import json +import random + +class Discriminator(nn.Module): + def __init__(self, input_size): + super().__init__() # we can add more layers later + + self.fc1 = nn.Linear(input_size, 50) + self.fc2 = nn.Linear(50, 1) + + def linear(self, x): + x = F.relu(self.fc1(x)) + x = torch.sigmoid(self.fc2(x)) + + return x + + def forward(self, data, d_mode): + device = 'cuda' + loss_function = nn.BCELoss() # follow DCgan + + image_batch = data['image'][0].view(-1, 5120).to(device) + img_tok = image_batch.view(-1, 5120) # flatten the lists + + img_pred = self.linear(img_tok) + img_label = torch.full((img_tok.size(0), 1), 1, dtype=torch.bfloat16, device=device) # use label 1 for imgs + img_loss = loss_function(img_pred, img_label) + + total_lang_loss = 0 + lang_correct_count = 0 + total_lang_preds = 0 + img_correct_count = torch.eq(torch.ge(img_pred, 0.5).float().to(torch.bfloat16), img_label).sum().item() + img_accuracy = img_correct_count / img_tok.size(0) * 100 + + for lang_tensor in data["lang"]: + lang_tensor = lang_tensor.to(device) + lang_pred = self.linear(lang_tensor.view(-1, 5120)) # Process each lang tensor independently + lang_label = torch.full((lang_pred.size(0), 1), 0, dtype=torch.bfloat16, device=device) # Label 0 for language + + lang_loss = loss_function(lang_pred, lang_label) + total_lang_loss += lang_loss + + #for accuracy calculations + lang_correct = torch.eq(torch.ge(lang_pred, 0.5).float().to(torch.bfloat16), lang_label).sum().item() + lang_correct_count += lang_correct + total_lang_preds += lang_pred.size(0) + + if d_mode: + lang_accuracy = lang_correct_count / total_lang_preds * 100 + print(f"Image Accuracy: {img_accuracy:.2f}%") + print(f"Language Accuracy: {lang_accuracy:.2f}%") + + loss = img_loss + total_lang_loss + + return { + "loss": loss, + "img_is_correct": img_correct_count, + "lang_is_correct": lang_correct_count, + "img_accuracy": img_accuracy, + "lang_accuracy": lang_accuracy, + } + + else: + img_with_lang_label_loss = loss_function(img_pred, torch.full((img_tok.size(0), 1), 0, dtype=torch.bfloat16, device=device)) + return img_with_lang_label_loss \ No newline at end of file diff --git a/llava/VLLMSafety/pipeline.py b/llava/VLLMSafety/pipeline.py new file mode 100644 index 000000000..d2e9b9db9 --- /dev/null +++ b/llava/VLLMSafety/pipeline.py @@ -0,0 +1,133 @@ +import argparse +import torch +import torch.nn as nn +from PIL import Image +import json +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data +import math +import os +import random +from datetime import datetime + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +#from llava.model.llava_arch import prepare_inputs_labels_for_multimodal + +from discriminator import preprocess_and_call_train + +def split_list(lst, n): # taken from model_vqa.py + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): # taken from model_vqa.py + chunks = split_list(lst, n) + return chunks[k] + +def get_tkns(input_ids, image_tensor, model, img_size): + '''takes in one prompt and one image and returns a dictionary that has the language tokens + from the prompt as one entry and the image tokens from the prompt as another''' + position_ids = None # set to None in generate() + attention_mask = None # set to None in generate() must figure out if this and the above is acceptable + + # prep_inputs... returns None as the first value, but idk why + none_q, position_ids, attention_mask, past_key_values, input_embeds, labels, chunk_sizes = model.prepare_inputs_labels_for_multimodal( + input_ids = input_ids, + position_ids = position_ids, + attention_mask = attention_mask, + past_key_values = None, + labels = None, + images = image_tensor.unsqueeze(0).half().cuda(), + image_sizes = img_size + ) + + split_embeddings = torch.split(input_embeds[0], chunk_sizes, dim=0) + lang_tkns = split_embeddings[2] # only the second to avoid adding the same tokens over and over + img_tkns = split_embeddings[1] + + tkn_dict = { + "lang_tkns": lang_tkns, + "img_tkns": img_tkns + } + + return tkn_dict + +def prep_batches(line, model, tokenizer, image_processor, rags, **kwargs): + q_id = line["id"] # can be used to identify each batch, probably good to use to keep track of progress during training + image_file = line["image"] + qs = line["text"] + + if qs.startswith(f"{DEFAULT_IMAGE_TOKEN}\n") == False: + idx = qs.find(DEFAULT_IMAGE_TOKEN) + len(DEFAULT_IMAGE_TOKEN) + qs = qs[idx:].strip() + qs = qs[idx:] + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + assert qs.startswith(f"{DEFAULT_IMAGE_TOKEN}\n") == True, f'no image tag found in text \n text = {qs} \n id = {q_id}' + + # something to note: this appends a default prompt to each prompt, might impact discrim since it will keep getting trained on + # the same tokens. i'll adjust to remove this soon + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + + image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') + image_tensor = process_images([image], image_processor, model.config)[0] + image_sizes = [image.size] + + tkn_dict = get_tkns(input_ids, image_tensor, model, image_sizes) #returns tkn_dict with image and language tokens + + return tkn_dict + +def train(args): + args_dict = vars(args) + + device = 'cuda' # set device appropriately + EPOCHS = 10 + G_losses = [] + D_losses = [] + iters = 0 + + ## boot up model and get everything running properly + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) + + # get data - following along with model_vqa.py + questions = [json.loads(q) for q in open(os.path.expanduser(args.conversation_file), "r")] + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + + # right now each batch is created one by one for each conversation in the file, maybe we want to precompute all the + # batches ahead of time? maybe we consolidate this for-loop into a function? for now it should work but + # just some things to think about + + for line in questions: + tkn_dict = prep_batches(line, model, tokenizer, image_processor, args, **args_dict) + + projection_model = preprocess_and_call_train(tkn_dict) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default= "/home/smirrashidi/llava-v1.5-13b") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image_folder", type=str, default= "/home/smirrashidi/coco_data/images") + parser.add_argument("--conversation_file", type=str, default= "/home/smirrashidi/coco_data/discrim_data.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_v1") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + args = parser.parse_args() + + train(args) diff --git a/llava/VLLMSafety/send.py b/llava/VLLMSafety/send.py new file mode 100644 index 000000000..6509ad862 --- /dev/null +++ b/llava/VLLMSafety/send.py @@ -0,0 +1,34 @@ +import argparse +import torch +import os +import random +from datetime import datetime + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path + +from discriminator.py import Discriminator # import Laya's discriminator +from make_data.py import CustomDataset # impor the dataset class + +def get_data(image_folder, language_file): + '''takes in images and the language file and outputs a shuffled list + of both the images after going through _projector and the tokenized language tokens''' + + return None + +def send_to_discriminator(): + ## boot up model and get everything running properly + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)] + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--data, type=str, default=") + args = parser.parse_args() \ No newline at end of file diff --git a/llava/VLLMSafety/test_discrim.py b/llava/VLLMSafety/test_discrim.py new file mode 100644 index 000000000..0c888f287 --- /dev/null +++ b/llava/VLLMSafety/test_discrim.py @@ -0,0 +1,137 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +from torch.utils.data import DataLoader + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +from llava.train.train import * +from llava.train.llava_trainer import LLaVATrainer + +from PIL import Image +import math +from llava.VLLMSafety.discriminator import Discriminator +from datetime import date + +class DiscArguments: + test_data_path: str = "/home/smirrashidi/coco_data/coco_test_conversations.json" + test_image_folder: str = "/home/smirrashidi/coco_data/coco_test" + model_path: str = "/home/smirrashidi/LLaVAFork/checkpoints/llava-v1.5-13b-lora-disc" + model_base: str = "lmsys/vicuna-13b-v1.3" + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args, disc_args, testing) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + + if testing == True: + data_args.image_folder = "/home/smirrashidi/coco_data/coco_test" + data_args.data_path = "/home/smirrashidi/coco_data/coco_test_conversations.json" + + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + else: + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + model_args = ModelArguments(model_name_or_path = args.model_path) + data_args = DataArguments(data_path = args.question_file, + image_folder = args.image_folder) + training_args = TrainingArguments(output_dir="/home/smirrashidi/dump") + disc_args = DiscArguments + + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) + + model = model.to(torch.bfloat16) + + data_args.image_processor = image_processor + + test_data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args, + disc_args=disc_args, + testing = True) + + data_collator = test_data_module['data_collator'] + + test_dataloader = DataLoader( + test_data_module['train_dataset'], + batch_size=4, + collate_fn=data_collator, + shuffle=False) + + eval_disc_data = { + "num_img_corr": 0, + "num_lang_corr": 0, + "img_total": 0, + "lang_total": 0, + } + + for i, batch in enumerate(test_dataloader): + print(f"Iteration #{i}") + input_ids = batch['input_ids'] + image = batch['images'] + with torch.inference_mode(): + discrim_dict = model.forward_eval_discrim( + input_ids = input_ids, + images = image + ) + + eval_disc_data["num_img_corr"] += discrim_dict["img_is_correct"].sum().item() + eval_disc_data["num_lang_corr"] += discrim_dict["lang_is_correct"].sum().item() + eval_disc_data["img_total"] += discrim_dict["img_is_correct"].size(0) + eval_disc_data["lang_total"] += discrim_dict["lang_is_correct"].size(0) + + eval_disc_data["date"] = date.today().strftime('%Y-%m-%d') + print(eval_disc_data) + + with open("/home/smirrashidi/eval_discrim_results.json", "a") as json_file: + json.dump(eval_disc_data, json_file) + json_file.write("\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="/home/smirrashidi/LLaVAFork/checkpoints/llava-v1.5-13b-lora-disc") + parser.add_argument("--model-base", type=str, default="lmsys/vicuna-13b-v1.3") + parser.add_argument("--image-folder", type=str, default="/home/smirrashidi/coco_data/coco_test") + parser.add_argument("--question-file", type=str, default="/home/smirrashidi/coco_data/coco_test_conversations.json") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_v1") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + args = parser.parse_args() + + eval_model(args) + diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py index 069d0d1c1..897428bff 100644 --- a/llava/model/language_model/llava_llama.py +++ b/llava/model/language_model/llava_llama.py @@ -1,4 +1,4 @@ -# Copyright 2023 Haotian Liu + # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,9 +23,14 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput +from llava.VLLMSafety.discriminator import Discriminator from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM +from transformers.modeling_utils import * +from transformers.modeling_utils import _add_variant +import wandb + class LlavaConfig(LlamaConfig): model_type = "llava_llama" @@ -47,10 +52,20 @@ def __init__(self, config): self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.disc_data = { + "image": None, + "lang": None, + } + + self.eval_mode = False + + if not self.eval_mode: + self.discriminator = Discriminator(5120) # hard coding in sizes for now # Initialize weights and apply final processing self.post_init() + def get_model(self): return self.model @@ -68,7 +83,8 @@ def forward( images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + d_mode: Optional[bool] = False, # False means run without discriminator + ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( @@ -87,19 +103,111 @@ def forward( images, image_sizes ) + + if self.eval_mode: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + d_mode=False # if you want to turn off the disc completely + - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict - ) + if d_mode == True: + discrim_dict = self.discriminator.forward(self.disc_data, d_mode=True) # d loss is sum of disc loss on images and lang + model_output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + d_loss = discrim_dict["loss"] + + data = {'disc loss': d_loss.item()} + # with open('/home/smirrashidi/loss_9-24.json', 'a') as f: + # json.dump(data, f) + # f.write('\n') + + model_output.loss = d_loss # returning only discriminator loss + + return model_output + else: + # d_loss = self.discriminator.forward(self.disc_data, d_mode=False) # d loss is sum of disc loss on images and lang; same call in both if and else + model_output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + #wandb.log({"generator_disc loss": d_loss}) + wandb.log({"generator loss": model_output.loss}) + + model_output.loss = model_output.loss #+ d_loss # returning sum of model and discriminator loss + wandb.log({"summed loss": model_output.loss}) + + return model_output + + def forward_eval_discrim( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + d_mode: Optional[bool] = True, # False means run without discriminator + eval_disc: Optional[bool] = True + ): + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + + discrim_dict = self.discriminator.forward(self.disc_data, d_mode=True) # d loss is sum of disc loss on images and lang + + return discrim_dict @torch.no_grad() def generate( @@ -152,7 +260,4 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs['images'] = images if image_sizes is not None: inputs['image_sizes'] = image_sizes - return inputs - -AutoConfig.register("llava_llama", LlavaConfig) -AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) + return inputs \ No newline at end of file diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index d71650eac..808d1a9f4 100644 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -146,6 +146,9 @@ def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None ): + self.disc_data['image'] = [] + self.disc_data['lang'] = [] + vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels @@ -154,7 +157,8 @@ def prepare_inputs_labels_for_multimodal( if type(images) is list: images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) + raw_image_features = self.encode_images(concat_images) + image_features = raw_image_features split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') @@ -200,6 +204,9 @@ def prepare_inputs_labels_for_multimodal( raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") else: image_features = self.encode_images(images) + raw_image_features = image_features + + self.disc_data['image'].append(raw_image_features) # TODO: image start / end is not implemented here to support pretraining. if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): @@ -250,6 +257,12 @@ def prepare_inputs_labels_for_multimodal( split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + + + #curr input embeds is coming from cur_input_ids_noim which means its already filitered + self.disc_data['lang'].append(cur_input_embeds_no_im[1]) + #print(f'self.disc_data: {self.disc_data}\n') + cur_new_input_embeds = [] cur_new_labels = [] @@ -365,4 +378,4 @@ def initialize_vision_tokenizer(self, model_args, tokenizer): for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): - p.requires_grad = False + p.requires_grad = False \ No newline at end of file diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py index ce2853a41..0ee1022da 100644 --- a/llava/train/llava_trainer.py +++ b/llava/train/llava_trainer.py @@ -1,19 +1,62 @@ import os import torch import torch.nn as nn - -from torch.utils.data import Sampler +import math +import torch.optim as optim +from packaging import version +import time +import deepspeed +import random +import sys +import json + +from typing import Dict, Optional, Union, List, Any, Tuple + +from torch.utils.data import Sampler, Dataset, DataLoader +from transformers.trainer_utils import SchedulerType, PredictionOutput, EvalLoopOutput +from llava.VLLMSafety.discriminator import Discriminator +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint +from transformers.optimization import get_scheduler +from transformers.modeling_utils import unwrap_model +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES from transformers import Trainer +from transformers.trainer_pt_utils import nested_detach from transformers.trainer import ( is_sagemaker_mp_enabled, get_parameter_names, has_length, + get_model_param_count, + speed_metrics, + _is_peft_model, + get_dataloader_sampler, + hp_params, + skip_first_batches, + is_torch_tpu_available, ALL_LAYERNORM_LAYERS, logger, + accelerate_version, + DebugOption, + DebugUnderflowOverflow, + TrainerState, + HPSearchBackend, + TrainOutput, + shutil, + RandomSampler, + ParallelMode ) -from typing import List, Optional +from typing import List, Optional, Union +import wandb + +TRAINER_STATE_NAME = "trainer_state.json" +lr = 0.0002 +beta1 = 0.5 + +#os.environ['WANDB_MODE'] = 'disabled' +wandb.init( + project="llava_safety" +) def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero @@ -97,7 +140,7 @@ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, class LengthGroupedSampler(Sampler): - r""" + """ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness. """ @@ -131,6 +174,8 @@ def __iter__(self): class LLaVATrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): @@ -147,6 +192,20 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: else: return super()._get_train_sampler() + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + self.create_optimizer() + + optimizer = self.optimizer + self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + def create_optimizer(self): """ Setup the optimizer. @@ -164,18 +223,14 @@ def create_optimizer(self): decay_parameters = [name for name in decay_parameters if "bias" not in name] if self.args.mm_projector_lr is not None: projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] + discriminator_parameters = [name for name, _ in opt_model.named_parameters() if "discriminator" in name] optimizer_grouped_parameters = [ { "params": [ - p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) - ], - "weight_decay": self.args.weight_decay, - }, - { - "params": [ - p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) + p for n, p in opt_model.named_parameters() if (n in discriminator_parameters and p.requires_grad) ], - "weight_decay": 0.0, + "weight_decay": 0, # TODO: this can be a hyperparameter + "lr": lr, }, { "params": [ @@ -192,7 +247,7 @@ def create_optimizer(self): "lr": self.args.mm_projector_lr, }, ] - else: + else: # our code will never go here optimizer_grouped_parameters = [ { "params": [ @@ -208,6 +263,7 @@ def create_optimizer(self): }, ] + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) @@ -225,6 +281,14 @@ def create_optimizer(self): logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") + self.d_optimizer = optim.Adam(opt_model.discriminator.parameters(), lr= lr, betas=(beta1, 0.999)) # how to get discriminator parameters? + + for name, param in opt_model.named_parameters(): + if 'mm_projector' not in name and 'discriminator' not in name: + param.requires_grad = False + + # turn off all the params in the model that are not part of the projector or discriminator + return self.optimizer def _save_checkpoint(self, model, trial, metrics=None): @@ -253,3 +317,551 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): pass else: super(LLaVATrainer, self)._save(output_dir, state_dict) + + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the intial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + num_train_tokens = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + ) + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState() + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + if args.gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + else: + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. + if not args.ignore_data_skip: + for epoch in range(epochs_trained): + sampler = get_dataloader_sampler(train_dataloader) + sampler_kinds = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + sampler_kinds.append(SeedableRandomSampler) + is_random_sampler = isinstance(sampler, tuple(sampler_kinds)) + if not is_random_sampler: + # We just need to begin an iteration to create the randomization of the sampler. + for _ in train_dataloader: + break + else: + # Otherwise we need to call the whooooole sampler cause there is some random operation added + # AT THE VERY END! + sampler = sampler if sampler is not None else [] + _ = list(sampler) + + total_batched_samples = 0 + for epoch in range(epochs_trained, num_train_epochs): + epoch_iterator = train_dataloader + if hasattr(epoch_iterator, "set_epoch"): + epoch_iterator.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + for step, inputs in enumerate(epoch_iterator): + inputs["d_mode"] = True if step % 2 == 0 else False + total_batched_samples += 1 + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel() + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + with self.accelerator.accumulate(model): + tr_loss_step = self.training_step(model, inputs) + + if ( + args.logging_nan_inf_filter + and not is_torch_tpu_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) + + if ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or + # last step in epoch but step is always smaller than gradient_accumulation_steps + is_last_step_and_steps_less_than_grad_acc + ): + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + if is_last_step_and_steps_less_than_grad_acc: + self.accelerator.gradient_state._set_sync_gradients(True) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if is_sagemaker_mp_enabled() and args.fp16: + self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + # Optimizer step + self.optimizer.step() + + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + print(self.state.epoch) + + if abs(self.state.epoch - 0.01) < 1e-4: + print(f"Saving checkpoint at epoch {self.state.epoch}") + self._save_checkpoint(model, trial=None) # only saves the mm_projector weights, which can be passed into the model later + + elif abs(self.state.epoch - 0.25) < 1e-4: + print(f"Saving checkpoint at epoch {self.state.epoch}") + self._save_checkpoint(model, trial=None) + + elif abs(self.state.epoch - 0.5) < 1e-4: + print(f"Saving checkpoint at epoch {self.state.epoch}") + self._save_checkpoint(model, trial=None) + + elif abs(self.state.epoch - 0.75) < 1e-4: + print(f"Saving checkpoint at epoch {self.state.epoch}") + self._save_checkpoint(model, trial=None) + + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + torch.cuda.empty_cache() + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + if step < 0: + logger.warning( + "There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_tpu_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_tpu_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + train_loss = self._total_loss_scalar / self.state.global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + gan style, compute d_loss and g_loss and update optimizers accordingly + """ + model.train() + inputs = self._prepare_inputs(inputs) + + # get d loss + #d_loss = self._compute_loss_for_discriminator(model, inputs) + #self._backward_pass(d_loss, self.d_optimizer, update_optimizer=True, loss_name="discriminator_loss") + + # get g loss + g_loss = self._compute_loss_for_generator(model, inputs) + self._backward_pass(g_loss, self.optimizer, update_optimizer=False, loss_name="generator_loss") + + + total_loss = d_loss.detach() + g_loss.detach() + return total_loss / self.args.gradient_accumulation_steps + + def _compute_loss_for_discriminator(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + inputs['d_mode'] = True # enable discriminator mode + with self.compute_loss_context_manager(): + d_loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + d_loss = d_loss.mean() # average loss across multiple GPUs + + return d_loss + + def _compute_loss_for_generator(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + inputs['d_mode'] = False # enable generator mode + with self.compute_loss_context_manager(): + g_loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + g_loss = g_loss.mean() # Average loss across multiple GPUs + + return g_loss + + def _backward_pass(self, loss: torch.Tensor, optimizer, update_optimizer: bool, loss_name: str): + + if self.use_apex: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss) # backwards pass + + # only update d_optimizer (we want g_optimizer to go through grad clips) + if update_optimizer: + optimizer.step() + optimizer.zero_grad() + + # Log the loss using WandB + #wandb.log({loss_name: loss.item()}) \ No newline at end of file diff --git a/llava/train/train.py b/llava/train/train.py index 477c668b6..5b3c774c9 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -169,7 +169,7 @@ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() - multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] + multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler', 'discriminator'] for name, module in model.named_modules(): if any(mm_keyword in name for mm_keyword in multimodal_keywords): continue @@ -787,6 +787,7 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, def train(attn_implementation=None): global local_rank + print("Starting Training") parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments)) @@ -928,11 +929,13 @@ def make_inputs_require_grad(module, input, output): model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True + print("Tuning mm_projector") model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter if training_args.freeze_mm_mlp_adapter: for p in model.get_model().mm_projector.parameters(): p.requires_grad = False + print("\n\nif this is printing then you are not training the projector") if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) @@ -958,6 +961,18 @@ def make_inputs_require_grad(module, input, output): data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + model.to("cuda") + + for name, param in model.discriminator.named_parameters(): + param.requires_grad = True + + for name, param in model.discriminator.named_parameters(): + assert param.requires_grad, f"Parameter {name} does not have requires_grad set to True" + + for name, param in model.get_model().mm_projector.named_parameters(): + assert param.requires_grad, f"Parameter {name} does not have requires_grad set to True" + trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, @@ -978,6 +993,7 @@ def make_inputs_require_grad(module, input, output): non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) + print(non_lora_state_dict) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) diff --git a/scripts/finetune_test.sh b/scripts/finetune_test.sh new file mode 100644 index 000000000..425911f4f --- /dev/null +++ b/scripts/finetune_test.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path liuhaotian/llava-v1.5-13b \ + --version v1 \ + --data_path /home/lpullela/LLaVA/playground/data/llava_v1_5_mix665k_subset10.json \ + --image_folder /home/lpullela/LLaVA/playground/data/ \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b-task-lora \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ No newline at end of file diff --git a/scripts/v1_5/eval_discrim.sh b/scripts/v1_5/eval_discrim.sh new file mode 100644 index 000000000..d4a274e20 --- /dev/null +++ b/scripts/v1_5/eval_discrim.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# llava/train/train_mem.py + +# REMOVE TEST DATA PATH, TEST DATA IMAGES, TESTING + +# change output dir back to llava-v1.5-13b-lora_disc + +deepspeed llava/VLLMSafety/evaluate_disc.py \ + --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path lmsys/vicuna-13b-v1.5 \ + --version v1 \ + --data_path ./playground/data/llava_v1_5_mix665k.json \ + --image_folder ./playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b/mm_projector.bin \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./checkpoints/llava-v1.5-13b-lora_eval \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 --save_total_limit 1 \ + --learning_rate 2e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ No newline at end of file diff --git a/scripts/v1_5/finetune.sh b/scripts/v1_5/finetune.sh index 435448394..14c4a94c1 100644 --- a/scripts/v1_5/finetune.sh +++ b/scripts/v1_5/finetune.sh @@ -7,7 +7,7 @@ deepspeed llava/train/train_mem.py \ --data_path ./playground/data/llava_v1_5_mix665k.json \ --image_folder ./playground/data \ --vision_tower openai/clip-vit-large-patch14-336 \ - --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b/mm_projector.bin \ --mm_projector_type mlp2x_gelu \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ @@ -15,9 +15,9 @@ deepspeed llava/train/train_mem.py \ --image_aspect_ratio pad \ --group_by_modality_length True \ --bf16 True \ - --output_dir ./checkpoints/llava-v1.5-13b \ + --output_dir ./checkpoints/llava-v1.5-13b-train_disc_no_lora \ --num_train_epochs 1 \ - --per_device_train_batch_size 16 \ + --per_device_train_batch_size 4 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ diff --git a/scripts/v1_5/finetune_lora.sh b/scripts/v1_5/finetune_lora.sh index 90f00707c..c2cdf24d6 100644 --- a/scripts/v1_5/finetune_lora.sh +++ b/scripts/v1_5/finetune_lora.sh @@ -8,7 +8,7 @@ deepspeed llava/train/train_mem.py \ --data_path ./playground/data/llava_v1_5_mix665k.json \ --image_folder ./playground/data \ --vision_tower openai/clip-vit-large-patch14-336 \ - --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b/mm_projector.bin \ --mm_projector_type mlp2x_gelu \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ @@ -16,15 +16,14 @@ deepspeed llava/train/train_mem.py \ --image_aspect_ratio pad \ --group_by_modality_length True \ --bf16 True \ - --output_dir ./checkpoints/llava-v1.5-13b-lora \ + --output_dir ./checkpoints/llava-v1.5-13b-lora-9-26 \ --num_train_epochs 1 \ - --per_device_train_batch_size 16 \ + --per_device_train_batch_size 2 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ - --save_steps 50000 \ - --save_total_limit 1 \ + --save_steps 50000 --save_total_limit 1 \ --learning_rate 2e-4 \ --weight_decay 0. \ --warmup_ratio 0.03 \ @@ -35,4 +34,5 @@ deepspeed llava/train/train_mem.py \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ - --report_to wandb + --report_to wandb \ + --tune_mm_mlp_adapter True \ No newline at end of file