Fix multi-turn generator token budget accounting#1642
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a refined token budgeting mechanism for multi-turn generators, distinguishing between trajectory-level assistant token limits and the model's maximum context window. It introduces logic to cap per-turn inference requests based on the remaining budget and context, adds validation to ensure configuration consistency, and updates documentation and tests accordingly. Feedback identifies an opportunity to simplify token calculation logic in skyrl_gym_generator.py and points out a redundant variable assignment in skyrl_vlm_generator.py.
| turn_generated_tokens = sum(turn_output.get_turn_loss_mask()) | ||
| if not self.use_conversation_multi_turn and not retokenize_chat_history: | ||
| new_resp_tokens = turn_output.output_ids.copy() | ||
| if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id: | ||
| new_resp_tokens = new_resp_tokens[:-1] | ||
| turn_generated_tokens = len(new_resp_tokens) |
There was a problem hiding this comment.
The logic for calculating turn_generated_tokens can be simplified to avoid an unnecessary list copy and slicing when use_conversation_multi_turn is False. Additionally, moving the sum call into an else block avoids redundant calculation in that scenario.
| turn_generated_tokens = sum(turn_output.get_turn_loss_mask()) | |
| if not self.use_conversation_multi_turn and not retokenize_chat_history: | |
| new_resp_tokens = turn_output.output_ids.copy() | |
| if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id: | |
| new_resp_tokens = new_resp_tokens[:-1] | |
| turn_generated_tokens = len(new_resp_tokens) | |
| if not self.use_conversation_multi_turn and not retokenize_chat_history: | |
| turn_generated_tokens = len(turn_output.output_ids) | |
| if turn_output.output_ids and turn_output.output_ids[-1] == self.tokenizer.eos_token_id: | |
| turn_generated_tokens -= 1 | |
| else: | |
| turn_generated_tokens = sum(turn_output.get_turn_loss_mask()) |
| if generated_tokens_used >= max_tokens and not done: | ||
| stop_reason = "length" | ||
| done = True | ||
| new_obs = [] |
Summary
Fixes #406 by making
generator.sampling_params.max_generate_lengtha trajectory-level assistant-token budget for multi-turn SkyRL rollouts instead of accidentally treating observation/context tokens as generated completion budget.Problem
Multi-turn rollouts can grow the prompt with environment observations between assistant turns. Before this change, async generator paths did not consistently shrink the per-request vLLM
max_tokensby the remaining assistant budget, and some paths mixed three separate limits:generator.sampling_params.max_generate_length: intended assistant/generated-token budget for the trajectorygenerator.max_input_length: input/context cap checked before each turngenerator.inference_engine.engine_init_kwargs.max_model_len: vLLM prompt-plus-completion context windowThat could either allow trajectories to exceed the intended assistant-token budget, or let vLLM receive a request whose prompt plus completion window could not fit.
Changes
max_model_len, and compute safe per-requestmax_tokens.SkyRLGymGenerator.agent_loopto:max_model_len - current_input_lengthlengthwhen no decode room remainssampling_params["max_tokens"]as the trajectory budgetSkyRLVLMGymGeneratorwith the same per-turn assistant-budget andmax_model_lenbehavior.max_tokens.loss_maskassistant tokens first, with packed prompt+response length retained as a secondary training tensor guard.max_generate_length,max_input_length, and vLLMmax_model_len.Validation
uv run --with ruff --isolated ruff check skyrl/train/generators/utils.py skyrl/train/generators/skyrl_gym_generator.py skyrl/train/generators/skyrl_vlm_generator.py skyrl/train/utils/utils.py skyrl/train/config/config.py examples/train/mini_swe_agent/mini_swe_generator.py tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_skyrl_vlm_generator.py tests/train/test_config.pyuv run --extra dev --extra skyrl-train --with transformers --isolated pytest tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_skyrl_vlm_generator.py tests/train/test_config.py tests/train/generators/test_utils.pyResult: 109 tests passed.
Note:
--with transformersis needed on this macOS environment because the current uv override constrains Transformers to Linux by default.