Skip to content

Commit 5d76696

Browse files
committed
[AI Explainability] Add LLM as a Judge test in LMEval (#584)
* feat: add llmaaj initial * feat: add LLMAAJ * feat: increase Timeout for LMEval tasks * fix: task name in metrics * feat: increase timeout for lmeval task * fix: change model name * fix: akctually remove model format
1 parent cdfe48d commit 5d76696

File tree

4 files changed

+118
-30
lines changed

4 files changed

+118
-30
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
CUSTOM_UNITXT_TASK_DATA = {
2+
"task_list": {
3+
"custom": {
4+
"systemPrompts": [
5+
{"name": "sp_0", "value": "Be concise. At every point give the shortest acceptable answer."}
6+
],
7+
"templates": [
8+
{
9+
"name": "tp_0",
10+
"value": '{ "__type__": "input_output_template", '
11+
'"input_format": "{text_a_type}: {text_a}\\n'
12+
'{text_b_type}: {text_b}", '
13+
'"output_format": "{label}", '
14+
'"target_prefix": '
15+
'"The {type_of_relation} class is ", '
16+
'"instruction": "Given a {text_a_type} and {text_b_type} '
17+
'classify the {type_of_relation} of the {text_b_type} to one of {classes}.",'
18+
' "postprocessors": [ "processors.take_first_non_empty_line",'
19+
' "processors.lower_case_till_punc" ] }',
20+
}
21+
],
22+
},
23+
"taskRecipes": [{"card": {"name": "cards.wnli"}, "systemPrompt": {"ref": "sp_0"}, "template": {"ref": "tp_0"}}],
24+
}
25+
}
26+
27+
LLMAAJ_TASK_DATA = {
28+
"task_list": {
29+
"custom": {
30+
"templates": [
31+
{
32+
"name": "response_assessment.rating.mt_bench_single_turn",
33+
"value": '{\n "__type__": "input_output_template",\n "instruction":'
34+
' "Please act as an impartial judge and evaluate the quality of the '
35+
"response provided by an AI assistant to the user question displayed below."
36+
" Your evaluation should consider factors such as the helpfulness, relevance,"
37+
" accuracy, depth, creativity, and level of detail of the response. Begin your"
38+
" evaluation by providing a short explanation. Be as objective as possible. "
39+
"After providing your explanation, you must rate the response on a scale of 1 to 10"
40+
' by strictly following this format: \\"[[rating]]\\", for example: \\"Rating: '
41+
'[[5]]\\".\\n\\n",\n "input_format": "[Question]\\n{question}\\n\\n[The Start '
42+
"of Assistant's Answer]\\n{answer}\\n[The End of Assistant's Answer]\",\n "
43+
'"output_format": "[[{rating}]]",\n "postprocessors": [\n '
44+
'"processors.extract_mt_bench_rating_judgment"\n ]\n}\n',
45+
}
46+
],
47+
"tasks": [
48+
{
49+
"name": "response_assessment.rating.single_turn",
50+
"value": '{\n "__type__": "task",\n "input_fields": {\n '
51+
'"question": "str",\n "answer": "str"\n },\n '
52+
'"outputs": {\n "rating": "float"\n },\n '
53+
'"metrics": [\n "metrics.spearman"\n ]\n}\n',
54+
}
55+
],
56+
"metrics": [
57+
{
58+
"name": "llmaaj_metric",
59+
"value": '{\n "__type__": "llm_as_judge",\n "inference_model": {\n '
60+
'"__type__": "hf_pipeline_based_inference_engine",\n '
61+
'"model_name": "rgeada/tiny-untrained-granite",\n '
62+
'"max_new_tokens": 256,\n "use_fp16": true},\n '
63+
'"template": "templates.response_assessment.rating.mt_bench_single_turn",\n '
64+
'"task": "response_assessment.rating.single_turn",\n '
65+
'"main_score": "mistral_7b_instruct_v0_2_huggingface_template_mt_bench_single_turn"\n}',
66+
}
67+
],
68+
},
69+
"taskRecipes": [
70+
{
71+
"card": {
72+
"custom": '{\n "__type__": "task_card",\n "loader": '
73+
'{\n "__type__": "load_hf",\n '
74+
'"path": "OfirArviv/mt_bench_single_score_gpt4_judgement",\n '
75+
'"split": "train"\n },\n "preprocess_steps": [\n '
76+
'{\n "__type__": "rename_splits",\n '
77+
'"mapper": {\n "train": "test"\n }\n },\n '
78+
'{\n "__type__": "filter_by_condition",\n '
79+
'"values": {\n "turn": 1\n },\n '
80+
'"condition": "eq"\n },\n {\n '
81+
'"__type__": "filter_by_condition",\n '
82+
'"values": {\n "reference": "[]"\n },\n '
83+
'"condition": "eq"\n },\n {\n '
84+
'"__type__": "rename",\n "field_to_field": {\n '
85+
'"model_input": "question",\n '
86+
'"score": "rating",\n '
87+
'"category": "group",\n '
88+
'"model_output": "answer"\n }\n },\n '
89+
'{\n "__type__": "literal_eval",\n '
90+
'"field": "question"\n },\n '
91+
'{\n "__type__": "copy",\n '
92+
'"field": "question/0",\n '
93+
'"to_field": "question"\n },\n '
94+
'{\n "__type__": "literal_eval",\n '
95+
'"field": "answer"\n },\n {\n '
96+
'"__type__": "copy",\n '
97+
'"field": "answer/0",\n '
98+
'"to_field": "answer"\n }\n ],\n '
99+
'"task": "tasks.response_assessment.rating.single_turn",\n '
100+
'"templates": [\n '
101+
'"templates.response_assessment.rating.mt_bench_single_turn"\n ]\n}\n',
102+
"template": {"ref": "response_assessment.rating.mt_bench_single_turn"},
103+
"metrics": [{"ref": "llmaaj_metric"}],
104+
}
105+
}
106+
],
107+
}
108+
}

tests/model_explainability/lm_eval/data/__init__.py

Whitespace-only changes.

tests/model_explainability/lm_eval/test_lm_eval.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from typing import List
33

4-
4+
from tests.model_explainability.lm_eval.constants import LLMAAJ_TASK_DATA, CUSTOM_UNITXT_TASK_DATA
55
from tests.model_explainability.utils import validate_tai_component_images
66

77
from tests.model_explainability.lm_eval.utils import get_lmeval_tasks, validate_lmeval_job_pod_and_logs
@@ -28,35 +28,14 @@
2828
),
2929
pytest.param(
3030
{"name": "test-lmeval-hf-custom-task"},
31-
{
32-
"task_list": {
33-
"custom": {
34-
"systemPrompts": [
35-
{"name": "sp_0", "value": "Be concise. At every point give the shortest acceptable answer."}
36-
],
37-
"templates": [
38-
{
39-
"name": "tp_0",
40-
"value": '{ "__type__": "input_output_template", '
41-
'"input_format": "{text_a_type}: {text_a}\\n'
42-
'{text_b_type}: {text_b}", '
43-
'"output_format": "{label}", '
44-
'"target_prefix": '
45-
'"The {type_of_relation} class is ", '
46-
'"instruction": "Given a {text_a_type} and {text_b_type} '
47-
'classify the {type_of_relation} of the {text_b_type} to one of {classes}.",'
48-
' "postprocessors": [ "processors.take_first_non_empty_line",'
49-
' "processors.lower_case_till_punc" ] }',
50-
}
51-
],
52-
},
53-
"taskRecipes": [
54-
{"card": {"name": "cards.wnli"}, "systemPrompt": {"ref": "sp_0"}, "template": {"ref": "tp_0"}}
55-
],
56-
}
57-
},
31+
CUSTOM_UNITXT_TASK_DATA,
5832
id="custom_task",
5933
),
34+
pytest.param(
35+
{"name": "test-lmeval-hf-llmaaj"},
36+
LLMAAJ_TASK_DATA,
37+
id="llmaaj_task",
38+
),
6039
],
6140
indirect=True,
6241
)

