Skip to content

Commit e5ea952

Browse files
authored
[llm_bench] fix vlm processing without image and add more supported models (openvinotoolkit#2182)
1 parent a9a4e41 commit e5ea952

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

tools/llm_bench/llm_bench_utils/model_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ def get_param_from_file(args, input_key):
5555
data_dict = {}
5656
if "media" in input_key:
5757
if args["media"] is None and args["images"] is None:
58-
if args["use_case"] != "vlm":
58+
if args["use_case"] == "vlm":
5959
log.warn("Input image is not provided. Only text generation part will be evaluated")
6060
elif args["use_case"] != "image_gen":
6161
raise RuntimeError("No input image. ImageToImage/Inpainting Models cannot start generation without one. Please, provide an image.")
6262
else:
6363
data_dict["media"] = args["media"] if args["media"] is not None else args["images"]
6464
if args["prompt"] is None:
65-
if args["use_case"] != "vlm":
66-
data_dict["prompt"] = "What is OpenVINO?" if data_dict["media"] is None else "Describe image"
65+
if args["use_case"] == "vlm":
66+
data_dict["prompt"] = "What is OpenVINO?" if data_dict.get("media") is None else "Describe image"
6767
elif args['use_case'] == 'image_gen':
6868
data_dict["prompt"] = 'sailing ship in storm by Leonardo da Vinci'
6969
else:
@@ -216,6 +216,10 @@ def get_use_case(model_name_or_path):
216216
return "image_gen", pipe_type.replace("Pipeline", "")
217217

218218
if config is not None:
219+
case, model_name = resolve_complex_model_types(config)
220+
if case is not None:
221+
log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_name}')
222+
return case, model_name
219223
for case, model_ids in USE_CASES.items():
220224
for idx, model_id in enumerate(normalize_model_ids(model_ids)):
221225
if config.get("model_type").lower().replace('_', '-').startswith(model_id):
@@ -230,6 +234,19 @@ def get_use_case(model_name_or_path):
230234
return case, model_name
231235

232236

237+
def resolve_complex_model_types(config):
238+
model_type = config.get("model_type").lower().replace('_', '-')
239+
if model_type == "gemma3":
240+
return "vlm", model_type
241+
if model_type == "gemma3-text":
242+
return "text_gen", model_type
243+
if model_type in ["phi4mm", "phi4-multimodal"]:
244+
return "vlm", model_type
245+
if model_type == "llama4":
246+
return "vlm", model_type
247+
return None, None
248+
249+
233250
def get_model_name(model_name_or_path):
234251
# try to get use_case from model name
235252
path = os.path.normpath(model_name_or_path)

0 commit comments

Comments
 (0)