Skip to content

Commit 609f5f5

Browse files
authored
Add mixed-prefix gsm8k eval and its CPU unit test (#27502)
1 parent fdcd28a commit 609f5f5

4 files changed

Lines changed: 265 additions & 5 deletions

File tree

python/sglang/test/run_eval.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ def run_eval(args):
178178
num_shots=getattr(args, "num_shots", 5),
179179
data_path=getattr(args, "gsm8k_data_path", None),
180180
)
181+
elif args.eval_name == "mixed_prefix_gsm8k":
182+
from sglang.test.simple_eval_mixed_prefix_gsm8k import MixedPrefixGSM8KEval
183+
184+
eval_obj = MixedPrefixGSM8KEval(
185+
num_examples=args.num_examples,
186+
num_threads=args.num_threads,
187+
num_shots=args.num_shots,
188+
secondary_pool_size=args.mixed_prefix_gsm8k_secondary_pool_size,
189+
data_path=args.gsm8k_data_path,
190+
seed=args.mixed_prefix_gsm8k_seed,
191+
)
181192
else:
182193
raise ValueError(f"Invalid eval name: {args.eval_name}")
183194

@@ -367,6 +378,18 @@ def run_eval(args):
367378
default=None,
368379
help="Path to GSM8K data file (e.g., test.jsonl)",
369380
)
381+
parser.add_argument(
382+
"--mixed-prefix-gsm8k-secondary-pool-size",
383+
type=int,
384+
default=15,
385+
help="Size of secondary example pool for eval_name=mixed_prefix_gsm8k (default: 15)",
386+
)
387+
parser.add_argument(
388+
"--mixed-prefix-gsm8k-seed",
389+
type=int,
390+
default=42,
391+
help="Seed for per-question random sampling in mixed_prefix_gsm8k (default: 42)",
392+
)
370393

371394
args = parser.parse_args()
372395

