Skip to content

Fix multi-turn generator token budget accounting#1642

Open
taivu1998 wants to merge 1 commit into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-406-token-budget
Open

Fix multi-turn generator token budget accounting#1642
taivu1998 wants to merge 1 commit into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-406-token-budget

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #406 by making generator.sampling_params.max_generate_length a trajectory-level assistant-token budget for multi-turn SkyRL rollouts instead of accidentally treating observation/context tokens as generated completion budget.

Problem

Multi-turn rollouts can grow the prompt with environment observations between assistant turns. Before this change, async generator paths did not consistently shrink the per-request vLLM max_tokens by the remaining assistant budget, and some paths mixed three separate limits:

  • generator.sampling_params.max_generate_length: intended assistant/generated-token budget for the trajectory
  • generator.max_input_length: input/context cap checked before each turn
  • generator.inference_engine.engine_init_kwargs.max_model_len: vLLM prompt-plus-completion context window

That could either allow trajectories to exceed the intended assistant-token budget, or let vLLM receive a request whose prompt plus completion window could not fit.

Changes

  • Added shared generator helpers to normalize backend-shaped sampling params, read max_model_len, and compute safe per-request max_tokens.
  • Updated SkyRLGymGenerator.agent_loop to:
    • cap each turn by remaining assistant-token budget
    • additionally cap by max_model_len - current_input_length
    • avoid counting environment observations against the assistant budget
    • stop cleanly with length when no decode room remains
    • preserve eval/custom sampling_params["max_tokens"] as the trajectory budget
  • Updated SkyRLVLMGymGenerator with the same per-turn assistant-budget and max_model_len behavior.
  • Kept batched single-turn generation backend-shaped by always passing max_tokens.
  • Updated Mini-SWE post-processing so truncation uses loss_mask assistant tokens first, with packed prompt+response length retained as a secondary training tensor guard.
  • Added validation and docs clarifying the difference between max_generate_length, max_input_length, and vLLM max_model_len.

Validation

  • uv run --with ruff --isolated ruff check skyrl/train/generators/utils.py skyrl/train/generators/skyrl_gym_generator.py skyrl/train/generators/skyrl_vlm_generator.py skyrl/train/utils/utils.py skyrl/train/config/config.py examples/train/mini_swe_agent/mini_swe_generator.py tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_skyrl_vlm_generator.py tests/train/test_config.py
  • uv run --extra dev --extra skyrl-train --with transformers --isolated pytest tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_skyrl_vlm_generator.py tests/train/test_config.py tests/train/generators/test_utils.py

Result: 109 tests passed.

Note: --with transformers is needed on this macOS environment because the current uv override constrains Transformers to Linux by default.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:11
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a refined token budgeting mechanism for multi-turn generators, distinguishing between trajectory-level assistant token limits and the model's maximum context window. It introduces logic to cap per-turn inference requests based on the remaining budget and context, adds validation to ensure configuration consistency, and updates documentation and tests accordingly. Feedback identifies an opportunity to simplify token calculation logic in skyrl_gym_generator.py and points out a redundant variable assignment in skyrl_vlm_generator.py.

Comment on lines +461 to +466
turn_generated_tokens = sum(turn_output.get_turn_loss_mask())
if not self.use_conversation_multi_turn and not retokenize_chat_history:
new_resp_tokens = turn_output.output_ids.copy()
if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id:
new_resp_tokens = new_resp_tokens[:-1]
turn_generated_tokens = len(new_resp_tokens)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating turn_generated_tokens can be simplified to avoid an unnecessary list copy and slicing when use_conversation_multi_turn is False. Additionally, moving the sum call into an else block avoids redundant calculation in that scenario.

Suggested change
turn_generated_tokens = sum(turn_output.get_turn_loss_mask())
if not self.use_conversation_multi_turn and not retokenize_chat_history:
new_resp_tokens = turn_output.output_ids.copy()
if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id:
new_resp_tokens = new_resp_tokens[:-1]
turn_generated_tokens = len(new_resp_tokens)
if not self.use_conversation_multi_turn and not retokenize_chat_history:
turn_generated_tokens = len(turn_output.output_ids)
if turn_output.output_ids and turn_output.output_ids[-1] == self.tokenizer.eos_token_id:
turn_generated_tokens -= 1
else:
turn_generated_tokens = sum(turn_output.get_turn_loss_mask())

if generated_tokens_used >= max_tokens and not done:
stop_reason = "length"
done = True
new_obs = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assignment to new_obs is redundant because done is set to True immediately before, which will cause the loop to terminate. Since the conversation extension at line 213 is guarded by if not done:, this empty list is never utilized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make per-turn max_tokens handling (vs. truncation at the end) more explicit and better for multi-turn

1 participant