11import pytest
22
33from tests .llama_stack .constants import LlamaStackProviders
4+ from utilities .constants import Timeout
5+ from timeout_sampler import TimeoutSampler
6+
7+
8+ TRUSTYAI_LMEVAL_ARCEASY = f"{ LlamaStackProviders .Eval .TRUSTYAI_LMEVAL } ::arc_easy"
49
510
611@pytest .mark .parametrize (
1924@pytest .mark .rawdeployment
2025class TestLlamaStackLMEvalProvider :
2126 """
22- Adds basic tests for the LlamaStack LMEval provider.
23-
27+ Tests for the LlamaStack LMEval provider.
2428 1. Register the LLM that will be evaluated.
25- 2. Register the arc_easy benchmark (eval)
26- 3. TODO: Add test for run_eval
29+ 2. Register the arc_easy benchmark.
30+ 3. Run the evaluation and wait until it's completed.
2731 """
2832
2933 def test_lmeval_register_benchmark (self , llama_stack_client ):
3034 llama_stack_client .models .register (
3135 provider_id = LlamaStackProviders .Inference .VLLM_INFERENCE , model_type = "llm" , model_id = "qwen"
3236 )
3337
34- provider_id = LlamaStackProviders .Eval .TRUSTYAI_LMEVAL
35- trustyai_lmeval_arc_easy = f"{ provider_id } ::arc_easy"
3638 llama_stack_client .benchmarks .register (
37- benchmark_id = trustyai_lmeval_arc_easy ,
38- dataset_id = trustyai_lmeval_arc_easy ,
39+ benchmark_id = TRUSTYAI_LMEVAL_ARCEASY ,
40+ dataset_id = TRUSTYAI_LMEVAL_ARCEASY ,
3941 scoring_functions = ["string" ],
4042 provider_id = LlamaStackProviders .Eval .TRUSTYAI_LMEVAL ,
4143 provider_benchmark_id = "string" ,
4244 metadata = {"tokenized_requests" : False , "tokenizer" : "google/flan-t5-small" },
4345 )
44-
4546 benchmarks = llama_stack_client .benchmarks .list ()
4647
4748 assert len (benchmarks ) == 1
48- assert benchmarks [0 ].identifier == trustyai_lmeval_arc_easy
49- assert benchmarks [0 ].provider_id == provider_id
49+ assert benchmarks [0 ].identifier == TRUSTYAI_LMEVAL_ARCEASY
50+ assert benchmarks [0 ].provider_id == LlamaStackProviders . Eval . TRUSTYAI_LMEVAL
5051
5152 def test_llamastack_run_eval (self , patched_trustyai_operator_configmap_allow_online , llama_stack_client ):
5253 job = llama_stack_client .eval .run_eval (
@@ -56,10 +57,21 @@ def test_llamastack_run_eval(self, patched_trustyai_operator_configmap_allow_onl
5657 "model" : "qwen" ,
5758 "type" : "model" ,
5859 "provider_id" : LlamaStackProviders .Eval .TRUSTYAI_LMEVAL ,
59- "sampling_params" : {"temperature" : 0.7 , "top_p" : 0.9 , "max_tokens" : 256 },
60+ "sampling_params" : {"temperature" : 0.7 , "top_p" : 0.9 , "max_tokens" : 10 },
6061 },
61- "num_examples" : 100 ,
62+ "scoring_params" : {},
63+ "num_examples" : 2 ,
6264 },
6365 )
6466
65- print ("hi" )
67+ samples = TimeoutSampler (
68+ wait_timeout = Timeout .TIMEOUT_10MIN ,
69+ sleep = 30 ,
70+ func = lambda : llama_stack_client .eval .jobs .status (
71+ job_id = job .job_id , benchmark_id = TRUSTYAI_LMEVAL_ARCEASY
72+ ).status ,
73+ )
74+
75+ for sample in samples :
76+ if sample == "completed" :
77+ break
0 commit comments