Skip to content

Commit 5b3dd09

Browse files
authored
feat: add new LMEval tests with chat template (#322)
* feat: improve lmeval tests * fix: remove unused function
1 parent 635ed2d commit 5b3dd09

File tree

3 files changed

+76
-58
lines changed

3 files changed

+76
-58
lines changed

tests/model_explainability/lm_eval/conftest.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,26 @@
2525

2626
@pytest.fixture(scope="function")
2727
def lmevaljob_hf(
28-
admin_client: DynamicClient, model_namespace: Namespace, patched_trustyai_operator_configmap_allow_online: ConfigMap
28+
request: FixtureRequest,
29+
admin_client: DynamicClient,
30+
model_namespace: Namespace,
31+
patched_trustyai_operator_configmap_allow_online: ConfigMap,
2932
) -> Generator[LMEvalJob, None, None]:
3033
with LMEvalJob(
3134
client=admin_client,
32-
name="test-job",
35+
name=LMEVALJOB_NAME,
3336
namespace=model_namespace.name,
3437
model="hf",
35-
model_args=[{"name": "pretrained", "value": "google/flan-t5-base"}],
36-
task_list={
37-
"custom": {
38-
"systemPrompts": [
39-
{"name": "sp_0", "value": "Be concise. At every point give the shortest acceptable answer."}
40-
],
41-
"templates": [
42-
{
43-
"name": "tp_0",
44-
"value": '{ "__type__": "input_output_template", '
45-
'"input_format": "{text_a_type}: {text_a}\\n'
46-
'{text_b_type}: {text_b}", '
47-
'"output_format": "{label}", '
48-
'"target_prefix": '
49-
'"The {type_of_relation} class is ", '
50-
'"instruction": "Given a {text_a_type} and {text_b_type} '
51-
'classify the {type_of_relation} of the {text_b_type} to one of {classes}.",'
52-
' "postprocessors": [ "processors.take_first_non_empty_line",'
53-
' "processors.lower_case_till_punc" ] }',
54-
}
55-
],
56-
},
57-
"taskRecipes": [
58-
{"card": {"name": "cards.wnli"}, "systemPrompt": {"ref": "sp_0"}, "template": {"ref": "tp_0"}}
59-
],
60-
},
38+
model_args=[{"name": "pretrained", "value": "Qwen/Qwen2.5-0.5B"}],
39+
task_list=request.param.get("task_list"),
6140
log_samples=True,
6241
allow_online=True,
6342
allow_code_execution=True,
43+
system_instruction="Be concise. At every point give the shortest acceptable answer.",
44+
chat_template={
45+
"enabled": True,
46+
},
47+
limit="0.01",
6448
) as job:
6549
yield job
6650

@@ -80,6 +64,7 @@ def lmevaljob_local_offline(
8064
model="hf",
8165
model_args=[{"name": "pretrained", "value": "/opt/app-root/src/hf_home/flan"}],
8266
task_list=request.param.get("task_list"),
67+
limit="0.01",
8368
log_samples=True,
8469
offline={"storage": {"pvcName": "lmeval-data"}},
8570
pod={
@@ -402,6 +387,13 @@ def lmevaljob_hf_pod(admin_client: DynamicClient, lmevaljob_hf: LMEvalJob) -> Ge
402387
yield get_lmevaljob_pod(client=admin_client, lmevaljob=lmevaljob_hf)
403388

404389

390+
@pytest.fixture(scope="function")
391+
def lmevaljob_local_offline_pod(
392+
admin_client: DynamicClient, lmevaljob_local_offline: LMEvalJob
393+
) -> Generator[Pod, Any, Any]:
394+
yield get_lmevaljob_pod(client=admin_client, lmevaljob=lmevaljob_local_offline)
395+
396+
405397
@pytest.fixture(scope="function")
406398
def lmevaljob_vllm_emulator_pod(
407399
admin_client: DynamicClient, lmevaljob_vllm_emulator: LMEvalJob

tests/model_explainability/lm_eval/test_lm_eval.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,65 @@
11
import pytest
22

3-
from tests.model_explainability.lm_eval.utils import verify_lmevaljob_running
43
from utilities.constants import Timeout
54

65
LMEVALJOB_COMPLETE_STATE: str = "Complete"
76

87

98
@pytest.mark.parametrize(
10-
"model_namespace",
9+
"model_namespace, lmevaljob_hf",
1110
[
1211
pytest.param(
13-
{"name": "test-lmeval-huggingface"},
14-
)
12+
{"name": "test-lmeval-hf-arc"}, {"task_list": {"taskNames": ["arc_challenge"]}}, id="arc_challenge"
13+
),
14+
pytest.param(
15+
{"name": "test-lmeval-hf-mmlu"},
16+
{"task_list": {"taskNames": ["mmlu_astronomy_generative"]}},
17+
id="mmlu_astronomy_generative",
18+
),
19+
pytest.param({"name": "test-lmeval-hf-hellaswag"}, {"task_list": {"taskNames": ["hellaswag"]}}, id="hellaswag"),
20+
pytest.param(
21+
{"name": "test-lmeval-hf-truthfulqa"}, {"task_list": {"taskNames": ["truthfulqa_gen"]}}, id="truthfulqa_gen"
22+
),
23+
pytest.param(
24+
{"name": "test-lmeval-hf-winogrande"}, {"task_list": {"taskNames": ["winogrande"]}}, id="winogrande"
25+
),
26+
pytest.param(
27+
{"name": "test-lmeval-hf-custom-task"},
28+
{
29+
"task_list": {
30+
"custom": {
31+
"systemPrompts": [
32+
{"name": "sp_0", "value": "Be concise. At every point give the shortest acceptable answer."}
33+
],
34+
"templates": [
35+
{
36+
"name": "tp_0",
37+
"value": '{ "__type__": "input_output_template", '
38+
'"input_format": "{text_a_type}: {text_a}\\n'
39+
'{text_b_type}: {text_b}", '
40+
'"output_format": "{label}", '
41+
'"target_prefix": '
42+
'"The {type_of_relation} class is ", '
43+
'"instruction": "Given a {text_a_type} and {text_b_type} '
44+
'classify the {type_of_relation} of the {text_b_type} to one of {classes}.",'
45+
' "postprocessors": [ "processors.take_first_non_empty_line",'
46+
' "processors.lower_case_till_punc" ] }',
47+
}
48+
],
49+
},
50+
"taskRecipes": [
51+
{"card": {"name": "cards.wnli"}, "systemPrompt": {"ref": "sp_0"}, "template": {"ref": "tp_0"}}
52+
],
53+
}
54+
},
55+
id="custom_task",
56+
),
1557
],
1658
indirect=True,
1759
)
1860
def test_lmeval_huggingface_model(admin_client, model_namespace, lmevaljob_hf_pod):
19-
"""Basic test that verifies that LMEval can run successfully pulling a model from HuggingFace."""
61+
"""Tests that verify running common evaluations (and a custom one) on a model pulled directly from HuggingFace.
62+
On each test we run a different evaluation task, limiting it to 1% of the questions on each eval."""
2063
lmevaljob_hf_pod.wait_for_status(status=lmevaljob_hf_pod.Status.SUCCEEDED, timeout=Timeout.TIMEOUT_20MIN)
2164

2265

@@ -39,10 +82,12 @@ def test_lmeval_local_offline_builtin_tasks_flan_arceasy(
3982
admin_client,
4083
model_namespace,
4184
lmeval_data_downloader_pod,
42-
lmevaljob_local_offline,
85+
lmevaljob_local_offline_pod,
4386
):
4487
"""Test that verifies that LMEval can run successfully in local, offline mode using builtin tasks"""
45-
verify_lmevaljob_running(client=admin_client, lmevaljob=lmevaljob_local_offline)
88+
lmevaljob_local_offline_pod.wait_for_status(
89+
status=lmevaljob_local_offline_pod.Status.SUCCEEDED, timeout=Timeout.TIMEOUT_20MIN
90+
)
4691

4792

4893
@pytest.mark.parametrize(
@@ -72,10 +117,12 @@ def test_lmeval_local_offline_unitxt_tasks_flan_20newsgroups(
72117
admin_client,
73118
model_namespace,
74119
lmeval_data_downloader_pod,
75-
lmevaljob_local_offline,
120+
lmevaljob_local_offline_pod,
76121
):
77122
"""Test that verifies that LMEval can run successfully in local, offline mode using unitxt"""
78-
verify_lmevaljob_running(client=admin_client, lmevaljob=lmevaljob_local_offline)
123+
lmevaljob_local_offline_pod.wait_for_status(
124+
status=lmevaljob_local_offline_pod.Status.SUCCEEDED, timeout=Timeout.TIMEOUT_20MIN
125+
)
79126

80127

81128
@pytest.mark.parametrize(

tests/model_explainability/lm_eval/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,12 @@
33
from ocp_resources.pod import Pod
44

55
from utilities.constants import Timeout
6-
from utilities.infra import check_pod_status_in_time
76
from simple_logger.logger import get_logger
87

98

109
LOGGER = get_logger(name=__name__)
1110

1211

13-
def verify_lmevaljob_running(client: DynamicClient, lmevaljob: LMEvalJob) -> None:
14-
"""
15-
Verifies that an LMEvalJob Pod reaches Running state and maintains Running/Succeeded state.
16-
Waits for Pod to enter Running state, then checks it stays Running or Succeeded for 2 minutes.
17-
18-
Args:
19-
client: DynamicClient instance for interacting with Kubernetes
20-
lmevaljob: LMEvalJob object representing the job to verify
21-
22-
Raises:
23-
TimeoutError: If Pod doesn't reach Running state within 10 minutes
24-
AssertionError: If Pod doesn't stay in one of the desired states for 2 minutes
25-
"""
26-
27-
lmevaljob_pod = Pod(client=client, name=lmevaljob.name, namespace=lmevaljob.namespace, wait_for_resource=True)
28-
lmevaljob_pod.wait_for_status(status=lmevaljob_pod.Status.RUNNING, timeout=Timeout.TIMEOUT_20MIN)
29-
30-
check_pod_status_in_time(pod=lmevaljob_pod, status={lmevaljob_pod.Status.RUNNING, lmevaljob_pod.Status.SUCCEEDED})
31-
32-
3312
def get_lmevaljob_pod(client: DynamicClient, lmevaljob: LMEvalJob, timeout: int = Timeout.TIMEOUT_2MIN) -> Pod:
3413
"""
3514
Gets the pod corresponding to a given LMEvalJob and waits for it to be ready.

0 commit comments

Comments
 (0)