@@ -231,7 +231,7 @@ def load_model(self):
231231 tokenizer = self .tokenizer ,
232232 cache_config = cache_config ,
233233 )
234- elif self .sampler_type == "sglang-jax " :
234+ elif self .sampler_type == "sglang_jax " :
235235 from tunix .google .stubs import sglang_jax_sampler_stub as sglang_jax_sampler # pylint: disable=g-import-not-at-top
236236
237237 mapping_config = mappings .MappingConfig .build (
@@ -254,6 +254,27 @@ def load_model(self):
254254 mapping_config = mapping_config ,
255255 ),
256256 )
257+ elif self .sampler_type == "vllm" :
258+ from tunix .google .stubs import vllm_sampler_stub as vllm_sampler # pylint: disable=g-import-not-at-top
259+
260+ mapping_config = mappings .MappingConfig .build (
261+ mapping_obj = None ,
262+ model = self .model ,
263+ backend = "vllm" ,
264+ )
265+ self .sampler_vllm = vllm_sampler .VllmSampler (
266+ tokenizer = self .tokenizer ,
267+ config = vllm_sampler .VllmConfig (
268+ mesh = self .mesh ,
269+ max_model_len = self .max_prompt_length
270+ + self .max_generation_steps
271+ + 100 ,
272+ model_version = self .model_version ,
273+ hbm_utilization = 0.4 ,
274+ init_with_random_weights = False ,
275+ mapping_config = mapping_config ,
276+ ),
277+ )
257278 else :
258279 raise ValueError (f"Unsupported sampler type: { self .sampler_type } " )
259280
@@ -358,7 +379,7 @@ def generate(
358379 eos_tokens = [stop_token_id ],
359380 seed = jax .random .PRNGKey (seed ) if seed is not None else None ,
360381 )
361- elif self .sampler_type == "sglang-jax " :
382+ elif self .sampler_type == "sglang_jax " :
362383 out_data = self .sampler_sglang (
363384 input_strings = prompts ,
364385 max_generation_steps = safe_gen_length ,
@@ -370,6 +391,18 @@ def generate(
370391 echo = False ,
371392 pad_output = True ,
372393 )
394+ elif self .sampler_type == "vllm" :
395+ out_data = self .sampler_vllm (
396+ input_strings = prompts ,
397+ max_generation_steps = safe_gen_length ,
398+ max_prompt_length = self .max_prompt_length ,
399+ temperature = temperature ,
400+ top_p = top_p ,
401+ top_k = top_k ,
402+ seed = seed ,
403+ echo = False ,
404+ pad_output = True ,
405+ )
373406 else :
374407 raise ValueError (f"Unsupported sampler type: { self .sampler_type } " )
375408 return out_data .text
@@ -585,6 +618,7 @@ def evaluate(
585618# %%
586619# AIME-2024
587620model_version = "agentica-org/DeepScaleR-1.5B-Preview"
621+ # model_version = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
588622dataset = AIME_2024_DATA_PATH
589623model_config , model_path = MODEL_MAPPING [model_version ]
590624
0 commit comments