File tree Expand file tree Collapse file tree 3 files changed +16
-4
lines changed
Expand file tree Collapse file tree 3 files changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,9 @@ class VllmConfig:
7171 init = False , default_factory = dict
7272 )
7373
74+ # vLLM sampler args that can be directly passed in without additional processing, e.g. temperature, stop etc.
75+ sampler_kwargs : Dict [str , Any ] = dataclasses .field (default_factory = dict )
76+
7477 def __post_init__ (self , engine_kwargs : Optional [Dict [str , Any ]]):
7578 engine_kwargs = engine_kwargs or {}
7679 self ._processed_engine_kwargs = engine_kwargs
@@ -418,16 +421,21 @@ def __call__(
418421 if seed is not None :
419422 sampling_params .seed = seed
420423
421- if kwargs :
424+ kwargs .update (self .config .sampler_kwargs )
425+ if kwargs :
422426 try :
423- sampling_params .update (** kwargs )
424427 logging .log_first_n (
425428 logging .INFO ,
426429 "Received additional kwargs that are not explicitly defined in"
427430 f" the method signature: { kwargs } . These will be forwarded to the"
428431 " underlying sampler, but please ensure that they are valid." ,
429432 1 ,
430- )
433+ )
434+ for key , value in kwargs .items ():
435+ logging .info (
436+ "Sampler kwargs setting key '%s' with value '%s'." , key , value
437+ )
438+ setattr (sampling_params , key , value )
431439 except Exception as e :
432440 logging .log_first_n (
433441 logging .INFO ,
Original file line number Diff line number Diff line change @@ -157,9 +157,12 @@ class RolloutConfig:
157157 # Maximum number of concurrent sequences allowed to be processed in vLLM.
158158 rollout_vllm_max_num_seqs : Optional [int ] = None
159159
160- # Additional keyword arguments forwarded directly to the vLLM sampler/ engine.
160+ # Additional keyword arguments forwarded directly to the vLLM engine constructor .
161161 rollout_vllm_kwargs : dict [str , Any ] = dataclasses .field (default_factory = dict )
162162
163+ # Additional keyword arguments forwarded directly to the vLLM sampler.
164+ rollout_vllm_sampler_kwargs : dict [str , Any ] = dataclasses .field (default_factory = dict )
165+
163166 # SG-Lang JAX specific rollout configs.
164167
165168 # Model version for SG-Lang JAX rollout engine.
Original file line number Diff line number Diff line change @@ -68,6 +68,7 @@ def __init__(
6868 "hf_config_path" : rollout_config .rollout_vllm_hf_config_path ,
6969 ** rollout_config .rollout_vllm_kwargs ,
7070 },
71+ sampler_kwargs = rollout_config .rollout_vllm_sampler_kwargs ,
7172 ),
7273 )
7374 state = nnx .state (model )
You can’t perform that action at this time.
0 commit comments