Skip to content

[SGLang Async Rollout] Validate prompt_len + max_resp_len <= max_mode… #1627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 24, 2025

Conversation

jybsuper
Copy link
Contributor

@jybsuper jybsuper commented May 22, 2025

…l_len before generation

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

This PR adds a validation step to prevent generation requests that exceed the model’s maximum context length in SGLang. Without this check, multi-turn RL training can fail when the combined length of the prompt and the maximum response exceeds the model limit. The new validation ensures prompt_len + max_resp_len <= max_model_len before sending requests to the SGLang engine.

Test

Successfully tested with my multiturn RL dataset with max_turns==30 which keeps failing with the following error before this change(Qwen2.5-32B-instruct + GRPO):

Traceback (most recent call last):
  File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 64, in main
    run_ppo(config)
  File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 76, in run_ppo
    ray.get(runner.run.remote(config))
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2822, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/jobuser/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 930, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): ray::TaskRunner.run() (pid=1150536, ip=100.96.248.206, actor_id=85b22be1ed8ef671c739638a01000000, repr=<main_ppo.TaskRunner object at 0x796b0bba7010>)
  File "/home/jobuser/resources/verl/trainer/main_ppo.py", line 183, in run
    trainer.fit()
  File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 872, in fit
    val_metrics = self._validate()
  File "/home/jobuser/resources/verl/trainer/ppo/ray_trainer.py", line 607, in _validate
    test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
  File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 49, in func
    output = ray.get(output)
ray.exceptions.RayTaskError(ValueError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=1169888, ip=100.96.248.206, actor_id=6deb9fd4b4ff01530920ada301000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7e41e90afa90>)
  File "/home/jobuser/resources/verl/single_controller/ray/base.py", line 625, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/home/jobuser/resources/verl/single_controller/base/decorator.py", line 534, in inner
    return func(*args, **kwargs)
  File "/home/jobuser/resources/verl/workers/fsdp_workers.py", line 630, in generate_sequences
    output = self.rollout.generate_sequences_with_tools(prompts=prompts)
  File "/home/jobuser/resources/verl/utils/debug/performance.py", line 78, in f
    return self.log(decorated_function, *args, **kwargs)
  File "/home/jobuser/resources/verl/utils/debug/performance.py", line 88, in log
    output = func(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 613, in generate_sequences_with_tools
    output_req_list = loop.run_until_complete(
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/home/jobuser/resources/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py", line 529, in _async_rollout_a_request
    output = await self._engine.async_generate(
  File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 265, in async_generate
    return await generator.__anext__()
  File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 403, in generate_request
    tokenized_obj = await self._tokenize_one_request(obj)
  File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 450, in _tokenize_one_request
    self._validate_token_len(obj, input_ids)
  File "/home/jobuser/.local/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 482, in _validate_token_len
    raise ValueError(error_msg)
ValueError: Requested token count exceeds the model's maximum context length of 32768 tokens. You requested a total of 34009 tokens: 23769 tokens from the input messages and 10240 tokens for the completion. Please reduce the number of tokens in the input messages or the completion to fit within the limit.

Additional Info.

  • Inference: SGLang,

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if necessary.

@vermouth1992 vermouth1992 requested a review from SwordFaith May 22, 2025 00:45
@@ -482,7 +482,10 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo
else:
raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}")
elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:
generation_prompt = _req.get_generation_prompt(self.tokenizer)
generation_prompt_ids = _req.get_generation_prompt(self.tokenizer)
if len(generation_prompt_ids) + self.config.response_length >= self.config.max_model_len:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to allow the generation to reach max_model_len when len(generation_prompt_ids) < max_model_len and len(generation_prompt_ids) + response_length >= max_model_len.

For example, if our model with a tool call completes at 30,000 tokens, the maximum model length is 32,768, and the response length is 8,192. The model could potentially generate a correct answer between 30,000 and 32,768 tokens. However, under the current design, this range is skipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree—that would be the ideal behavior. Unfortunately, SGLang currently enforces a check in its engine(link) that throws an error if input_token_len + max_new_tokens exceeds the model’s max context length, and that is the error this PR tries to avoid. I'm not sure if there’s a specific reason behind this design choice in SGLang, but it does restrict us from fully utilizing the available context window.

Copy link
Collaborator

@SwordFaith SwordFaith May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree—that would be the ideal behavior. Unfortunately, SGLang currently enforces a check in its engine(link) that throws an error if input_token_len + max_new_tokens exceeds the model’s max context length, and that is the error this PR tries to avoid. I'm not sure if there’s a specific reason behind this design choice in SGLang, but it does restrict us from fully utilizing the available context window.

Could we define a max_new_tokens local variable that dynamically alternates between config.response_length and max_model_len - len(generation_prompt_ids) and allow it to be overridden within the sampling kwargs? Does that approach work for you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant idea! Just updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Giiven the generate_sequence method sends a batch of requests to SGLang together, I didn't add this logic for it.

@jybsuper jybsuper force-pushed the max_len branch 3 times, most recently from e27b476 to c91171d Compare May 22, 2025 03:47
@SwordFaith
Copy link
Collaborator

LGTM

@vermouth1992 vermouth1992 merged commit 7225544 into volcengine:main May 24, 2025
35 of 36 checks passed
@jybsuper jybsuper deleted the max_len branch May 24, 2025 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants