Skip to content

Commit 2f45719

Browse files
committed
Add accuracy tests.
Signed-off-by: Muti Chung <[email protected]>
1 parent db21a1c commit 2f45719

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from tempfile import TemporaryDirectory
2+
3+
from lm_eval.evaluator import simple_evaluate
4+
5+
from llmcompressor.modifiers.awq.convert_autoawq import convert_and_save
6+
from tests.testing_utils import requires_gpu
7+
8+
9+
def run_lm_eval(model_name_or_path: str):
10+
results = simple_evaluate(
11+
model="hf",
12+
model_args=f"pretrained={model_name_or_path},dtype=float16",
13+
tasks=["arc_challenge", "arc_easy"],
14+
num_fewshot=5,
15+
batch_size=16,
16+
)
17+
18+
return results
19+
20+
21+
def compare_models(model_name_or_path: str):
22+
autoawq_result = run_lm_eval(model_name_or_path)
23+
with TemporaryDirectory() as converted_model_dir:
24+
convert_and_save(model_name_or_path, converted_model_dir, "naive-quantized")
25+
converted_result = run_lm_eval(converted_model_dir)
26+
27+
arc_c_autoawq = autoawq_result["results"]["arc_challenge"]["acc_norm,none"]
28+
arc_c_converted = converted_result["results"]["arc_challenge"]["acc_norm,none"]
29+
arc_e_autoawq = autoawq_result["results"]["arc_easy"]["acc_norm,none"]
30+
arc_e_converted = converted_result["results"]["arc_easy"]["acc_norm,none"]
31+
32+
assert abs(arc_e_autoawq - arc_e_converted) < 1e-2, (
33+
f"Arc Easy: autoawq={arc_e_autoawq} != converted={arc_e_converted}."
34+
)
35+
assert abs(arc_c_autoawq - arc_c_converted) < 1e-2, (
36+
f"Arc Challenge: autoawq={arc_c_autoawq} != converted={arc_c_converted}."
37+
)
38+
39+
40+
@requires_gpu
41+
def test_mistral():
42+
compare_models(
43+
"fbaldassarri/mistralai_Mistral-7B-Instruct-v0.3-autoawq-int4-gs128-asym"
44+
)
45+
46+
47+
@requires_gpu
48+
def test_qwen():
49+
compare_models(
50+
"ruikangliu/DeepSeek-R1-Distill-Qwen-1.5B-quantized.awq-autoawq-w4g128"
51+
)
52+
53+
54+
@requires_gpu
55+
def test_llama():
56+
compare_models("AMead10/Llama-3.2-3B-Instruct-AWQ")

0 commit comments

Comments
 (0)