Skip to content

Commit a5efcd7

Browse files
committed
adding support for vllm sampler kwargs.
1 parent efb4913 commit a5efcd7

File tree

5 files changed

+186
-8
lines changed

5 files changed

+186
-8
lines changed

tests/generate/vllm_sampler_test.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tunix.generate import mappings
3232
from tunix.generate import sampler as vanilla_sampler
3333
from tunix.generate import vllm_sampler
34+
from tunix.models.dummy_model_creator import create_dummy_model
3435
from tunix.models.llama3 import model as llama_lib
3536
from tunix.models.llama3 import params as llama_params
3637
from tunix.sft import utils as base_utils
@@ -357,6 +358,169 @@ async def dispatch_requests():
357358
),
358359
)
359360

361+
def test_vllm_sampler_sampling_kwargs(self):
362+
"""Test that sampling kwargs are correctly applied to sampling_params."""
363+
tunix_model = create_dummy_model(
364+
model_class=llama_lib.Llama3,
365+
config=llama_lib.ModelConfig.llama3p2_1b(),
366+
mesh=self.mesh,
367+
random_seed=3,
368+
)
369+
370+
model_tokenizer = transformers.AutoTokenizer.from_pretrained(
371+
self.model_path
372+
)
373+
374+
prompts = ["Hello, my name is Tom."]
375+
inputs = tc.batch_templatize(prompts, model_tokenizer)
376+
377+
mapping_config = mappings.MappingConfig.build(tunix_model)
378+
379+
# Test 1: Config sampling_kwargs are applied
380+
config_sampling_kwargs = {
381+
"frequency_penalty": 0.5,
382+
"presence_penalty": 0.3,
383+
}
384+
385+
vllm_config = vllm_sampler.VllmConfig(
386+
mesh=self.mesh,
387+
hbm_utilization=0.2,
388+
init_with_random_weights=True,
389+
tpu_backend_type="jax",
390+
mapping_config=mapping_config,
391+
server_mode=False,
392+
sampling_kwargs=config_sampling_kwargs,
393+
engine_kwargs={
394+
"model": self.model_path,
395+
"max_model_len": 512,
396+
"enable_prefix_caching": True,
397+
},
398+
)
399+
400+
vl_sampler = vllm_sampler.VllmSampler(
401+
tokenizer=model_tokenizer,
402+
config=vllm_config,
403+
)
404+
405+
state = nnx.state(tunix_model)
406+
vl_sampler.load_checkpoint(state)
407+
408+
# Mock the generate method to capture sampling_params
409+
original_generate = vl_sampler.llm.generate
410+
captured_sampling_params = []
411+
412+
def mock_generate(prompts, sampling_params, **kwargs):
413+
captured_sampling_params.append(sampling_params)
414+
return original_generate(prompts, sampling_params, **kwargs)
415+
416+
vl_sampler.llm.generate = mock_generate
417+
418+
# Call with additional method kwargs
419+
method_sampling_kwargs = {"min_tokens": 10}
420+
vl_sampler(
421+
input_strings=inputs,
422+
max_generation_steps=128,
423+
max_prompt_length=None,
424+
temperature=0.0,
425+
top_k=1,
426+
seed=0,
427+
echo=False,
428+
pad_output=True,
429+
**method_sampling_kwargs,
430+
)
431+
432+
# Verify that both config and method kwargs were applied
433+
self.assertLen(captured_sampling_params, 1)
434+
sampling_params = captured_sampling_params[0]
435+
436+
# Check config kwargs
437+
self.assertEqual(sampling_params.frequency_penalty, 0.5)
438+
self.assertEqual(sampling_params.presence_penalty, 0.3)
439+
440+
# Check method kwargs
441+
self.assertEqual(sampling_params.min_tokens, 10)
442+
443+
def test_vllm_sampler_sampling_kwargs_override(self):
444+
"""Test that method kwargs override config sampling_kwargs."""
445+
tunix_model = create_dummy_model(
446+
model_class=llama_lib.Llama3,
447+
config=llama_lib.ModelConfig.llama3p2_1b(),
448+
mesh=self.mesh,
449+
random_seed=3,
450+
)
451+
452+
model_tokenizer = transformers.AutoTokenizer.from_pretrained(
453+
self.model_path
454+
)
455+
456+
prompts = ["Hello, my name is Tom."]
457+
inputs = tc.batch_templatize(prompts, model_tokenizer)
458+
459+
mapping_config = mappings.MappingConfig.build(tunix_model)
460+
461+
# Config has frequency_penalty = 0.5
462+
config_sampling_kwargs = {
463+
"frequency_penalty": 0.5,
464+
"presence_penalty": 0.3,
465+
}
466+
467+
vllm_config = vllm_sampler.VllmConfig(
468+
mesh=self.mesh,
469+
hbm_utilization=0.2,
470+
init_with_random_weights=True,
471+
tpu_backend_type="jax",
472+
mapping_config=mapping_config,
473+
server_mode=False,
474+
sampling_kwargs=config_sampling_kwargs,
475+
engine_kwargs={
476+
"model": self.model_path,
477+
"max_model_len": 512,
478+
"enable_prefix_caching": True,
479+
},
480+
)
481+
482+
vl_sampler = vllm_sampler.VllmSampler(
483+
tokenizer=model_tokenizer,
484+
config=vllm_config,
485+
)
486+
487+
state = nnx.state(tunix_model)
488+
vl_sampler.load_checkpoint(state)
489+
490+
# Mock the generate method to capture sampling_params
491+
original_generate = vl_sampler.llm.generate
492+
captured_sampling_params = []
493+
494+
def mock_generate(prompts, sampling_params, **kwargs):
495+
captured_sampling_params.append(sampling_params)
496+
return original_generate(prompts, sampling_params, **kwargs)
497+
498+
vl_sampler.llm.generate = mock_generate
499+
500+
# Call with method kwargs that override config kwargs
501+
method_sampling_kwargs = {"frequency_penalty": 0.8} # Override from 0.5 to 0.8
502+
vl_sampler(
503+
input_strings=inputs,
504+
max_generation_steps=128,
505+
max_prompt_length=None,
506+
temperature=0.0,
507+
top_k=1,
508+
seed=0,
509+
echo=False,
510+
pad_output=True,
511+
**method_sampling_kwargs,
512+
)
513+
514+
# Verify that method kwargs override config kwargs
515+
self.assertLen(captured_sampling_params, 1)
516+
sampling_params = captured_sampling_params[0]
517+
518+
# Check that method kwarg overrides config kwarg
519+
self.assertEqual(sampling_params.frequency_penalty, 0.8)
520+
521+
# Check that other config kwargs are still applied
522+
self.assertEqual(sampling_params.presence_penalty, 0.3)
523+
360524