python/sglang/test/simple_eval_gsm8k.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,29 @@ def __init__(
5656
else:
5757
filename = download_and_cache_file(GSM8K_URL)
5858

59-
self._lines = list(read_jsonl(filename))
60-
self._few_shot_prompt = get_few_shot_examples(self._lines, num_shots)
61-
59+
all_lines = list(read_jsonl(filename))
60+
pool_size = self._setup_prefix_pool(all_lines, num_shots)
6261
# The evaluation data should not include the few-shot examples to prevent data leakage.
63-
self._lines = self._lines[num_shots:]
62+
self._lines = all_lines[pool_size:]
6463
if num_examples is not None:
64+
# Slice caps silently when num_examples exceeds the available lines,
65+
# matching upstream: callers like test_basic_sanity_eagle3 pass a
66+
# num_examples larger than the dataset on purpose.
6567
self._lines = self._lines[:num_examples]
6668

69+
def _setup_prefix_pool(self, all_lines: list, num_shots: int) -> int:
70+
self._few_shot_prompt = get_few_shot_examples(all_lines, num_shots)
71+
return num_shots
72+
73+
def _build_prefix(self, idx: int) -> str:
74+
return self._few_shot_prompt
75+
6776
def __call__(self, sampler: SamplerBase) -> EvalResult:
6877
def fn(idx: int) -> SingleEvalResult:
6978
question = get_one_example(self._lines, idx, include_answer=False)
7079
correct_answer = get_answer_value(self._lines[idx]["answer"])
7180

72-
prompt_content = self._few_shot_prompt + question
81+
prompt_content = self._build_prefix(idx) + question
7382
prompt_messages = [
7483
sampler._pack_message(content=prompt_content, role="user")
7584
]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import random
2+
from typing import Optional
3+
4+
from sglang.test.simple_eval_gsm8k import GSM8KEval, get_one_example
5+
6+
7+
class MixedPrefixGSM8KEval(GSM8KEval):
8+
9+
def __init__(
10+
self,
11+
num_examples: Optional[int],
12+
num_threads: int,
13+
num_shots: int,
14+
secondary_pool_size: int,
15+
data_path: Optional[str],
16+
seed: int,
17+
):
18+
self._secondary_pool_size = secondary_pool_size
19+
self._seed = seed
20+
super().__init__(
21+
num_examples=num_examples,
22+
num_threads=num_threads,
23+
num_shots=num_shots,
24+
data_path=data_path,
25+
)
26+
27+
def _setup_prefix_pool(self, all_lines: list, num_shots: int) -> int:
28+
overall_pool_size = num_shots + self._secondary_pool_size
29+
if len(all_lines) < overall_pool_size + 1:
30+
raise ValueError(
31+
f"GSM8K dataset has {len(all_lines)} examples but mixed-prefix "
32+
f"eval needs at least {overall_pool_size + 1} "
33+
f"(num_shots {num_shots} + secondary "
34+
f"{self._secondary_pool_size} + 1 test)."
35+
)
36+
self._primary_shots = all_lines[:num_shots]
37+
self._secondary_pool = all_lines[num_shots:overall_pool_size]
38+
return overall_pool_size
39+
40+
def _build_prefix(self, idx: int) -> str:
41+
rng = random.Random(self._seed + idx)
42+
num_primary = rng.randint(0, self._num_shots)
43+
secondary_size = rng.randint(0, self._secondary_pool_size)
44+
secondary_indices = rng.sample(range(len(self._secondary_pool)), secondary_size)
45+
primary = self._primary_shots[:num_primary]
46+
secondary = [self._secondary_pool[i] for i in secondary_indices]
47+
combined = primary + secondary
48+
return "".join(
49+
get_one_example(combined, i, include_answer=True) + "\n\n"
50+
for i in range(len(combined))
51+
)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import json
2+
import os
3+
import tempfile
4+
import unittest
5+
from typing import List, Tuple
6+
7+
from sglang.test.ci.ci_register import register_cpu_ci
8+
from sglang.test.simple_eval_gsm8k import get_one_example
9+
from sglang.test.simple_eval_mixed_prefix_gsm8k import MixedPrefixGSM8KEval
10+
from sglang.test.test_utils import CustomTestCase
11+
12+
register_cpu_ci(est_time=5, suite="base-b-test-cpu")
13+
14+
15+
def _write_synthetic_dataset(path: str, n: int) -> None:
16+
with open(path, "w") as f:
17+
for i in range(n):
18+
f.write(
19+
json.dumps(
20+
{
21+
"question": f"Synthetic question {i}: what is {i} + {i}?",
22+
"answer": f"The answer is {2 * i}. #### {2 * i}",
23+
}
24+
)
25+
+ "\n"
26+
)
27+
28+
29+
class TestMixedPrefixGSM8KEval(CustomTestCase):
30+
NUM_SHOTS = 4
31+
SECONDARY_POOL_SIZE = 12
32+
NUM_EXAMPLES = 40
33+
34+
@classmethod
35+
def setUpClass(cls):
36+
cls._tmpdir = tempfile.TemporaryDirectory()
37+
cls._data_path = os.path.join(cls._tmpdir.name, "synthetic.jsonl")
38+
_write_synthetic_dataset(cls._data_path, 100)
39+
40+
@classmethod
41+
def tearDownClass(cls):
42+
cls._tmpdir.cleanup()
43+
44+
def _make_eval(self, seed: int = 42, num_examples=None):
45+
return MixedPrefixGSM8KEval(
46+
num_examples=(
47+
num_examples if num_examples is not None else self.NUM_EXAMPLES
48+
),
49+
num_threads=1,
50+
num_shots=self.NUM_SHOTS,
51+
secondary_pool_size=self.SECONDARY_POOL_SIZE,
52+
data_path=self._data_path,
53+
seed=seed,
54+
)
55+
56+
def _primary_lines(self, evaluator) -> List[str]:
57+
return [
58+
get_one_example(evaluator._primary_shots, j, include_answer=True) + "\n\n"
59+
for j in range(self.NUM_SHOTS)
60+
]
61+
62+
def _decompose(self, evaluator, prefix: str) -> Tuple[int, List[str]]:
63+
k = 0
64+
for line in self._primary_lines(evaluator):
65+
if prefix.startswith(line):
66+
prefix = prefix[len(line) :]
67+
k += 1
68+
else:
69+
break
70+
remainder_questions: List[str] = []
71+
if prefix:
72+
chunks = prefix.split("\n\n")
73+
for chunk in chunks:
74+
if chunk.startswith("Question: "):
75+
q_text = chunk[len("Question: ") :].split("\nAnswer:")[0]
76+
remainder_questions.append(q_text)
77+
return k, remainder_questions
78+
79+
def test_primary_segment_is_strict_prefix_of_primary_shots(self):
80+
e = self._make_eval()
81+
for i in range(self.NUM_EXAMPLES):
82+
k, _ = self._decompose(e, e._build_prefix(i))
83+
self.assertGreaterEqual(k, 0)
84+
self.assertLessEqual(k, self.NUM_SHOTS)
85+
86+
def test_remainder_questions_come_from_secondary_pool(self):
87+
e = self._make_eval()
88+
secondary_qs = {item["question"] for item in e._secondary_pool}
89+
for i in range(self.NUM_EXAMPLES):
90+
_, remainder = self._decompose(e, e._build_prefix(i))
91+
for q in remainder:
92+
self.assertIn(q, secondary_qs)
93+
94+
def test_remainder_no_duplicates_within_one_query(self):
95+
e = self._make_eval()
96+
for i in range(self.NUM_EXAMPLES):
97+
_, remainder = self._decompose(e, e._build_prefix(i))
98+
self.assertEqual(
99+
len(remainder),
100+
len(set(remainder)),
101+
f"query {i} has duplicate secondary samples",
102+
)
103+
104+
def test_remainder_size_within_secondary_pool_bound(self):
105+
e = self._make_eval()
106+
for i in range(self.NUM_EXAMPLES):
107+
_, remainder = self._decompose(e, e._build_prefix(i))
108+
self.assertGreaterEqual(len(remainder), 0)
109+
self.assertLessEqual(len(remainder), self.SECONDARY_POOL_SIZE)
110+
111+
def test_primary_depth_takes_multiple_values(self):
112+
e = self._make_eval()
113+
ks = {
114+
self._decompose(e, e._build_prefix(i))[0] for i in range(self.NUM_EXAMPLES)
115+
}
116+
self.assertGreater(len(ks), 2, f"k values seen: {ks}")
117+
118+
def test_secondary_size_takes_multiple_values(self):
119+
e = self._make_eval()
120+
sizes = {
121+
len(self._decompose(e, e._build_prefix(i))[1])
122+
for i in range(self.NUM_EXAMPLES)
123+
}
124+
self.assertGreater(len(sizes), 2, f"sizes seen: {sizes}")
125+
126+
def test_two_queries_share_min_primary_prefix(self):
127+
e = self._make_eval()
128+
lines = self._primary_lines(e)
129+
prefixes = [e._build_prefix(i) for i in range(self.NUM_EXAMPLES)]
130+
ks = [self._decompose(e, p)[0] for p in prefixes]
131+
for i in range(self.NUM_EXAMPLES):
132+
for j in range(i + 1, self.NUM_EXAMPLES):
133+
shared = "".join(lines[: min(ks[i], ks[j])])
134+
self.assertTrue(prefixes[i].startswith(shared))
135+
self.assertTrue(prefixes[j].startswith(shared))
136+
137+
def test_build_prefix_is_deterministic(self):
138+
a = self._make_eval(seed=42)
139+
b = self._make_eval(seed=42)
140+
for i in range(self.NUM_EXAMPLES):
141+
self.assertEqual(a._build_prefix(i), b._build_prefix(i))
142+
143+
def test_seed_actually_matters(self):
144+
a = self._make_eval(seed=42)
145+
b = self._make_eval(seed=43)
146+
differences = sum(
147+
1
148+
for i in range(self.NUM_EXAMPLES)
149+
if a._build_prefix(i) != b._build_prefix(i)
150+
)
151+
self.assertGreater(differences, self.NUM_EXAMPLES // 2)
152+
153+
def test_pools_and_test_lines_pairwise_disjoint(self):
154+
e = self._make_eval(num_examples=None)
155+
primary_qs = {item["question"] for item in e._primary_shots}
156+
secondary_qs = {item["question"] for item in e._secondary_pool}
157+
test_qs = {item["question"] for item in e._lines}
158+
self.assertEqual(primary_qs & secondary_qs, set())
159+
self.assertEqual(primary_qs & test_qs, set())
160+
self.assertEqual(secondary_qs & test_qs, set())
161+
162+
def test_insufficient_dataset_raises(self):
163+
tiny = os.path.join(self._tmpdir.name, "tiny.jsonl")
164+
_write_synthetic_dataset(tiny, n=5)
165+
with self.assertRaises(ValueError):
166+
MixedPrefixGSM8KEval(
167+
num_examples=1,
168+
num_threads=1,
169+
num_shots=self.NUM_SHOTS,
170+
secondary_pool_size=self.SECONDARY_POOL_SIZE,
171+
data_path=tiny,
172+
seed=42,
173+
)
174+
175+
176+
if __name__ == "__main__":
177+
unittest.main()

0 commit comments

Comments
 (0)