Skip to content

Commit 31720ac

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Add vLLM sampler for math eval.
PiperOrigin-RevId: 875875039
1 parent 52a611c commit 31720ac

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

examples/deepscaler/math_eval_nb.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def load_model(self):
231231
tokenizer=self.tokenizer,
232232
cache_config=cache_config,
233233
)
234-
elif self.sampler_type == "sglang-jax":
234+
elif self.sampler_type == "sglang_jax":
235235
from tunix.google.stubs import sglang_jax_sampler_stub as sglang_jax_sampler # pylint: disable=g-import-not-at-top
236236

237237
mapping_config = mappings.MappingConfig.build(
@@ -254,6 +254,27 @@ def load_model(self):
254254
mapping_config=mapping_config,
255255
),
256256
)
257+
elif self.sampler_type == "vllm":
258+
from tunix.google.stubs import vllm_sampler_stub as vllm_sampler # pylint: disable=g-import-not-at-top
259+
260+
mapping_config = mappings.MappingConfig.build(
261+
mapping_obj=None,
262+
model=self.model,
263+
backend="vllm",
264+
)
265+
self.sampler_vllm = vllm_sampler.VllmSampler(
266+
tokenizer=self.tokenizer,
267+
config=vllm_sampler.VllmConfig(
268+
mesh=self.mesh,
269+
max_model_len=self.max_prompt_length
270+
+ self.max_generation_steps
271+
+ 100,
272+
model_version=self.model_version,
273+
hbm_utilization=0.4,
274+
init_with_random_weights=False,
275+
mapping_config=mapping_config,
276+
),
277+
)
257278
else:
258279
raise ValueError(f"Unsupported sampler type: {self.sampler_type}")
259280

@@ -358,7 +379,7 @@ def generate(
358379
eos_tokens=[stop_token_id],
359380
seed=jax.random.PRNGKey(seed) if seed is not None else None,
360381
)
361-
elif self.sampler_type == "sglang-jax":
382+
elif self.sampler_type == "sglang_jax":
362383
out_data = self.sampler_sglang(
363384
input_strings=prompts,
364385
max_generation_steps=safe_gen_length,
@@ -370,6 +391,18 @@ def generate(
370391
echo=False,
371392
pad_output=True,
372393
)
394+
elif self.sampler_type == "vllm":
395+
out_data = self.sampler_vllm(
396+
input_strings=prompts,
397+
max_generation_steps=safe_gen_length,
398+
max_prompt_length=self.max_prompt_length,
399+
temperature=temperature,
400+
top_p=top_p,
401+
top_k=top_k,
402+
seed=seed,
403+
echo=False,
404+
pad_output=True,
405+
)
373406
else:
374407
raise ValueError(f"Unsupported sampler type: {self.sampler_type}")
375408
return out_data.text
@@ -585,6 +618,7 @@ def evaluate(
585618
# %%
586619
# AIME-2024
587620
model_version = "agentica-org/DeepScaleR-1.5B-Preview"
621+
# model_version = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
588622
dataset = AIME_2024_DATA_PATH
589623
model_config, model_path = MODEL_MAPPING[model_version]
590624

0 commit comments

Comments
 (0)