Skip to content

Commit ca13c4b

Browse files
committed
Fix multi-turn generator token budgets
1 parent e4648d4 commit ca13c4b

11 files changed

Lines changed: 541 additions & 53 deletions

File tree

examples/train/mini_swe_agent/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ For issues with SkyRL or the Mini-SWE-Agent integration, please [open an Issue](
5858

5959
### Common Issues
6060

61-
- **Context length errors**: If you see `ValueError: The decoder prompt (length xxxx) is longer than the maximum model length`, increase `max_input_length` and `max_generate_length` or reduce steps in `swebench.yaml`.
61+
- **Context length errors**: If you see `ValueError: The decoder prompt (length xxxx) is longer than the maximum model length`, increase the vLLM `engine_init_kwargs.max_model_len`, reduce `max_input_length`, or reduce steps in `swebench.yaml`. `max_generate_length` is the assistant-token budget for a trajectory and does not increase the model context window.
6262

6363
- **All zero rewards**: If rewards are consistently zero, the task may be too difficult. Consider:
6464
- Filtering data for a better mix of easy/hard samples

examples/train/mini_swe_agent/mini_swe_generator.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,37 @@
11
import asyncio
2-
from dataclasses import dataclass
3-
from typing import Dict, List, Optional, Any, Tuple
4-
import yaml
52
import traceback
6-
import ray
3+
from dataclasses import dataclass
74
from pathlib import Path
5+
from typing import Any, Dict, List, Optional, Tuple
86

9-
from minisweagent.models import get_model
7+
import ray
8+
import yaml
109
from minisweagent.agents.default import DefaultAgent
11-
from minisweagent.run.utils.save import save_traj
1210
from minisweagent.config import get_config_path
13-
from .mini_swe_utils import evaluate_trajectory, get_sb_environment
11+
from minisweagent.models import get_model
12+
from minisweagent.run.utils.save import save_traj
1413

15-
from skyrl.train.config import GeneratorConfig, SkyRLGymConfig
16-
from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator, GeneratorOutput, GeneratorInput
17-
from skyrl.train.generators.base import TrajectoryID, TrainingPhase, BatchMetadata
1814
from skyrl.backends.skyrl_train.inference_engines.base import ConversationType
19-
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
20-
from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend
15+
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import (
16+
InferenceEngineClient,
17+
)
18+
from skyrl.backends.skyrl_train.inference_engines.utils import (
19+
get_sampling_params_for_backend,
20+
)
21+
from skyrl.train.config import GeneratorConfig, SkyRLGymConfig
22+
from skyrl.train.generators.base import BatchMetadata, TrainingPhase, TrajectoryID
23+
from skyrl.train.generators.skyrl_gym_generator import (
24+
GeneratorInput,
25+
GeneratorOutput,
26+
SkyRLGymGenerator,
27+
)
2128
from skyrl.train.generators.utils import (
22-
get_rollout_metrics,
2329
get_response_ids_and_loss_mask_from_messages,
30+
get_rollout_metrics,
2431
)
2532

33+
from .mini_swe_utils import evaluate_trajectory, get_sb_environment
34+
2635

2736
@dataclass
2837
class MiniSWEGeneratorConfig(GeneratorConfig):
@@ -199,15 +208,27 @@ async def minisweagent_agent_loop(
199208
# Extract prompt ids
200209
prompt_ids = initial_input_ids
201210

202-
# Calculate maximum response tokens allowed
203-
max_response_tokens = max_tokens + max_input_length - initial_prompt_length
211+
# Truncate by assistant-token budget first. Environment/user observations are kept only
212+
# insofar as they fit the secondary packed-sequence guard below; they do not consume
213+
# max_generate_length because their loss mask is 0.
214+
assistant_tokens = 0
215+
assistant_budget_response_tokens = len(response_ids)
216+
assistant_budget_exceeded = False
217+
for idx, mask in enumerate(loss_mask):
218+
assistant_tokens += int(bool(mask))
219+
if assistant_tokens > max_tokens:
220+
assistant_budget_response_tokens = idx
221+
assistant_budget_exceeded = True
222+
break
223+
224+
# Keep the packed prompt+response sequence bounded for training tensor sizes.
225+
packed_response_tokens = max(0, max_tokens + max_input_length - initial_prompt_length)
226+
max_response_tokens = min(assistant_budget_response_tokens, packed_response_tokens)
204227

205-
# Determine stop reason
206228
stop_reason = "complete" # Default for trial completion
207-
if len(response_ids) > max_response_tokens:
229+
if assistant_budget_exceeded or len(response_ids) > packed_response_tokens:
208230
stop_reason = "length"
209231

210-
# Truncate to maximum allowed length
211232
response_ids = response_ids[:max_response_tokens]
212233
loss_mask = loss_mask[:max_response_tokens]
213234

skyrl/train/config/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ class FullyAsyncConfig(BaseConfig):
424424
@dataclass
425425
class SamplingParams(BaseConfig):
426426
max_generate_length: int = 1024
427+
"""Trajectory-level assistant/generated-token budget. In multi-turn generators,
428+
environment observation tokens are loss-masked and do not count against this budget.
429+
The vLLM request field is ``max_tokens`` and may be reduced per turn to fit context."""
427430
repetition_penalty: float = 1.0
428431
temperature: float = 1.0
429432
top_p: float = 1.0
@@ -496,7 +499,9 @@ class InferenceEngineConfig(BaseConfig):
496499
"""When True, pass ``language_model_only=True`` to the vLLM engine so that
497500
multimodal models (e.g. Qwen3.5) skip vision encoder initialization."""
498501
engine_init_kwargs: Dict[str, Any] = field(default_factory=dict)
499-
"""Pass-through kwargs for the vLLM engine. Names must match the engine's args."""
502+
"""Pass-through kwargs for the vLLM engine. Names must match the engine's args. If
503+
``max_model_len`` is set, rollout requests are capped so input tokens plus per-request
504+
generated tokens fit within that window."""
500505
override_existing_update_group: str = "auto"
501506
"""``"auto"``, ``"enable"``, or ``"disable"``."""
502507
external_proxy_url: Optional[str] = None
@@ -528,7 +533,8 @@ class GeneratorConfig(BaseConfig):
528533
batched: bool = False
529534
max_turns: int = 1
530535
max_input_length: Optional[int] = None
531-
"""Max generator input length for multi-turn conversations. For single-turn, set equal to ``max_prompt_length``."""
536+
"""Max input/context length allowed before each generation turn. For single-turn, set
537+
equal to ``max_prompt_length``. Distinct from ``sampling_params.max_generate_length``."""
532538
chat_template: ChatTemplateConfig = field(default_factory=ChatTemplateConfig)
533539
chat_template_kwargs: Dict[str, Any] = field(default_factory=dict)
534540
"""Kwargs passed to ``tokenizer.apply_chat_template``."""

skyrl/train/config/ppo_base_config.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ generator:
295295
n_samples_per_prompt: 5
296296
async_engine: true
297297
batched: false
298-
max_input_length: ${trainer.max_prompt_length} # max generator input length used for multi-turn conversations - for single turn set equal to max_prompt_length
298+
# Max input/context length checked before each generation turn. For single-turn, set equal to max_prompt_length.
299+
# This is distinct from sampling_params.max_generate_length, which budgets assistant-generated tokens.
300+
max_input_length: ${trainer.max_prompt_length}
299301
# VLLM_ENABLE_V1_MULTIPROCESSING=0 for reproducibility
300302
vllm_v1_disable_multiproc: true
301303
enable_prefix_caching: true
@@ -334,11 +336,14 @@ generator:
334336

335337
# Inference engine arguments. Arguments are passed directly to the vLLM engine, so names must match
336338
# the engine's args. To specify an engine arg in the CLI override, use the format: +generator.engine_init_kwargs.arg_name=value
339+
# If max_model_len is set, each rollout request's max_tokens is capped so prompt+completion fits this window.
337340
engine_init_kwargs: {}
338341

339342
override_existing_update_group: "auto" # "auto", "enable", "disable"
340343
# sampling params for generation phase
341344
sampling_params:
345+
# Trajectory-level assistant/generated-token budget. Multi-turn environment observations are loss-masked
346+
# and do not count against this value.
342347
max_generate_length: 1024
343348
repetition_penalty: 1.0
344349
temperature: 1.0
@@ -395,4 +400,4 @@ generator:
395400
environment:
396401
env_class: "gsm8k"
397402
# NOTE: environment specific defaults for environment.skyrl_gym are set at the following path:
398-
# skyrl_gym: config/skyrl_gym_config/default.yaml
403+
# skyrl_gym: config/skyrl_gym_config/default.yaml

skyrl/train/generators/skyrl_gym_generator.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import asyncio
99
import copy
1010
from concurrent.futures import ThreadPoolExecutor
11-
from dataclasses import asdict, dataclass
11+
from dataclasses import dataclass
1212
from typing import Any, Dict, List, Optional, Tuple, Union
1313
from uuid import uuid4
1414

@@ -33,9 +33,13 @@
3333
)
3434
from skyrl.train.generators.utils import (
3535
apply_overlong_filtering,
36+
compute_request_max_tokens,
3637
get_custom_chat_template,
3738
get_generation_prompt_ids,
39+
get_max_model_len,
3840
get_rollout_metrics,
41+
normalize_sampling_params,
42+
sampling_params_with_max_tokens,
3943
)
4044
from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput
4145

@@ -330,11 +334,12 @@ async def agent_loop(
330334
loss_mask = [] # this excludes the prompt
331335
rollout_logprobs = None
332336

333-
# `sampling_params` if provided is a dict in the format expected by the inference engine backend
334-
# we cast default config to a dict for consistency
335-
current_sampling_params: dict = (
336-
sampling_params if sampling_params is not None else asdict(self.generator_cfg.sampling_params)
337-
)
337+
# `sampling_params` if provided is a dict in the format expected by the inference engine backend.
338+
# When absent, normalize the config dataclass into the backend shape here so agent_loop() direct
339+
# callers receive the same behavior as the main generator path.
340+
base_sampling_params = normalize_sampling_params(self.generator_cfg, sampling_params)
341+
max_model_len = get_max_model_len(self.generator_cfg)
342+
generated_tokens_used = 0
338343

339344
# Accumulate per-step rewards. Format: (reward, response_end_token_idx)
340345
per_step_rewards: List[Tuple[float, Optional[int]]] = []
@@ -343,7 +348,7 @@ async def agent_loop(
343348

344349
agent_loop_output = StepWiseOutput(step_outputs=[]) if is_step_wise else None
345350

346-
get_logprobs = self.generator_cfg.sampling_params.logprobs is not None
351+
get_logprobs = base_sampling_params.get("logprobs", None) is not None
347352
agent_loop_state = AgentLoopState(
348353
chat_history=chat_history,
349354
input_ids=initial_input_ids,
@@ -352,6 +357,7 @@ async def agent_loop(
352357
response_end_idx=None,
353358
done=False,
354359
)
360+
new_obs: ConversationType = []
355361

356362
while not agent_loop_state.done:
357363

@@ -373,8 +379,20 @@ async def agent_loop(
373379
agent_loop_state.loss_mask = []
374380
agent_loop_state.rollout_logprobs = None
375381

382+
request_max_tokens = compute_request_max_tokens(
383+
max_tokens - generated_tokens_used,
384+
len(agent_loop_state.input_ids),
385+
max_model_len,
386+
)
387+
if request_max_tokens <= 0:
388+
stop_reason = "length"
389+
break
390+
391+
current_sampling_params = sampling_params_with_max_tokens(base_sampling_params, request_max_tokens)
376392
engine_input = InferenceEngineInput(
377-
prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params
393+
prompt_token_ids=[agent_loop_state.input_ids],
394+
session_ids=[session_id],
395+
sampling_params=current_sampling_params,
378396
)
379397
engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name)
380398
output = engine_output["responses"][0]
@@ -440,6 +458,22 @@ async def agent_loop(
440458
if turn_output.rollout_expert_indices is not None and agent_loop_state.rollout_expert_indices is None:
441459
agent_loop_state.rollout_expert_indices = []
442460

461+
turn_generated_tokens = sum(turn_output.get_turn_loss_mask())
462+
if not self.use_conversation_multi_turn and not retokenize_chat_history:
463+
new_resp_tokens = turn_output.output_ids.copy()
464+
if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id:
465+
new_resp_tokens = new_resp_tokens[:-1]
466+
turn_generated_tokens = len(new_resp_tokens)
467+
generated_tokens_used += turn_generated_tokens
468+
if generated_tokens_used >= max_tokens and not agent_loop_state.done:
469+
# The trajectory-level assistant budget is exhausted. Do not append the
470+
# next observation, since there will be no following generation request.
471+
stop_reason = "length"
472+
agent_loop_state.done = True
473+
new_obs = []
474+
turn_output.new_obs = []
475+
turn_output.obs_ids = []
476+
443477
if is_step_wise:
444478
# current response + observation ids
445479
turn_response_ids = turn_output.output_ids + turn_output.obs_ids
@@ -489,6 +523,23 @@ async def agent_loop(
489523
rollout_expert_indices_out = None
490524
response_ids = None
491525

526+
if not is_step_wise and not retokenize_chat_history and agent_loop_state.response_end_idx is None:
527+
agent_loop_output = TrajectoryOutput(
528+
response_ids=[],
529+
reward=[],
530+
stop_reason=stop_reason,
531+
loss_mask=[],
532+
prompt_ids=prompt_ids,
533+
rollout_logprobs=[] if get_logprobs else None,
534+
env_metrics=env_metrics,
535+
rollout_expert_indices=None,
536+
)
537+
return self._post_process_agent_loop_output(
538+
agent_loop_output,
539+
env_extras,
540+
trajectory_id,
541+
)
542+
492543
# Prepare the final loss_mask, response_ids and rollout_logprobs .
493544
# We remove the final observation messages /token IDs here
494545
# Note that during the agent loop, we still add the final observation messages/ tokens because we terminate the agent loop if the input length
@@ -531,17 +582,21 @@ async def agent_loop(
531582
if not self.use_conversation_multi_turn:
532583
assert response_ids is not None and loss_mask is not None
533584
if stop_reason != "length" and response_ids and response_ids[-1] != self.tokenizer.eos_token_id:
534-
response_ids.append(self.tokenizer.eos_token_id)
535-
# TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss
536-
# masked with 0, why bother adding it?
537-
loss_mask.append(1)
538-
if rollout_logprobs is not None:
539-
rollout_logprobs.append(0.0)
540-
if rollout_expert_indices_out is not None and rollout_expert_indices_out:
541-
layer_num = len(rollout_expert_indices_out[0])
542-
topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0
543-
rollout_expert_indices_out.append([[0] * topk for _ in range(layer_num)])
544-
appended_eos_token = True
585+
if generated_tokens_used < max_tokens:
586+
response_ids.append(self.tokenizer.eos_token_id)
587+
# TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss
588+
# masked with 0, why bother adding it?
589+
loss_mask.append(1)
590+
if rollout_logprobs is not None:
591+
rollout_logprobs.append(0.0)
592+
if rollout_expert_indices_out is not None and rollout_expert_indices_out:
593+
layer_num = len(rollout_expert_indices_out[0])
594+
topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0
595+
rollout_expert_indices_out.append([[0] * topk for _ in range(layer_num)])
596+
generated_tokens_used += 1
597+
appended_eos_token = True
598+
else:
599+
stop_reason = "length"
545600

546601
if self.generator_cfg.step_wise_trajectories:
547602
for per_step_output, (reward, resp_end_idx) in zip(agent_loop_output.step_outputs, per_step_rewards):
@@ -713,7 +768,10 @@ async def generate_batched(
713768
tokenize=True,
714769
return_dict=False,
715770
)
716-
engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
771+
request_sampling_params = sampling_params_with_max_tokens(
772+
normalize_sampling_params(self.generator_cfg, sampling_params), max_tokens
773+
)
774+
engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=request_sampling_params)
717775
engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name)
718776
outputs = engine_output["responses"]
719777
responses = engine_output["response_ids"]
@@ -788,7 +846,8 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False
788846
if self.generator_cfg.step_wise_trajectories:
789847
assert trajectory_ids is not None, "`trajectory_ids` is a required field for step wise training"
790848
sampling_params: Optional[dict] = input_batch.get("sampling_params", None)
791-
max_tokens = self.generator_cfg.sampling_params.max_generate_length
849+
base_sampling_params = normalize_sampling_params(self.generator_cfg, sampling_params)
850+
max_tokens = int(base_sampling_params.get("max_tokens", self.generator_cfg.sampling_params.max_generate_length))
792851
max_input_length = self.generator_cfg.max_input_length
793852

794853
if self.batched:

0 commit comments

Comments
 (0)