-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[SGLang Async Rollout] Validate prompt_len + max_resp_len <= max_mode… #1627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -482,7 +482,10 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo | |||
else: | |||
raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") | |||
elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: | |||
generation_prompt = _req.get_generation_prompt(self.tokenizer) | |||
generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) | |||
if len(generation_prompt_ids) + self.config.response_length >= self.config.max_model_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to allow the generation to reach max_model_len
when len(generation_prompt_ids) < max_model_len
and len(generation_prompt_ids) + response_length >= max_model_len
.
For example, if our model with a tool call completes at 30,000 tokens, the maximum model length is 32,768, and the response length is 8,192. The model could potentially generate a correct answer between 30,000 and 32,768 tokens. However, under the current design, this range is skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree—that would be the ideal behavior. Unfortunately, SGLang currently enforces a check in its engine(link) that throws an error if input_token_len + max_new_tokens
exceeds the model’s max context length, and that is the error this PR tries to avoid. I'm not sure if there’s a specific reason behind this design choice in SGLang, but it does restrict us from fully utilizing the available context window.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree—that would be the ideal behavior. Unfortunately, SGLang currently enforces a check in its engine(link) that throws an error if
input_token_len + max_new_tokens
exceeds the model’s max context length, and that is the error this PR tries to avoid. I'm not sure if there’s a specific reason behind this design choice in SGLang, but it does restrict us from fully utilizing the available context window.
Could we define a max_new_tokens
local variable that dynamically alternates between config.response_length
and max_model_len - len(generation_prompt_ids)
and allow it to be overridden within the sampling kwargs
? Does that approach work for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brilliant idea! Just updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Giiven the generate_sequence
method sends a batch of requests to SGLang together, I didn't add this logic for it.
e27b476
to
c91171d
Compare
LGTM |
…l_len before generation
volcengine#1627) …l_len before generation ### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR adds a validation step to prevent generation requests that exceed the model’s maximum context length in SGLang. Without this check, multi-turn RL training can fail when the combined length of the prompt and the maximum response exceeds the model limit. The new validation ensures `prompt_len + max_resp_len <= max_model_len` before sending requests to the SGLang engine. ### Test Successfully tested with my multiturn RL dataset with `max_turns==30` which keeps failing with the following error before this change(Qwen2.5-32B-instruct + GRPO): ``` Traceback (most recent call last): File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 64, in main run_ppo(config) File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 76, in run_ppo ray.get(runner.run.remote(config)) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper return fn(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper return func(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2822, in get values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 930, in get_objects raise value.as_instanceof_cause() ray.exceptions.RayTaskError(ValueError): ray::TaskRunner.run() (pid=1150536, ip=100.96.248.206, actor_id=85b22be1ed8ef671c739638a01000000, repr=<main_ppo.TaskRunner object at 0x796b0bba7010>) File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 183, in run trainer.fit() File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 872, in fit val_metrics = self._validate() File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 607, in _validate test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 49, in func output = ray.get(output) ray.exceptions.RayTaskError(ValueError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=1169888, ip=100.96.248.206, actor_id=6deb9fd4b4ff01530920ada301000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7e41e90afa90>) File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 625, in func return getattr(self.worker_dict[key], name)(*args, **kwargs) File "/home/jobuser/resources/verl/single_controller/base/decorator.py", line 534, in inner return func(*args, **kwargs) File "/home/jobuser/resources/verl/workers/fsdp_workers.py", line 630, in generate_sequences output = self.rollout.generate_sequences_with_tools(prompts=prompts) File "/home/jobuser/resources/verl/utils/debug/performance.py", line 78, in f return self.log(decorated_function, *args, **kwargs) File "/home/jobuser/resources/verl/utils/debug/performance.py", line 88, in log output = func(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 613, in generate_sequences_with_tools output_req_list = loop.run_until_complete( File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 529, in _async_rollout_a_request output = await self._engine.async_generate( File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 265, in async_generate return await generator.__anext__() File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 403, in generate_request tokenized_obj = await self._tokenize_one_request(obj) File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 450, in _tokenize_one_request self._validate_token_len(obj, input_ids) File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 482, in _validate_token_len raise ValueError(error_msg) ValueError: Requested token count exceeds the model's maximum context length of 32768 tokens. You requested a total of 34009 tokens: 23769 tokens from the input messages and 10240 tokens for the completion. Please reduce the number of tokens in the input messages or the completion to fit within the limit. ``` ### Additional Info. - **Inference**: SGLang, ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary.
volcengine#1627) …l_len before generation ### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR adds a validation step to prevent generation requests that exceed the model’s maximum context length in SGLang. Without this check, multi-turn RL training can fail when the combined length of the prompt and the maximum response exceeds the model limit. The new validation ensures `prompt_len + max_resp_len <= max_model_len` before sending requests to the SGLang engine. ### Test Successfully tested with my multiturn RL dataset with `max_turns==30` which keeps failing with the following error before this change(Qwen2.5-32B-instruct + GRPO): ``` Traceback (most recent call last): File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 64, in main run_ppo(config) File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 76, in run_ppo ray.get(runner.run.remote(config)) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper return fn(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper return func(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2822, in get values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 930, in get_objects raise value.as_instanceof_cause() ray.exceptions.RayTaskError(ValueError): ray::TaskRunner.run() (pid=1150536, ip=100.96.248.206, actor_id=85b22be1ed8ef671c739638a01000000, repr=<main_ppo.TaskRunner object at 0x796b0bba7010>) File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 183, in run trainer.fit() File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 872, in fit val_metrics = self._validate() File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 607, in _validate test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 49, in func output = ray.get(output) ray.exceptions.RayTaskError(ValueError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=1169888, ip=100.96.248.206, actor_id=6deb9fd4b4ff01530920ada301000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7e41e90afa90>) File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 625, in func return getattr(self.worker_dict[key], name)(*args, **kwargs) File "/home/jobuser/resources/verl/single_controller/base/decorator.py", line 534, in inner return func(*args, **kwargs) File "/home/jobuser/resources/verl/workers/fsdp_workers.py", line 630, in generate_sequences output = self.rollout.generate_sequences_with_tools(prompts=prompts) File "/home/jobuser/resources/verl/utils/debug/performance.py", line 78, in f return self.log(decorated_function, *args, **kwargs) File "/home/jobuser/resources/verl/utils/debug/performance.py", line 88, in log output = func(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 613, in generate_sequences_with_tools output_req_list = loop.run_until_complete( File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 529, in _async_rollout_a_request output = await self._engine.async_generate( File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 265, in async_generate return await generator.__anext__() File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 403, in generate_request tokenized_obj = await self._tokenize_one_request(obj) File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 450, in _tokenize_one_request self._validate_token_len(obj, input_ids) File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 482, in _validate_token_len raise ValueError(error_msg) ValueError: Requested token count exceeds the model's maximum context length of 32768 tokens. You requested a total of 34009 tokens: 23769 tokens from the input messages and 10240 tokens for the completion. Please reduce the number of tokens in the input messages or the completion to fit within the limit. ``` ### Additional Info. - **Inference**: SGLang, ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary.
volcengine#1627) …l_len before generation ### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR adds a validation step to prevent generation requests that exceed the model’s maximum context length in SGLang. Without this check, multi-turn RL training can fail when the combined length of the prompt and the maximum response exceeds the model limit. The new validation ensures `prompt_len + max_resp_len <= max_model_len` before sending requests to the SGLang engine. ### Test Successfully tested with my multiturn RL dataset with `max_turns==30` which keeps failing with the following error before this change(Qwen2.5-32B-instruct + GRPO): ``` Traceback (most recent call last): File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 64, in main run_ppo(config) File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 76, in run_ppo ray.get(runner.run.remote(config)) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper return fn(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper return func(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2822, in get values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 930, in get_objects raise value.as_instanceof_cause() ray.exceptions.RayTaskError(ValueError): ray::TaskRunner.run() (pid=1150536, ip=100.96.248.206, actor_id=85b22be1ed8ef671c739638a01000000, repr=<main_ppo.TaskRunner object at 0x796b0bba7010>) File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 183, in run trainer.fit() File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 872, in fit val_metrics = self._validate() File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 607, in _validate test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 49, in func output = ray.get(output) ray.exceptions.RayTaskError(ValueError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=1169888, ip=100.96.248.206, actor_id=6deb9fd4b4ff01530920ada301000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7e41e90afa90>) File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 625, in func return getattr(self.worker_dict[key], name)(*args, **kwargs) File "/home/jobuser/resources/verl/single_controller/base/decorator.py", line 534, in inner return func(*args, **kwargs) File "/home/jobuser/resources/verl/workers/fsdp_workers.py", line 630, in generate_sequences output = self.rollout.generate_sequences_with_tools(prompts=prompts) File "/home/jobuser/resources/verl/utils/debug/performance.py", line 78, in f return self.log(decorated_function, *args, **kwargs) File "/home/jobuser/resources/verl/utils/debug/performance.py", line 88, in log output = func(*args, **kwargs) File "/home/jobuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 613, in generate_sequences_with_tools output_req_list = loop.run_until_complete( File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 529, in _async_rollout_a_request output = await self._engine.async_generate( File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 265, in async_generate return await generator.__anext__() File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 403, in generate_request tokenized_obj = await self._tokenize_one_request(obj) File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 450, in _tokenize_one_request self._validate_token_len(obj, input_ids) File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 482, in _validate_token_len raise ValueError(error_msg) ValueError: Requested token count exceeds the model's maximum context length of 32768 tokens. You requested a total of 34009 tokens: 23769 tokens from the input messages and 10240 tokens for the completion. Please reduce the number of tokens in the input messages or the completion to fit within the limit. ``` ### Additional Info. - **Inference**: SGLang, ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary.
…l_len before generation
Checklist Before Starting
What does this PR do?
This PR adds a validation step to prevent generation requests that exceed the model’s maximum context length in SGLang. Without this check, multi-turn RL training can fail when the combined length of the prompt and the maximum response exceeds the model limit. The new validation ensures
prompt_len + max_resp_len <= max_model_len
before sending requests to the SGLang engine.Test
Successfully tested with my multiturn RL dataset with
max_turns==30
which keeps failing with the following error before this change(Qwen2.5-32B-instruct + GRPO):Additional Info.
Checklist Before Submitting
[BREAKING]
to the PR title if it breaks any API.