tests/model_explainability/lm_eval/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List
22
import re
3+
from pyhelper_utils.general import tts
34
from kubernetes.dynamic import DynamicClient
45
from ocp_resources.lm_eval_job import LMEvalJob
56
from ocp_resources.pod import Pod
@@ -100,9 +101,9 @@ def validate_lmeval_job_pod_and_logs(lmevaljob_pod: Pod) -> None:
100101
r"INFO\sdriver\supdate status: job completed\s\{\"state\":\s\{\"state\""
101102
r":\"Complete\",\"reason\":\"Succeeded\",\"message\":\"job completed\""
102103
)
103-
lmevaljob_pod.wait_for_status(status=lmevaljob_pod.Status.RUNNING, timeout=Timeout.TIMEOUT_5MIN)
104+
lmevaljob_pod.wait_for_status(status=lmevaljob_pod.Status.RUNNING, timeout=tts("5m"))
104105
try:
105-
lmevaljob_pod.wait_for_status(status=Pod.Status.SUCCEEDED, timeout=Timeout.TIMEOUT_20MIN)
106+
lmevaljob_pod.wait_for_status(status=Pod.Status.SUCCEEDED, timeout=tts("1h"))
106107
except TimeoutExpiredError as e:
107108
raise UnexpectedFailureError("LMEval job pod failed from a running state.") from e
108109
if not bool(re.search(pod_success_log_regex, lmevaljob_pod.log())):

0 commit comments

Comments
 (0)