From bb3a5c4432480e762db1b1bde0fc2c11b64c7b98 Mon Sep 17 00:00:00 2001 From: aliaga Date: Tue, 17 Sep 2024 14:25:09 +0900 Subject: [PATCH] added evaluation for multiple images --- llava/conversation.py | 10 +++++ llava/eval/run_llava.py | 96 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/llava/conversation.py b/llava/conversation.py index 00c56867d..2aabb6ec1 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -391,6 +391,16 @@ def dict(self): "mpt": conv_mpt, } +# Gets a copy of the converstation according to templates and modifies the system with custom if not None +def get_conv(conv_mode, custom): + conv = conv_templates[conv_mode].copy() + + # Here we just modify the system prompt. We do not do any checking on prompt format + # In later iterations we might check if for example the system requires |im_start| or so + if custom is not None: + conv.system = custom + + return conv if __name__ == "__main__": print(default_conversation.get_prompt()) diff --git a/llava/eval/run_llava.py b/llava/eval/run_llava.py index 24b0fffcc..f364fb577 100644 --- a/llava/eval/run_llava.py +++ b/llava/eval/run_llava.py @@ -8,7 +8,7 @@ DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER, ) -from llava.conversation import conv_templates, SeparatorStyle +from llava.conversation import conv_templates, SeparatorStyle, get_conv from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import ( @@ -128,6 +128,100 @@ def eval_model(args): print(outputs) +def eval_multiple(args): + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model( + args.model_path, args.model_base, model_name + ) + + qs = args.query + image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + if IMAGE_PLACEHOLDER in qs: + if model.config.mm_use_im_start_end: + qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) + else: + qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) + else: + if model.config.mm_use_im_start_end: + qs = image_token_se + "\n" + qs + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + if "llama-2" in model_name.lower(): + conv_mode = "llava_llama_2" + elif "mistral" in model_name.lower(): + conv_mode = "mistral_instruct" + elif "v1.6-34b" in model_name.lower(): + conv_mode = "chatml_direct" + elif "v1" in model_name.lower(): + conv_mode = "llava_v1" + elif "mpt" in model_name.lower(): + conv_mode = "mpt" + else: + conv_mode = "llava_v0" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print( + "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( + conv_mode, args.conv_mode, args.conv_mode + ) + ) + else: + args.conv_mode = conv_mode + + # Later we are going to consider history + # conv = conv_templates[args.conv_mode].copy() + + if not hasattr(args,"custom_system"): + args.custom_system = None + conv = get_conv(args.conv_mode, args.custom_system) + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + print(f"PROMPT {prompt}") + + image_files = image_parser(args) + images = load_images(image_files) + image_sizes = [x.size for x in images] + + input_ids = ( + tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") + .unsqueeze(0) + .cuda() + ) + + for c, image in enumerate(images): + + images_tensor = process_images( + [image], + image_processor, + model.config + ).to(model.device, dtype=torch.float16) + + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=images_tensor, + image_sizes=image_sizes, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + max_new_tokens=args.max_new_tokens, + use_cache=True, + ) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + + # print(f"Image {image_files[c]} : {outputs}") + yield image_files[c] , outputs + + + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m")