Skip to content

Parallel fixes #1754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e94132b
started on geneartor
Jul 3, 2024
44934a8
training script
Jul 3, 2024
9657413
going through codebase
Jul 5, 2024
94da3a5
reworkig things
Jul 12, 2024
7caaf55
filtering tokens
sophiamirrashidi Jul 16, 2024
9afb3b8
rename
sophiamirrashidi Jul 17, 2024
8351c77
saving
sophiamirrashidi Jul 21, 2024
35a0d75
debugging pipeline
sophiamirrashidi Jul 21, 2024
f4075a1
finished debugging
sophiamirrashidi Jul 22, 2024
50ddf73
update
sophiamirrashidi Jul 22, 2024
a96449e
integrated discriminator
Jul 23, 2024
76fb9ce
fix?
sophiamirrashidi Jul 24, 2024
ee61bfd
move call to training bug
lpullela Jul 24, 2024
57515d4
Merge remote-tracking branch 'refs/remotes/origin/main'
lpullela Jul 24, 2024
1262f5a
adding updated llava arch
lpullela Jul 24, 2024
1287702
disc working with 92% acc, can add more layers to nn for better acc
lpullela Jul 25, 2024
b2d8882
laya fine tuning script
lpullela Jul 28, 2024
e143c1f
integrated discriminator within fine tuning script, i believe this sh…
lpullela Jul 29, 2024
05db792
float conversions not necessary unless lora
lpullela Jul 29, 2024
87e7336
laya: error in getting pure language token embeddings
lpullela Jul 29, 2024
3aab4f8
check this
lpullela Jul 29, 2024
ea4fa59
adding local changes
sophiamirrashidi Aug 14, 2024
f959052
provisional training changes
sophiamirrashidi Aug 15, 2024
ea030ef
forgot a file
sophiamirrashidi Aug 15, 2024
60ece7c
fixed debugger and working on adding the the d_mode:
sophiamirrashidi Aug 19, 2024
13eac7c
added in d_mode
sophiamirrashidi Aug 20, 2024
a30a34e
created second optimizer
sophiamirrashidi Aug 20, 2024
538c928
fixed discriminator
sophiamirrashidi Aug 24, 2024
0568ab2
pipeline is ready
sophiamirrashidi Aug 28, 2024
fd9ede0
trying to test discrim
sophiamirrashidi Sep 4, 2024
43c2844
tentative eval
sophiamirrashidi Sep 11, 2024
40c42d7
code from meeting 9/11
sophiamirrashidi Sep 11, 2024
f005264
saving discriminator
sophiamirrashidi Sep 16, 2024
b539185
adding rest of files from saving the discrim
sophiamirrashidi Sep 16, 2024
9822afb
successfully evaluated discrimiantor
sophiamirrashidi Sep 18, 2024
2c46155
finished training:
sophiamirrashidi Sep 21, 2024
765dd47
eval works but now training doesnt
sophiamirrashidi Sep 23, 2024
9646ea8
updated code for training - added in checkpointing and changed some f…
sophiamirrashidi Sep 27, 2024
b1ebae6
set up intermittent checkpointing and logic fixes
sophiamirrashidi Sep 27, 2024
14e2b18
parallel rewrite for multigpu training
lpullela Nov 1, 2024
bfae342
eval params
lpullela Nov 1, 2024
0740ec7
dmode off for testing
lpullela Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions llava/VLLMSafety/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import json
import random

class Discriminator(nn.Module):
def __init__(self, input_size):
super().__init__() # we can add more layers later

self.fc1 = nn.Linear(input_size, 50)
self.fc2 = nn.Linear(50, 1)

def linear(self, x):
x = F.relu(self.fc1(x))
x = torch.sigmoid(self.fc2(x))

return x

def forward(self, data, d_mode):
device = 'cuda'
loss_function = nn.BCELoss() # follow DCgan

image_batch = data['image'][0].view(-1, 5120).to(device)
img_tok = image_batch.view(-1, 5120) # flatten the lists

img_pred = self.linear(img_tok)
img_label = torch.full((img_tok.size(0), 1), 1, dtype=torch.bfloat16, device=device) # use label 1 for imgs
img_loss = loss_function(img_pred, img_label)

total_lang_loss = 0
lang_correct_count = 0
total_lang_preds = 0
img_correct_count = torch.eq(torch.ge(img_pred, 0.5).float().to(torch.bfloat16), img_label).sum().item()
img_accuracy = img_correct_count / img_tok.size(0) * 100

