Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions tests/model_explainability/lm_eval/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
CUSTOM_UNITXT_TASK_DATA = {
"task_list": {
"custom": {
"systemPrompts": [
{"name": "sp_0", "value": "Be concise. At every point give the shortest acceptable answer."}
],
"templates": [
{
"name": "tp_0",
"value": '{ "__type__": "input_output_template", '
'"input_format": "{text_a_type}: {text_a}\\n'
'{text_b_type}: {text_b}", '
'"output_format": "{label}", '
'"target_prefix": '
'"The {type_of_relation} class is ", '
'"instruction": "Given a {text_a_type} and {text_b_type} '
'classify the {type_of_relation} of the {text_b_type} to one of {classes}.",'
' "postprocessors": [ "processors.take_first_non_empty_line",'
' "processors.lower_case_till_punc" ] }',
}
],
},
"taskRecipes": [{"card": {"name": "cards.wnli"}, "systemPrompt": {"ref": "sp_0"}, "template": {"ref": "tp_0"}}],
}
}

LLMAAJ_TASK_DATA = {
"task_list": {
"custom": {
"templates": [
{
"name": "response_assessment.rating.mt_bench_single_turn",
"value": '{\n "__type__": "input_output_template",\n "instruction":'
' "Please act as an impartial judge and evaluate the quality of the '
"response provided by an AI assistant to the user question displayed below."
" Your evaluation should consider factors such as the helpfulness, relevance,"
" accuracy, depth, creativity, and level of detail of the response. Begin your"
" evaluation by providing a short explanation. Be as objective as possible. "
"After providing your explanation, you must rate the response on a scale of 1 to 10"
' by strictly following this format: \\"[[rating]]\\", for example: \\"Rating: '
'[[5]]\\".\\n\\n",\n "input_format": "[Question]\\n{question}\\n\\n[The Start '
"of Assistant's Answer]\\n{answer}\\n[The End of Assistant's Answer]\",\n "
'"output_format": "[[{rating}]]",\n "postprocessors": [\n '
'"processors.extract_mt_bench_rating_judgment"\n ]\n}\n',
}
],
"tasks": [
{
"name": "response_assessment.rating.single_turn",
"value": '{\n "__type__": "task",\n "input_fields": {\n '
'"question": "str",\n "answer": "str"\n },\n '
'"outputs": {\n "rating": "float"\n },\n '
'"metrics": [\n "metrics.spearman"\n ]\n}\n',
}
],
"metrics": [
{
"name": "llmaaj_metric",
"value": '{\n "__type__": "llm_as_judge",\n "inference_model": {\n '
'"__type__": "hf_pipeline_based_inference_engine",\n '
'"model_name": "rgeada/tiny-untrained-granite",\n '
'"max_new_tokens": 256,\n "use_fp16": true},\n '
'"template": "templates.response_assessment.rating.mt_bench_single_turn",\n '
'"task": "response_assessment.rating.single_turn",\n '
'"main_score": "mistral_7b_instruct_v0_2_huggingface_template_mt_bench_single_turn"\n}',
}
],
},
"taskRecipes": [
{
"card": {
"custom": '{\n "__type__": "task_card",\n "loader": '
'{\n "__type__": "load_hf",\n '
'"path": "OfirArviv/mt_bench_single_score_gpt4_judgement",\n '
'"split": "train"\n },\n "preprocess_steps": [\n '
'{\n "__type__": "rename_splits",\n '
'"mapper": {\n "train": "test"\n }\n },\n '
'{\n "__type__": "filter_by_condition",\n '
'"values": {\n "turn": 1\n },\n '
'"condition": "eq"\n },\n {\n '
'"__type__": "filter_by_condition",\n '
'"values": {\n "reference": "[]"\n },\n '
'"condition": "eq"\n },\n {\n '
'"__type__": "rename",\n "field_to_field": {\n '
'"model_input": "question",\n '
'"score": "rating",\n '
'"category": "group",\n '
'"model_output": "answer"\n }\n },\n '
'{\n "__type__": "literal_eval",\n '
'"field": "question"\n },\n '
'{\n "__type__": "copy",\n '
'"field": "question/0",\n '
'"to_field": "question"\n },\n '
'{\n "__type__": "literal_eval",\n '
'"field": "answer"\n },\n {\n '
'"__type__": "copy",\n '
'"field": "answer/0",\n '
'"to_field": "answer"\n }\n ],\n '
'"task": "tasks.response_assessment.rating.single_turn",\n '
'"templates": [\n '
'"templates.response_assessment.rating.mt_bench_single_turn"\n ]\n}\n',
"template": {"ref": "response_assessment.rating.mt_bench_single_turn"},
"metrics": [{"ref": "llmaaj_metric"}],
}
}
],
}
}
Empty file.
35 changes: 7 additions & 28 deletions tests/model_explainability/lm_eval/test_lm_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from typing import List


