Skip to content

Commit 8cab9bf

Browse files
committed
fix test
1 parent e475fa7 commit 8cab9bf

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

tools/who_what_benchmark/tests/test_cli_text.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformers import AutoTokenizer
1111
from optimum.intel.openvino import OVModelForCausalLM, OVWeightQuantizationConfig
1212

13+
from test_cli_image import get_similarity
1314
from conftest import convert_text_model, run_wwb
1415

1516

@@ -279,9 +280,10 @@ def test_text_genai_json_string_config():
279280

280281
@pytest.mark.parametrize(
281282
("model_id"),
282-
[("TinyLlama/TinyLlama-1.1B-Chat-v1.0")],
283+
[("optimum-intel-internal-testing/tiny-random-Phi3ForCausalLM")],
283284
)
284285
def test_text_chat_model(model_id, tmp_path):
286+
SIMILARITY_THRESHOLD = 0.9
285287
temp_file_name = tmp_path / "gt.csv"
286288
chat_model_path = convert_text_model(model_id, model_id.split("/")[1], _convert_base)
287289

@@ -322,6 +324,9 @@ def test_text_chat_model(model_id, tmp_path):
322324
assert (outputs_path / "metrics_per_question.csv").exists()
323325
assert (outputs_path / "metrics.csv").exists()
324326
assert (outputs_path / "target.csv").exists()
327+
328+
similarity = get_similarity(output)
329+
assert similarity >= SIMILARITY_THRESHOLD
325330

326331
outputs_path = tmp_path / "genai"
327332
output = run_wwb(
@@ -345,3 +350,5 @@ def test_text_chat_model(model_id, tmp_path):
345350
assert (outputs_path / "metrics_per_question.csv").exists()
346351
assert (outputs_path / "metrics.csv").exists()
347352
assert (outputs_path / "target.csv").exists()
353+
similarity = get_similarity(output)
354+
assert similarity >= SIMILARITY_THRESHOLD

tools/who_what_benchmark/whowhatbench/whowhat_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def evaluate_similarity(model, data_gold, data_prediction):
3535
# in chat mode - gold, prediction are list of answers
3636
metric_per_chat_answer_list = []
3737
metric_per_question = []
38-
for i, gold, prediction in tqdm(
38+
for i, (gold, prediction) in tqdm(
3939
enumerate(zip(answers_gold, answers_prediction)),
4040
total=min(len(answers_gold), len(answers_prediction)),
4141
desc="Similarity evaluation",

0 commit comments

Comments
 (0)