for lang_tensor in data["lang"]:
lang_tensor = lang_tensor.to(device)
lang_pred = self.linear(lang_tensor.view(-1, 5120)) # Process each lang tensor independently
lang_label = torch.full((lang_pred.size(0), 1), 0, dtype=torch.bfloat16, device=device) # Label 0 for language

lang_loss = loss_function(lang_pred, lang_label)
total_lang_loss += lang_loss

#for accuracy calculations
lang_correct = torch.eq(torch.ge(lang_pred, 0.5).float().to(torch.bfloat16), lang_label).sum().item()
lang_correct_count += lang_correct
total_lang_preds += lang_pred.size(0)

if d_mode:
lang_accuracy = lang_correct_count / total_lang_preds * 100
print(f"Image Accuracy: {img_accuracy:.2f}%")
print(f"Language Accuracy: {lang_accuracy:.2f}%")

loss = img_loss + total_lang_loss

return {
"loss": loss,
"img_is_correct": img_correct_count,
"lang_is_correct": lang_correct_count,
"img_accuracy": img_accuracy,
"lang_accuracy": lang_accuracy,
}

else:
img_with_lang_label_loss = loss_function(img_pred, torch.full((img_tok.size(0), 1), 0, dtype=torch.bfloat16, device=device))
return img_with_lang_label_loss
133 changes: 133 additions & 0 deletions llava/VLLMSafety/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
import torch
import torch.nn as nn
from PIL import Image
import json
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import math
import os
import random
from datetime import datetime

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
#from llava.model.llava_arch import prepare_inputs_labels_for_multimodal

from discriminator import preprocess_and_call_train

def split_list(lst, n): # taken from model_vqa.py
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k): # taken from model_vqa.py
chunks = split_list(lst, n)
return chunks[k]

def get_tkns(input_ids, image_tensor, model, img_size):
'''takes in one prompt and one image and returns a dictionary that has the language tokens
from the prompt as one entry and the image tokens from the prompt as another'''
position_ids = None # set to None in generate()
attention_mask = None # set to None in generate() must figure out if this and the above is acceptable

# prep_inputs... returns None as the first value, but idk why
none_q, position_ids, attention_mask, past_key_values, input_embeds, labels, chunk_sizes = model.prepare_inputs_labels_for_multimodal(
input_ids = input_ids,
position_ids = position_ids,
attention_mask = attention_mask,
past_key_values = None,
labels = None,
images = image_tensor.unsqueeze(0).half().cuda(),
image_sizes = img_size
)

split_embeddings = torch.split(input_embeds[0], chunk_sizes, dim=0)
lang_tkns = split_embeddings[2] # only the second to avoid adding the same tokens over and over
img_tkns = split_embeddings[1]

tkn_dict = {
"lang_tkns": lang_tkns,
"img_tkns": img_tkns
}

return tkn_dict

def prep_batches(line, model, tokenizer, image_processor, rags, **kwargs):
q_id = line["id"] # can be used to identify each batch, probably good to use to keep track of progress during training
image_file = line["image"]
qs = line["text"]

if qs.startswith(f"{DEFAULT_IMAGE_TOKEN}\n") == False:
idx = qs.find(DEFAULT_IMAGE_TOKEN) + len(DEFAULT_IMAGE_TOKEN)
qs = qs[idx:].strip()
qs = qs[idx:]
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
assert qs.startswith(f"{DEFAULT_IMAGE_TOKEN}\n") == True, f'no image tag found in text \n text = {qs} \n id = {q_id}'

# something to note: this appends a default prompt to each prompt, might impact discrim since it will keep getting trained on
# the same tokens. i'll adjust to remove this soon

conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
image_tensor = process_images([image], image_processor, model.config)[0]
image_sizes = [image.size]

tkn_dict = get_tkns(input_ids, image_tensor, model, image_sizes) #returns tkn_dict with image and language tokens

return tkn_dict

def train(args):
args_dict = vars(args)

device = 'cuda' # set device appropriately
EPOCHS = 10
G_losses = []
D_losses = []
iters = 0

## boot up model and get everything running properly
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)

# get data - following along with model_vqa.py
questions = [json.loads(q) for q in open(os.path.expanduser(args.conversation_file), "r")]
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)

