Skip to content

Commit b418aa7

Browse files
committed
adding support for vllm sampler kwargs.
1 parent f2dcd33 commit b418aa7

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

tunix/generate/vllm_sampler.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ class VllmConfig:
7171
init=False, default_factory=dict
7272
)
7373

74+
# vLLM sampler args that can be directly passed in without additional processing, e.g. temperature, stop etc.
75+
sampler_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
76+
7477
def __post_init__(self, engine_kwargs: Optional[Dict[str, Any]]):
7578
engine_kwargs = engine_kwargs or {}
7679
self._processed_engine_kwargs = engine_kwargs
@@ -418,16 +421,21 @@ def __call__(
418421
if seed is not None:
419422
sampling_params.seed = seed
420423

421-
if kwargs:
424+
kwargs.update(self.config.sampler_kwargs)
425+
if kwargs:
422426
try:
423-
sampling_params.update(**kwargs)
424427
logging.log_first_n(
425428
logging.INFO,
426429
"Received additional kwargs that are not explicitly defined in"
427430
f" the method signature: {kwargs}. These will be forwarded to the"
428431
" underlying sampler, but please ensure that they are valid.",
429432
1,
430-
)
433+
)
434+
for key, value in kwargs.items():
435+
logging.info(
436+
"Sampler kwargs setting key '%s' with value '%s'.", key, value
437+
)
438+
setattr(sampling_params, key, value)
431439
except Exception as e:
432440
logging.log_first_n(
433441
logging.INFO,

tunix/rl/rollout/base_rollout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,12 @@ class RolloutConfig:
157157
# Maximum number of concurrent sequences allowed to be processed in vLLM.
158158
rollout_vllm_max_num_seqs: Optional[int] = None
159159

160-
# Additional keyword arguments forwarded directly to the vLLM sampler/engine.
160+
# Additional keyword arguments forwarded directly to the vLLM engine constructor.
161161
rollout_vllm_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
162162

163+
# Additional keyword arguments forwarded directly to the vLLM sampler.
164+
rollout_vllm_sampler_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
165+
163166
# SG-Lang JAX specific rollout configs.
164167

165168
# Model version for SG-Lang JAX rollout engine.

tunix/rl/rollout/vllm_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
"hf_config_path": rollout_config.rollout_vllm_hf_config_path,
6969
**rollout_config.rollout_vllm_kwargs,
7070
},
71+
sampler_kwargs=rollout_config.rollout_vllm_sampler_kwargs,
7172
),
7273
)
7374
state = nnx.state(model)

0 commit comments

Comments
 (0)