Skip to content

Commit a5ccc22

Browse files
committed
add accuracy test case
1 parent cbd5cf1 commit a5ccc22

5 files changed

Lines changed: 83 additions & 24 deletions

File tree

.github/workflows/gke-connectivity-smoke.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ jobs:
259259
git checkout "${REPO_REF}"
260260
python3 -m pip install --upgrade pip
261261
python3 -m pip install -e "python[all]"
262+
python3 -m pip install evalscope
262263
python3 test/srt/mulit_host/run_suite.py
263264
env:
264265
- name: JOB_COMPLETION_INDEX

test/srt/mulit_host/multi_host_suite.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import subprocess
22
import sys
33
import time
4-
from dataclasses import dataclass
5-
from typing import Callable, Literal
4+
from dataclasses import dataclass, field
5+
from typing import Any, Callable, Literal
66

77

88
@dataclass(frozen=True)
@@ -21,10 +21,11 @@ class PerfCase:
2121
@dataclass(frozen=True)
2222
class AccuracyCase:
2323
name: str
24-
eval_name: str
25-
num_examples: int
26-
num_threads: int
27-
temperature: float = 0.0
24+
dataset: str
25+
model_id: str
26+
eval_batch_size: int = 32
27+
generation_config: dict[str, Any] = field(default_factory=dict)
28+
limit: int | None = None
2829
dry_run_result: Literal["success", "failed"] = "success"
2930

3031

test/srt/mulit_host/run_suite.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import importlib
44
import json
55
import os
6+
import subprocess
67
import sys
78
import threading
89
import time
@@ -159,8 +160,49 @@ def run_case(case: PerfCase | AccuracyCase, model_path: str, port: int) -> None:
159160
if isinstance(case, PerfCase):
160161
run_perf_case(case, model_path, port)
161162
return
163+
if isinstance(case, AccuracyCase):
164+
run_accuracy_case(case, port)
165+
return
166+
167+
raise NotImplementedError(f"Unsupported case type: {type(case).__name__}")
168+
169+
170+
def run_accuracy_case(case: AccuracyCase, port: int) -> None:
171+
api_url = f"http://127.0.0.1:{port}/v1"
172+
cmd = [
173+
"evalscope",
174+
"eval",
175+
"--model",
176+
case.model_id,
177+
"--api-url",
178+
api_url,
179+
"--api-key",
180+
"EMPTY",
181+
"--eval-type",
182+
"openai_api",
183+
"--datasets",
184+
case.dataset,
185+
"--eval-batch-size",
186+
str(case.eval_batch_size),
187+
]
188+
if case.generation_config:
189+
cmd.extend(["--generation-config", json.dumps(case.generation_config)])
190+
if case.limit is not None:
191+
cmd.extend(["--limit", str(case.limit)])
162192

163-
raise NotImplementedError(f"Accuracy case is not supported yet: {case.name}")
193+
_log(
194+
"Running accuracy case "
195+
f"name={case.name}, dataset={case.dataset}, "
196+
f"eval_batch_size={case.eval_batch_size}, "
197+
f"generation_config={case.generation_config}, limit={case.limit}"
198+
)
199+
_log(f"Command: {' '.join(cmd)}")
200+
completed = subprocess.run(cmd, check=False)
201+
if completed.returncode != 0:
202+
raise RuntimeError(
203+
f"evalscope exited with code {completed.returncode} for case={case.name}"
204+
)
205+
_log(f"Accuracy case {case.name} completed (warn-only mode, accuracy not gated)")
164206

165207

166208
def stop_server_process(server_process) -> None:

test/srt/mulit_host/test_mimo_flash.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from multi_host_suite import ModelRun, ModelRunConfig, MultiHostSuite, PerfCase
1+
from multi_host_suite import (
2+
AccuracyCase,
3+
ModelRun,
4+
ModelRunConfig,
5+
MultiHostSuite,
6+
PerfCase,
7+
)
28

39

410
def get_suites() -> list[MultiHostSuite]:
@@ -41,16 +47,23 @@ def get_suites() -> list[MultiHostSuite]:
4147
),
4248
),
4349
cases=[
44-
PerfCase(
45-
name="mimo-flash-benchmark",
46-
input_len=16384,
47-
output_len=1024,
48-
num_prompts=256,
49-
max_concurrency=64,
50-
request_rate=100,
51-
seed=12345,
52-
flush_cache=True,
53-
)
50+
# PerfCase(
51+
# name="mimo-flash-benchmark",
52+
# input_len=16384,
53+
# output_len=1024,
54+
# num_prompts=256,
55+
# max_concurrency=64,
56+
# request_rate=100,
57+
# seed=12345,
58+
# flush_cache=True,
59+
# ),
60+
AccuracyCase(
61+
name="mimo-flash-accuracy",
62+
dataset="gsm8k",
63+
model_id="XiaomiMiMo/MiMo-V2-Flash",
64+
eval_batch_size=32,
65+
generation_config={"temperature": 0.8, "top_p": 0.95},
66+
),
5467
],
5568
)
5669
],

test/srt/mulit_host/test_multi_host_suite.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def test_dry_run_suite_preserves_run_and_case_order(self):
4141
),
4242
AccuracyCase(
4343
name="mmlu-smoke",
44-
eval_name="mmlu",
45-
num_examples=20,
46-
num_threads=32,
44+
dataset="mmlu",
45+
model_id="deepseek-ai/DeepSeek-V2-Lite",
46+
eval_batch_size=32,
47+
limit=20,
4748
),
4849
],
4950
),
@@ -58,9 +59,10 @@ def test_dry_run_suite_preserves_run_and_case_order(self):
5859
cases=[
5960
AccuracyCase(
6061
name="mmlu-smoke",
61-
eval_name="mmlu",
62-
num_examples=20,
63-
num_threads=32,
62+
dataset="mmlu",
63+
model_id="XiaomiMiMo/MiMo-7B-RL",
64+
eval_batch_size=32,
65+
limit=20,
6466
),
6567
],
6668
),

0 commit comments

Comments
 (0)