# right now each batch is created one by one for each conversation in the file, maybe we want to precompute all the
# batches ahead of time? maybe we consolidate this for-loop into a function? for now it should work but
# just some things to think about

for line in questions:
tkn_dict = prep_batches(line, model, tokenizer, image_processor, args, **args_dict)

projection_model = preprocess_and_call_train(tkn_dict)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default= "/home/smirrashidi/llava-v1.5-13b")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image_folder", type=str, default= "/home/smirrashidi/coco_data/images")
parser.add_argument("--conversation_file", type=str, default= "/home/smirrashidi/coco_data/discrim_data.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
args = parser.parse_args()

train(args)
34 changes: 34 additions & 0 deletions llava/VLLMSafety/send.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import argparse
import torch
import os
import random
from datetime import datetime

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

from discriminator.py import Discriminator # import Laya's discriminator
from make_data.py import CustomDataset # impor the dataset class

def get_data(image_folder, language_file):
'''takes in images and the language file and outputs a shuffled list
of both the images after going through _projector and the tokenized language tokens'''

return None

def send_to_discriminator():
## boot up model and get everything running properly
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)]

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--data, type=str, default=")
args = parser.parse_args()
137 changes: 137 additions & 0 deletions llava/VLLMSafety/test_discrim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import argparse
import torch
import os
import json
from tqdm import tqdm
from torch.utils.data import DataLoader

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.train.train import *
from llava.train.llava_trainer import LLaVATrainer

from PIL import Image
import math
from llava.VLLMSafety.discriminator import Discriminator
from datetime import date

class DiscArguments:
test_data_path: str = "/home/smirrashidi/coco_data/coco_test_conversations.json"
test_image_folder: str = "/home/smirrashidi/coco_data/coco_test"
model_path: str = "/home/smirrashidi/LLaVAFork/checkpoints/llava-v1.5-13b-lora-disc"
model_base: str = "lmsys/vicuna-13b-v1.3"

def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args, disc_args, testing) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""

if testing == True:
data_args.image_folder = "/home/smirrashidi/coco_data/coco_test"
data_args.data_path = "/home/smirrashidi/coco_data/coco_test_conversations.json"

train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)

data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

else:
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)


def eval_model(args):
# Model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
model_args = ModelArguments(model_name_or_path = args.model_path)
data_args = DataArguments(data_path = args.question_file,
image_folder = args.image_folder)
training_args = TrainingArguments(output_dir="/home/smirrashidi/dump")
disc_args = DiscArguments

tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)

model = model.to(torch.bfloat16)

data_args.image_processor = image_processor

test_data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args,
disc_args=disc_args,
testing = True)

data_collator = test_data_module['data_collator']

test_dataloader = DataLoader(
test_data_module['train_dataset'],
batch_size=4,
collate_fn=data_collator,
shuffle=False)

eval_disc_data = {
"num_img_corr": 0,
"num_lang_corr": 0,
"img_total": 0,
"lang_total": 0,
}

for i, batch in enumerate(test_dataloader):
print(f"Iteration #{i}")
input_ids = batch['input_ids']
image = batch['images']
with torch.inference_mode():
discrim_dict = model.forward_eval_discrim(
input_ids = input_ids,
images = image
)

eval_disc_data["num_img_corr"] += discrim_dict["img_is_correct"].sum().item()
eval_disc_data["num_lang_corr"] += discrim_dict["lang_is_correct"].sum().item()
eval_disc_data["img_total"] += discrim_dict["img_is_correct"].size(0)
eval_disc_data["lang_total"] += discrim_dict["lang_is_correct"].size(0)

eval_disc_data["date"] = date.today().strftime('%Y-%m-%d')
print(eval_disc_data)

with open("/home/smirrashidi/eval_discrim_results.json", "a") as json_file:
json.dump(eval_disc_data, json_file)
json_file.write("\n")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="/home/smirrashidi/LLaVAFork/checkpoints/llava-v1.5-13b-lora-disc")
parser.add_argument("--model-base", type=str, default="lmsys/vicuna-13b-v1.3")
parser.add_argument("--image-folder", type=str, default="/home/smirrashidi/coco_data/coco_test")
parser.add_argument("--question-file", type=str, default="/home/smirrashidi/coco_data/coco_test_conversations.json")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
args = parser.parse_args()

eval_model(args)

Loading