Skip to content

Commit 841ee9f

Browse files
committed
adding stop string support to rl.
1 parent d272058 commit 841ee9f

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

src/maxtext/configs/post_train/rl.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ enable_dp_attention: False
149149
# Performance tuning for samplers
150150
max_num_batched_tokens: null
151151
max_num_seqs: null
152+
# If True, enables asynchronous scheduling in vLLM for faster generation
153+
async_scheduling: True
154+
# stop generation when any of these strings is generated
155+
stop_strings: [</answer>]
152156

153157
# ====== Checkpoint Configuration ======
154158
enable_checkpointing: True

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,8 +1589,10 @@ class VLLM(BaseModel):
15891589
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
15901590
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
15911591
enable_dp_attention: bool = Field(False, description="Enable the attn_dp mesh axis in vLLM.")
1592+
async_scheduling: bool = Field(False, description="Enable asynchronous scheduling in vLLM.")
15921593
max_num_batched_tokens: Optional[int] = Field(None, description="Max number of batched tokens in vLLM.")
15931594
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
1595+
stop_strings: Optional[list[str]] = Field(None, description="List of stop strings for vLLM decoding.")
15941596
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
15951597
vllm_hf_overrides: dict[str, Any] = Field(
15961598
default_factory=dict,

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices)
253253
)
254254
rollout_kwargs["tensor_parallel_size"] = tp
255255
rollout_kwargs["data_parallel_size"] = dp
256-
rollout_kwargs["rollout_vllm_async_scheduling"] = True
257256

258257
return rollout_kwargs
259258

@@ -542,9 +541,15 @@ def _filter_long_prompts(x):
542541
rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention,
543542
rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens,
544543
rollout_vllm_max_num_seqs=trainer_config.max_num_seqs,
544+
rollout_vllm_async_scheduling=trainer_config.async_scheduling,
545545
rollout_vllm_kwargs={
546546
"hf_overrides": trainer_config.vllm_hf_overrides,
547547
},
548+
rollout_vllm_sampling_kwargs={
549+
"stop": trainer_config.stop_strings,
550+
"detokenize": trainer_config.stop_strings is not None,
551+
"include_stop_str_in_output": trainer_config.stop_strings is not None,
552+
},
548553
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
549554
),
550555
)

0 commit comments

Comments
 (0)