from tests.model_explainability.lm_eval.constants import LLMAAJ_TASK_DATA, CUSTOM_UNITXT_TASK_DATA
from tests.model_explainability.utils import validate_tai_component_images

from tests.model_explainability.lm_eval.utils import get_lmeval_tasks, validate_lmeval_job_pod_and_logs
Expand All @@ -28,35 +28,14 @@
),
pytest.param(
{"name": "test-lmeval-hf-custom-task"},
{
"task_list": {
"custom": {
"systemPrompts": [
{"name": "sp_0", "value": "Be concise. At every point give the shortest acceptable answer."}
],
"templates": [
{
"name": "tp_0",
"value": '{ "__type__": "input_output_template", '
'"input_format": "{text_a_type}: {text_a}\\n'
'{text_b_type}: {text_b}", '
'"output_format": "{label}", '
'"target_prefix": '
'"The {type_of_relation} class is ", '
'"instruction": "Given a {text_a_type} and {text_b_type} '
'classify the {type_of_relation} of the {text_b_type} to one of {classes}.",'
' "postprocessors": [ "processors.take_first_non_empty_line",'
' "processors.lower_case_till_punc" ] }',
}
],
},
"taskRecipes": [
{"card": {"name": "cards.wnli"}, "systemPrompt": {"ref": "sp_0"}, "template": {"ref": "tp_0"}}
],
}
},
CUSTOM_UNITXT_TASK_DATA,
id="custom_task",
),
pytest.param(
{"name": "test-lmeval-hf-llmaaj"},
LLMAAJ_TASK_DATA,
id="llmaaj_task",
),
],
indirect=True,
)
Expand Down
5 changes: 3 additions & 2 deletions tests/model_explainability/lm_eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
import re
from pyhelper_utils.general import tts
from kubernetes.dynamic import DynamicClient
from ocp_resources.lm_eval_job import LMEvalJob
from ocp_resources.pod import Pod
Expand Down Expand Up @@ -100,9 +101,9 @@ def validate_lmeval_job_pod_and_logs(lmevaljob_pod: Pod) -> None:
r"INFO\sdriver\supdate status: job completed\s\{\"state\":\s\{\"state\""
r":\"Complete\",\"reason\":\"Succeeded\",\"message\":\"job completed\""
)
lmevaljob_pod.wait_for_status(status=lmevaljob_pod.Status.RUNNING, timeout=Timeout.TIMEOUT_5MIN)
lmevaljob_pod.wait_for_status(status=lmevaljob_pod.Status.RUNNING, timeout=tts("5m"))
try:
lmevaljob_pod.wait_for_status(status=Pod.Status.SUCCEEDED, timeout=Timeout.TIMEOUT_20MIN)
lmevaljob_pod.wait_for_status(status=Pod.Status.SUCCEEDED, timeout=tts("1h"))
except TimeoutExpiredError as e:
raise UnexpectedFailureError("LMEval job pod failed from a running state.") from e
if not bool(re.search(pod_success_log_regex, lmevaljob_pod.log())):
Expand Down