Skip to content

Commit eaadde9

Browse files
committed
improve lmeval provider tests
1 parent 6395477 commit eaadde9

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed
Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import pytest
22

33
from tests.llama_stack.constants import LlamaStackProviders
4+
from utilities.constants import Timeout
5+
from timeout_sampler import TimeoutSampler
6+
7+
8+
TRUSTYAI_LMEVAL_ARCEASY = f"{LlamaStackProviders.Eval.TRUSTYAI_LMEVAL}::arc_easy"
49

510

611
@pytest.mark.parametrize(
@@ -19,34 +24,30 @@
1924
@pytest.mark.rawdeployment
2025
class TestLlamaStackLMEvalProvider:
2126
"""
22-
Adds basic tests for the LlamaStack LMEval provider.
23-
27+
Tests for the LlamaStack LMEval provider.
2428
1. Register the LLM that will be evaluated.
25-
2. Register the arc_easy benchmark (eval)
26-
3. TODO: Add test for run_eval
29+
2. Register the arc_easy benchmark.
30+
3. Run the evaluation and wait until it's completed.
2731
"""
2832

2933
def test_lmeval_register_benchmark(self, llama_stack_client):
3034
llama_stack_client.models.register(
3135
provider_id=LlamaStackProviders.Inference.VLLM_INFERENCE, model_type="llm", model_id="qwen"
3236
)
3337

34-
provider_id = LlamaStackProviders.Eval.TRUSTYAI_LMEVAL
35-
trustyai_lmeval_arc_easy = f"{provider_id}::arc_easy"
3638
llama_stack_client.benchmarks.register(
37-
benchmark_id=trustyai_lmeval_arc_easy,
38-
dataset_id=trustyai_lmeval_arc_easy,
39+
benchmark_id=TRUSTYAI_LMEVAL_ARCEASY,
40+
dataset_id=TRUSTYAI_LMEVAL_ARCEASY,
3941
scoring_functions=["string"],
4042
provider_id=LlamaStackProviders.Eval.TRUSTYAI_LMEVAL,
4143
provider_benchmark_id="string",
4244
metadata={"tokenized_requests": False, "tokenizer": "google/flan-t5-small"},
4345
)
44-
4546
benchmarks = llama_stack_client.benchmarks.list()
4647

4748
assert len(benchmarks) == 1
48-
assert benchmarks[0].identifier == trustyai_lmeval_arc_easy
49-
assert benchmarks[0].provider_id == provider_id
49+
assert benchmarks[0].identifier == TRUSTYAI_LMEVAL_ARCEASY
50+
assert benchmarks[0].provider_id == LlamaStackProviders.Eval.TRUSTYAI_LMEVAL
5051

5152
def test_llamastack_run_eval(self, patched_trustyai_operator_configmap_allow_online, llama_stack_client):
5253
job = llama_stack_client.eval.run_eval(
@@ -56,10 +57,21 @@ def test_llamastack_run_eval(self, patched_trustyai_operator_configmap_allow_onl
5657
"model": "qwen",
5758
"type": "model",
5859
"provider_id": LlamaStackProviders.Eval.TRUSTYAI_LMEVAL,
59-
"sampling_params": {"temperature": 0.7, "top_p": 0.9, "max_tokens": 256},
60+
"sampling_params": {"temperature": 0.7, "top_p": 0.9, "max_tokens": 10},
6061
},
61-
"num_examples": 100,
62+
"scoring_params": {},
63+
"num_examples": 2,
6264
},
6365
)
6466

65-
print("hi")
67+
samples = TimeoutSampler(
68+
wait_timeout=Timeout.TIMEOUT_10MIN,
69+
sleep=30,
70+
func=lambda: llama_stack_client.eval.jobs.status(
71+
job_id=job.job_id, benchmark_id=TRUSTYAI_LMEVAL_ARCEASY
72+
).status,
73+
)
74+
75+
for sample in samples:
76+
if sample == "completed":
77+
break

0 commit comments

Comments
 (0)