-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[SGLang Async Rollout] Validate prompt_len + max_resp_len <= max_mode… #1627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -482,7 +482,10 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo | |||
else: | |||
raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") | |||
elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: | |||
generation_prompt = _req.get_generation_prompt(self.tokenizer) | |||
generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) | |||
if len(generation_prompt_ids) + self.config.response_length >= self.config.max_model_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to allow the generation to reach max_model_len
when len(generation_prompt_ids) < max_model_len
and len(generation_prompt_ids) + response_length >= max_model_len
.
For example, if our model with a tool call completes at 30,000 tokens, the maximum model length is 32,768, and the response length is 8,192. The model could potentially generate a correct answer between 30,000 and 32,768 tokens. However, under the current design, this range is skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree—that would be the ideal behavior. Unfortunately, SGLang currently enforces a check in its engine(link) that throws an error if input_token_len + max_new_tokens
exceeds the model’s max context length, and that is the error this PR tries to avoid. I'm not sure if there’s a specific reason behind this design choice in SGLang, but it does restrict us from fully utilizing the available context window.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree—that would be the ideal behavior. Unfortunately, SGLang currently enforces a check in its engine(link) that throws an error if
input_token_len + max_new_tokens
exceeds the model’s max context length, and that is the error this PR tries to avoid. I'm not sure if there’s a specific reason behind this design choice in SGLang, but it does restrict us from fully utilizing the available context window.
Could we define a max_new_tokens
local variable that dynamically alternates between config.response_length
and max_model_len - len(generation_prompt_ids)
and allow it to be overridden within the sampling kwargs
? Does that approach work for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brilliant idea! Just updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Giiven the generate_sequence
method sends a batch of requests to SGLang together, I didn't add this logic for it.
e27b476
to
c91171d
Compare
LGTM |
…l_len before generation
…l_len before generation
Checklist Before Starting
What does this PR do?
This PR adds a validation step to prevent generation requests that exceed the model’s maximum context length in SGLang. Without this check, multi-turn RL training can fail when the combined length of the prompt and the maximum response exceeds the model limit. The new validation ensures
prompt_len + max_resp_len <= max_model_len
before sending requests to the SGLang engine.Test
Successfully tested with my multiturn RL dataset with
max_turns==30
which keeps failing with the following error before this change(Qwen2.5-32B-instruct + GRPO):Additional Info.
Checklist Before Submitting
[BREAKING]
to the PR title if it breaks any API.