361525
if __name__ == "__main__":
362526
absltest.main()

tunix/generate/vllm_sampler.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ class VllmConfig:
7171
init=False, default_factory=dict
7272
)
7373

74+
# vLLM sampling args that can be directly passed in without additional processing, e.g. temperature, stop etc.
75+
sampling_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
76+
7477
def __post_init__(self, engine_kwargs: Optional[Dict[str, Any]]):
7578
engine_kwargs = engine_kwargs or {}
7679
self._processed_engine_kwargs = engine_kwargs
@@ -418,20 +421,25 @@ def __call__(
418421
if seed is not None:
419422
sampling_params.seed = seed
420423

421-
if kwargs:
424+
self.config.sampling_kwargs.update(kwargs)
425+
if self.config.sampling_kwargs:
422426
try:
423-
sampling_params.update(**kwargs)
424427
logging.log_first_n(
425428
logging.INFO,
426429
"Received additional kwargs that are not explicitly defined in"
427-
f" the method signature: {kwargs}. These will be forwarded to the"
430+
f" the method signature: {self.config.sampling_kwargs}. These will be forwarded to the"
428431
" underlying sampler, but please ensure that they are valid.",
429432
1,
430-
)
431-
except Exception as e:
433+
)
434+
for key, value in self.config.sampling_kwargs.items():
435+
logging.debug(
436+
"Sampler kwargs setting key '%s' with value '%s'.", key, value
437+
)
438+
setattr(sampling_params, key, value)
439+
except (AttributeError, TypeError) as e:
432440
logging.log_first_n(
433441
logging.INFO,
434-
f"Failed to update sampling_params with kwargs: {kwargs}."
442+
f"Failed to update sampling_params with kwargs: {self.config.sampling_kwargs}."
435443
f" Error: {e}",
436444
1,
437445
)

tunix/models/dummy_model_creator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def create_dummy_model(
6666

6767
@partial(nnx.jit, static_argnums=(2, 3,))
6868
def make_param(rngs, scale, shape, dt):
69-
return scale * rngs.params.normal(shape, dt)
69+
# Call the stream to get a unique JAX key, then use jax.random
70+
key = rngs.params()
71+
return scale * jax.random.normal(key, shape, dtype=dt)
7072

7173
def make_random_tensor(path, param, shard=None):
7274
shape = param.shape

tunix/rl/rollout/base_rollout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,12 @@ class RolloutConfig:
157157
# Maximum number of concurrent sequences allowed to be processed in vLLM.
158158
rollout_vllm_max_num_seqs: Optional[int] = None
159159

160-
# Additional keyword arguments forwarded directly to the vLLM sampler/engine.
160+
# Additional keyword arguments forwarded directly to the vLLM engine constructor.
161161
rollout_vllm_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
162162

163+
# Additional keyword arguments forwarded directly to the vLLM sampling params.
164+
rollout_vllm_sampling_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
165+
163166
# SG-Lang JAX specific rollout configs.
164167

165168
# Model version for SG-Lang JAX rollout engine.

tunix/rl/rollout/vllm_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
"hf_config_path": rollout_config.rollout_vllm_hf_config_path,
6969
**rollout_config.rollout_vllm_kwargs,
7070
},
71+
sampling_kwargs=rollout_config.rollout_vllm_sampling_kwargs,
7172
),
7273
)
7374
state = nnx.state(model)

0 commit comments

Comments
 (0)