Skip to content

Commit 7907875

Browse files
Bob-Chen222IamleosclaudeJamesBrianD
authored
Feat/deepseek non absorption (#941)
* deepseek MLA * Feat/attention sink (#207) * feat: add RPA v3 kernel with tpu-inference optimizations and attention sink - Port RPA v3 kernel with all 9 tpu-inference optimizations - Support attention sink, custom mask, xai_temperature, sliding window - Update ref function and tests to use v3 as baseline - Switch benchmark to use v3 kernel Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * wrap kv cache update with shard map * fix: attention_sink TP sharding and MHA+bf16 q_packing alignment (#220) - Shard attention_sink by tensor axis for TP compatibility - Move attention_sink preparation into prepare_inputs, pad per kv_head then pad q_heads_per_kv dimension to match kernel's padded layout - Restructure l/m initialization to avoid redundant writes - Add attention_sink tests: MHA, GQA(4q/1kv), single-head(1q/1kv) - Update test_flashattention_dp.py to use v3 ref and support attention_sink Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * feat: change kv cache shape to match rpav3 (#228) * tune rpav3 (#248) * correct on simple prompt * fix: test flashattention failed Refactor native attention to cleanly handle 5D KV cache: - _get_and_update_kv_cache only returns 5D fused buffer - forward_attention receives 5D, flattens to 3D internally for attention computation - update_kv_cache kernel accepts only 5D inputs (flattens internally for Pallas) - Remove scattered 3D/5D dual-path compatibility code Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * add bench script for temp record * deepseekv3 accurayc aligned on num_hidden_layers=4 * update model config per model file * revert layer * fix redundancy * add benchmark * shorten * fix * delete comment * add padding comment * update * fix lint * update deepseek attention * delete benchmark * add mimo * separate yarn of deepseek and grok * add scripts --------- Co-authored-by: leos <leos@primatrix.ai> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: RamezesDong <donghouze666@outlook.com>
1 parent bc77064 commit 7907875

9 files changed

Lines changed: 1300 additions & 7 deletions

File tree

benchmark/gsm8k/__init__.py

Whitespace-only changes.
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
"""GSM8K benchmark for sglang-jax.
2+
3+
Sends concurrent requests to a running sglang-jax server's /generate endpoint
4+
and measures accuracy and throughput on the GSM8K (or GSM8K Platinum) dataset.
5+
6+
Usage:
7+
# Start server first:
8+
# python3 -m sgl_jax.launch_server --model-path <model> --port 30000 ...
9+
10+
# Run benchmark:
11+
python bench_sglang_jax.py --base-url http://localhost:30000 --num-questions 200
12+
"""
13+
14+
import argparse
15+
import ast
16+
import asyncio
17+
import json
18+
import os
19+
import re
20+
import tempfile
21+
import time
22+
import urllib.request
23+
24+
import aiohttp
25+
import numpy as np
26+
from datasets import load_dataset
27+
from tqdm import tqdm
28+
29+
INVALID = -9999999
30+
31+
32+
def read_jsonl(path):
33+
with open(path) as f:
34+
for line in f:
35+
line = line.strip()
36+
if line:
37+
yield json.loads(line)
38+
39+
40+
def download_and_cache_file(url):
41+
cache_dir = os.path.join(tempfile.gettempdir(), "sgl_jax_bench_cache")
42+
os.makedirs(cache_dir, exist_ok=True)
43+
filename = url.split("/")[-1]
44+
cache_path = os.path.join(cache_dir, filename)
45+
if not os.path.isfile(cache_path):
46+
print(f"Downloading {url} to {cache_path}...")
47+
urllib.request.urlretrieve(url, cache_path)
48+
return cache_path
49+
50+
51+
def get_one_example(lines, i, include_answer):
52+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
53+
if include_answer:
54+
ret += " " + lines[i]["answer"]
55+
return ret
56+
57+
58+
def get_few_shot_examples(lines, k):
59+
ret = ""
60+
for i in range(k):
61+
ret += get_one_example(lines, i, True) + "\n\n"
62+
return ret
63+
64+
65+
def get_answer_value(answer_str):
66+
answer_str = answer_str.replace(",", "")
67+
numbers = re.findall(r"\d+", answer_str)
68+
if len(numbers) < 1:
69+
return INVALID
70+
try:
71+
return ast.literal_eval(numbers[-1])
72+
except SyntaxError:
73+
return INVALID
74+
75+
76+
async def send_request(session, base_url, text, sampling_params, semaphore, pbar):
77+
payload = {
78+
"text": text,
79+
"sampling_params": sampling_params,
80+
"stream": False,
81+
}
82+
async with semaphore:
83+
timeout = aiohttp.ClientTimeout(total=300)
84+
async with session.post(f"{base_url}/generate", json=payload, timeout=timeout) as response:
85+
if response.status != 200:
86+
error_text = await response.text()
87+
raise RuntimeError(f"Request failed with status {response.status}: {error_text}")
88+
result = await response.json()
89+
pbar.update(1)
90+
return result
91+
92+
93+
async def run_batch(base_url, questions, sampling_params, parallel):
94+
semaphore = asyncio.Semaphore(parallel)
95+
pbar = tqdm(total=len(questions), desc="Generating")
96+
97+
async with aiohttp.ClientSession() as session:
98+
tasks = [
99+
send_request(session, base_url, q, sampling_params, semaphore, pbar) for q in questions
100+
]
101+
results = await asyncio.gather(*tasks)
102+
103+
pbar.close()
104+
return results
105+
106+
107+
def main(args):
108+
# Load tokenizer if enable_thinking is set
109+
tokenizer = None
110+
if args.enable_thinking:
111+
from transformers import AutoTokenizer
112+
113+
assert (
114+
args.tokenizer_path is not None
115+
), "--tokenizer-path is required when --enable-thinking is set"
116+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
117+
118+
# Read data
119+
if args.platinum:
120+
print("Loading GSM8K Platinum dataset from HuggingFace...")
121+
dataset = load_dataset("madrylab/gsm8k-platinum", "main", split="test")
122+
lines = [{"question": item["question"], "answer": item["answer"]} for item in dataset]
123+
else:
124+
data_path = args.data_path
125+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
126+
if not os.path.isfile(data_path):
127+
data_path = download_and_cache_file(url)
128+
lines = list(read_jsonl(data_path))
129+
130+
# Construct prompts
131+
num_questions = args.num_questions
132+
num_shots = args.num_shots
133+
few_shot_examples = get_few_shot_examples(lines, num_shots)
134+
135+
questions = []
136+
labels = []
137+
for i in range(len(lines[:num_questions])):
138+
raw_question = few_shot_examples + get_one_example(lines, i, False)
139+
if tokenizer is not None:
140+
messages = [{"role": "user", "content": raw_question}]
141+
raw_question = tokenizer.apply_chat_template(
142+
messages,
143+
tokenize=False,
144+
add_generation_prompt=True,
145+
enable_thinking=True,
146+
)
147+
questions.append(raw_question)
148+
labels.append(get_answer_value(lines[i]["answer"]))
149+
assert all(label != INVALID for label in labels)
150+
151+
# Sampling parameters
152+
sampling_params = {
153+
"temperature": args.temperature,
154+
"top_p": args.top_p,
155+
"max_new_tokens": args.max_new_tokens,
156+
"stop": ["Question", "Assistant:", "<|separator|>"],
157+
}
158+
159+
# Run requests
160+
print(
161+
f"Running {len(questions)} requests against {args.base_url} "
162+
f"(parallelism={args.parallel})..."
163+
)
164+
tic = time.perf_counter()
165+
results = asyncio.run(run_batch(args.base_url, questions, sampling_params, args.parallel))
166+
latency = time.perf_counter() - tic
167+
168+
# Extract predictions
169+
preds = []
170+
for r in results:
171+
preds.append(get_answer_value(r["text"]))
172+
173+
# Compute accuracy
174+
acc = np.mean(np.array(preds) == np.array(labels))
175+
invalid = np.mean(np.array(preds) == INVALID)
176+
177+
# Compute speed
178+
num_output_tokens = sum(r["meta_info"]["completion_tokens"] for r in results)
179+
output_throughput = num_output_tokens / latency
180+
181+
# Print results
182+
print(f"Accuracy: {acc:.3f}")
183+
print(f"Invalid: {invalid:.3f}")
184+
print(f"Latency: {latency:.3f} s")
185+
print(f"Output throughput: {output_throughput:.3f} token/s")
186+
187+
# Dump raw outputs
188+
if args.output_file:
189+
with open(args.output_file, "w") as f:
190+
for i, r in enumerate(results):
191+
f.write(f"=== Question {i} ===\n")
192+
f.write(questions[i] + "\n")
193+
f.write("=== Answer ===\n")
194+
f.write(r["text"] + "\n")
195+
f.write(f"=== Prediction: {preds[i]}, Label: {labels[i]} ===\n\n")
196+
print(f"Raw outputs saved to {args.output_file}")
197+
198+
# Dump results
199+
with open(args.result_file, "a") as fout:
200+
value = {
201+
"task": "gsm8k-platinum" if args.platinum else "gsm8k",
202+
"backend": "sgl-jax",
203+
"latency": round(latency, 3),
204+
"accuracy": round(acc, 3),
205+
"num_requests": args.num_questions,
206+
"other": {
207+
"num_questions": args.num_questions,
208+
"parallel": args.parallel,
209+
},
210+
}
211+
fout.write(json.dumps(value) + "\n")
212+
print(f"Results appended to {args.result_file}")
213+
214+
215+
if __name__ == "__main__":
216+
parser = argparse.ArgumentParser(description="GSM8K benchmark for sglang-jax")
217+
parser.add_argument(
218+
"--base-url",
219+
type=str,
220+
default="http://localhost:30000",
221+
help="Base URL of the sglang-jax server",
222+
)
223+
parser.add_argument("--num-shots", type=int, default=5)
224+
parser.add_argument("--data-path", type=str, default="test.jsonl")
225+
parser.add_argument("--num-questions", type=int, default=200)
226+
parser.add_argument("--max-new-tokens", type=int, default=512)
227+
parser.add_argument("--temperature", type=float, default=0.0)
228+
parser.add_argument("--top-p", type=float, default=1.0)
229+
parser.add_argument("--parallel", type=int, default=64, help="Max concurrent requests")
230+
parser.add_argument(
231+
"--result-file",
232+
type=str,
233+
default="bench_results.jsonl",
234+
help="Path to append JSON result summary",
235+
)
236+
parser.add_argument(
237+
"--output-file",
238+
type=str,
239+
default=None,
240+
help="Path to write detailed per-question outputs",
241+
)
242+
parser.add_argument(
243+
"--enable-thinking",
244+
action="store_true",
245+
help="Enable thinking mode by wrapping prompts with chat template",
246+
)
247+
parser.add_argument(
248+
"--tokenizer-path",
249+
type=str,
250+
default=None,
251+
help="Path to tokenizer (required when --enable-thinking is set)",
252+
)
253+
parser.add_argument(
254+
"--platinum",
255+
action="store_true",
256+
help="Use GSM8K Platinum dataset (drop-in replacement with corrected labels)",
257+
)
258+
args = parser.parse_args()
259+
main(args)

python/sgl_jax/srt/configs/model_config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ def __init__(
173173
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
174174
)
175175
self.v_head_dim = getattr(self.hf_text_config, "v_head_dim", self.head_dim)
176-
177176
self.attention_arch = AttentionArch.MHA
177+
self._apply_model_specific_config()
178178
self.num_attention_heads = self.hf_text_config.num_attention_heads
179179
self.num_key_value_heads = getattr(self.hf_text_config, "num_key_value_heads", None)
180180

@@ -350,6 +350,25 @@ def is_dynamic_fp8_act(cfg):
350350
logger.info("No quantization config found in HF config or user config")
351351
return None
352352

353+
def _apply_model_specific_config(self) -> None:
354+
"""Invoke the model class's optional `patch_model_config` hook so model
355+
files can own their own config overrides (attention_arch, head_dim,
356+
MLA-specific dims, etc.) instead of a centralized if/elif chain here.
357+
358+
Runs during ModelConfig construction — before ModelRunner reads
359+
`attention_arch` for backend selection — so patches land in time.
360+
Import is lazy because model modules import ModelConfig back.
361+
"""
362+
from sgl_jax.srt.models.registry import ModelRegistry
363+
364+
try:
365+
model_cls, _ = ModelRegistry.resolve_model_cls(self.hf_config.architectures)
366+
except ValueError:
367+
return
368+
patch = getattr(model_cls, "patch_model_config", None)
369+
if patch is not None:
370+
patch(self)
371+
353372
@staticmethod
354373
def from_server_args(
355374
server_args: ServerArgs,

python/sgl_jax/srt/layers/attention/mla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
kernel_axes=(None, None),
7878
scope_name="q_a_proj",
7979
)
80-
self.q_a_layernorm = RMSNorm(q_lora_rank, param_dtype=jnp.float32)
80+
self.q_a_layernorm = RMSNorm(q_lora_rank, dtype=dtype)
8181
self.q_b_proj = LinearBase(
8282
q_lora_rank,
8383
num_heads * self.qk_head_dim,
@@ -97,7 +97,7 @@ def __init__(
9797
kernel_axes=(None, None),
9898
scope_name="kv_a_proj",
9999
)
100-
self.kv_a_layernorm = RMSNorm(kv_lora_rank, param_dtype=jnp.float32)
100+
self.kv_a_layernorm = RMSNorm(kv_lora_rank, dtype=dtype)
101101
self.kv_b_proj = LinearBase(
102102
kv_lora_rank,
103103
num_heads * (qk_nope_head_dim + v_head_dim),

0 commit comments

Comments
 (0)