|
| 1 | +"""GSM8K benchmark for sglang-jax. |
| 2 | +
|
| 3 | +Sends concurrent requests to a running sglang-jax server's /generate endpoint |
| 4 | +and measures accuracy and throughput on the GSM8K (or GSM8K Platinum) dataset. |
| 5 | +
|
| 6 | +Usage: |
| 7 | + # Start server first: |
| 8 | + # python3 -m sgl_jax.launch_server --model-path <model> --port 30000 ... |
| 9 | +
|
| 10 | + # Run benchmark: |
| 11 | + python bench_sglang_jax.py --base-url http://localhost:30000 --num-questions 200 |
| 12 | +""" |
| 13 | + |
| 14 | +import argparse |
| 15 | +import ast |
| 16 | +import asyncio |
| 17 | +import json |
| 18 | +import os |
| 19 | +import re |
| 20 | +import tempfile |
| 21 | +import time |
| 22 | +import urllib.request |
| 23 | + |
| 24 | +import aiohttp |
| 25 | +import numpy as np |
| 26 | +from datasets import load_dataset |
| 27 | +from tqdm import tqdm |
| 28 | + |
| 29 | +INVALID = -9999999 |
| 30 | + |
| 31 | + |
| 32 | +def read_jsonl(path): |
| 33 | + with open(path) as f: |
| 34 | + for line in f: |
| 35 | + line = line.strip() |
| 36 | + if line: |
| 37 | + yield json.loads(line) |
| 38 | + |
| 39 | + |
| 40 | +def download_and_cache_file(url): |
| 41 | + cache_dir = os.path.join(tempfile.gettempdir(), "sgl_jax_bench_cache") |
| 42 | + os.makedirs(cache_dir, exist_ok=True) |
| 43 | + filename = url.split("/")[-1] |
| 44 | + cache_path = os.path.join(cache_dir, filename) |
| 45 | + if not os.path.isfile(cache_path): |
| 46 | + print(f"Downloading {url} to {cache_path}...") |
| 47 | + urllib.request.urlretrieve(url, cache_path) |
| 48 | + return cache_path |
| 49 | + |
| 50 | + |
| 51 | +def get_one_example(lines, i, include_answer): |
| 52 | + ret = "Question: " + lines[i]["question"] + "\nAnswer:" |
| 53 | + if include_answer: |
| 54 | + ret += " " + lines[i]["answer"] |
| 55 | + return ret |
| 56 | + |
| 57 | + |
| 58 | +def get_few_shot_examples(lines, k): |
| 59 | + ret = "" |
| 60 | + for i in range(k): |
| 61 | + ret += get_one_example(lines, i, True) + "\n\n" |
| 62 | + return ret |
| 63 | + |
| 64 | + |
| 65 | +def get_answer_value(answer_str): |
| 66 | + answer_str = answer_str.replace(",", "") |
| 67 | + numbers = re.findall(r"\d+", answer_str) |
| 68 | + if len(numbers) < 1: |
| 69 | + return INVALID |
| 70 | + try: |
| 71 | + return ast.literal_eval(numbers[-1]) |
| 72 | + except SyntaxError: |
| 73 | + return INVALID |
| 74 | + |
| 75 | + |
| 76 | +async def send_request(session, base_url, text, sampling_params, semaphore, pbar): |
| 77 | + payload = { |
| 78 | + "text": text, |
| 79 | + "sampling_params": sampling_params, |
| 80 | + "stream": False, |
| 81 | + } |
| 82 | + async with semaphore: |
| 83 | + timeout = aiohttp.ClientTimeout(total=300) |
| 84 | + async with session.post(f"{base_url}/generate", json=payload, timeout=timeout) as response: |
| 85 | + if response.status != 200: |
| 86 | + error_text = await response.text() |
| 87 | + raise RuntimeError(f"Request failed with status {response.status}: {error_text}") |
| 88 | + result = await response.json() |
| 89 | + pbar.update(1) |
| 90 | + return result |
| 91 | + |
| 92 | + |
| 93 | +async def run_batch(base_url, questions, sampling_params, parallel): |
| 94 | + semaphore = asyncio.Semaphore(parallel) |
| 95 | + pbar = tqdm(total=len(questions), desc="Generating") |
| 96 | + |
| 97 | + async with aiohttp.ClientSession() as session: |
| 98 | + tasks = [ |
| 99 | + send_request(session, base_url, q, sampling_params, semaphore, pbar) for q in questions |
| 100 | + ] |
| 101 | + results = await asyncio.gather(*tasks) |
| 102 | + |
| 103 | + pbar.close() |
| 104 | + return results |
| 105 | + |
| 106 | + |
| 107 | +def main(args): |
| 108 | + # Load tokenizer if enable_thinking is set |
| 109 | + tokenizer = None |
| 110 | + if args.enable_thinking: |
| 111 | + from transformers import AutoTokenizer |
| 112 | + |
| 113 | + assert ( |
| 114 | + args.tokenizer_path is not None |
| 115 | + ), "--tokenizer-path is required when --enable-thinking is set" |
| 116 | + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True) |
| 117 | + |
| 118 | + # Read data |
| 119 | + if args.platinum: |
| 120 | + print("Loading GSM8K Platinum dataset from HuggingFace...") |
| 121 | + dataset = load_dataset("madrylab/gsm8k-platinum", "main", split="test") |
| 122 | + lines = [{"question": item["question"], "answer": item["answer"]} for item in dataset] |
| 123 | + else: |
| 124 | + data_path = args.data_path |
| 125 | + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" |
| 126 | + if not os.path.isfile(data_path): |
| 127 | + data_path = download_and_cache_file(url) |
| 128 | + lines = list(read_jsonl(data_path)) |
| 129 | + |
| 130 | + # Construct prompts |
| 131 | + num_questions = args.num_questions |
| 132 | + num_shots = args.num_shots |
| 133 | + few_shot_examples = get_few_shot_examples(lines, num_shots) |
| 134 | + |
| 135 | + questions = [] |
| 136 | + labels = [] |
| 137 | + for i in range(len(lines[:num_questions])): |
| 138 | + raw_question = few_shot_examples + get_one_example(lines, i, False) |
| 139 | + if tokenizer is not None: |
| 140 | + messages = [{"role": "user", "content": raw_question}] |
| 141 | + raw_question = tokenizer.apply_chat_template( |
| 142 | + messages, |
| 143 | + tokenize=False, |
| 144 | + add_generation_prompt=True, |
| 145 | + enable_thinking=True, |
| 146 | + ) |
| 147 | + questions.append(raw_question) |
| 148 | + labels.append(get_answer_value(lines[i]["answer"])) |
| 149 | + assert all(label != INVALID for label in labels) |
| 150 | + |
| 151 | + # Sampling parameters |
| 152 | + sampling_params = { |
| 153 | + "temperature": args.temperature, |
| 154 | + "top_p": args.top_p, |
| 155 | + "max_new_tokens": args.max_new_tokens, |
| 156 | + "stop": ["Question", "Assistant:", "<|separator|>"], |
| 157 | + } |
| 158 | + |
| 159 | + # Run requests |
| 160 | + print( |
| 161 | + f"Running {len(questions)} requests against {args.base_url} " |
| 162 | + f"(parallelism={args.parallel})..." |
| 163 | + ) |
| 164 | + tic = time.perf_counter() |
| 165 | + results = asyncio.run(run_batch(args.base_url, questions, sampling_params, args.parallel)) |
| 166 | + latency = time.perf_counter() - tic |
| 167 | + |
| 168 | + # Extract predictions |
| 169 | + preds = [] |
| 170 | + for r in results: |
| 171 | + preds.append(get_answer_value(r["text"])) |
| 172 | + |
| 173 | + # Compute accuracy |
| 174 | + acc = np.mean(np.array(preds) == np.array(labels)) |
| 175 | + invalid = np.mean(np.array(preds) == INVALID) |
| 176 | + |
| 177 | + # Compute speed |
| 178 | + num_output_tokens = sum(r["meta_info"]["completion_tokens"] for r in results) |
| 179 | + output_throughput = num_output_tokens / latency |
| 180 | + |
| 181 | + # Print results |
| 182 | + print(f"Accuracy: {acc:.3f}") |
| 183 | + print(f"Invalid: {invalid:.3f}") |
| 184 | + print(f"Latency: {latency:.3f} s") |
| 185 | + print(f"Output throughput: {output_throughput:.3f} token/s") |
| 186 | + |
| 187 | + # Dump raw outputs |
| 188 | + if args.output_file: |
| 189 | + with open(args.output_file, "w") as f: |
| 190 | + for i, r in enumerate(results): |
| 191 | + f.write(f"=== Question {i} ===\n") |
| 192 | + f.write(questions[i] + "\n") |
| 193 | + f.write("=== Answer ===\n") |
| 194 | + f.write(r["text"] + "\n") |
| 195 | + f.write(f"=== Prediction: {preds[i]}, Label: {labels[i]} ===\n\n") |
| 196 | + print(f"Raw outputs saved to {args.output_file}") |
| 197 | + |
| 198 | + # Dump results |
| 199 | + with open(args.result_file, "a") as fout: |
| 200 | + value = { |
| 201 | + "task": "gsm8k-platinum" if args.platinum else "gsm8k", |
| 202 | + "backend": "sgl-jax", |
| 203 | + "latency": round(latency, 3), |
| 204 | + "accuracy": round(acc, 3), |
| 205 | + "num_requests": args.num_questions, |
| 206 | + "other": { |
| 207 | + "num_questions": args.num_questions, |
| 208 | + "parallel": args.parallel, |
| 209 | + }, |
| 210 | + } |
| 211 | + fout.write(json.dumps(value) + "\n") |
| 212 | + print(f"Results appended to {args.result_file}") |
| 213 | + |
| 214 | + |
| 215 | +if __name__ == "__main__": |
| 216 | + parser = argparse.ArgumentParser(description="GSM8K benchmark for sglang-jax") |
| 217 | + parser.add_argument( |
| 218 | + "--base-url", |
| 219 | + type=str, |
| 220 | + default="http://localhost:30000", |
| 221 | + help="Base URL of the sglang-jax server", |
| 222 | + ) |
| 223 | + parser.add_argument("--num-shots", type=int, default=5) |
| 224 | + parser.add_argument("--data-path", type=str, default="test.jsonl") |
| 225 | + parser.add_argument("--num-questions", type=int, default=200) |
| 226 | + parser.add_argument("--max-new-tokens", type=int, default=512) |
| 227 | + parser.add_argument("--temperature", type=float, default=0.0) |
| 228 | + parser.add_argument("--top-p", type=float, default=1.0) |
| 229 | + parser.add_argument("--parallel", type=int, default=64, help="Max concurrent requests") |
| 230 | + parser.add_argument( |
| 231 | + "--result-file", |
| 232 | + type=str, |
| 233 | + default="bench_results.jsonl", |
| 234 | + help="Path to append JSON result summary", |
| 235 | + ) |
| 236 | + parser.add_argument( |
| 237 | + "--output-file", |
| 238 | + type=str, |
| 239 | + default=None, |
| 240 | + help="Path to write detailed per-question outputs", |
| 241 | + ) |
| 242 | + parser.add_argument( |
| 243 | + "--enable-thinking", |
| 244 | + action="store_true", |
| 245 | + help="Enable thinking mode by wrapping prompts with chat template", |
| 246 | + ) |
| 247 | + parser.add_argument( |
| 248 | + "--tokenizer-path", |
| 249 | + type=str, |
| 250 | + default=None, |
| 251 | + help="Path to tokenizer (required when --enable-thinking is set)", |
| 252 | + ) |
| 253 | + parser.add_argument( |
| 254 | + "--platinum", |
| 255 | + action="store_true", |
| 256 | + help="Use GSM8K Platinum dataset (drop-in replacement with corrected labels)", |
| 257 | + ) |
| 258 | + args = parser.parse_args() |
| 259 | + main(args) |
0 commit comments