Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/train/mini_swe_agent/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ For issues with SkyRL or the Mini-SWE-Agent integration, please [open an Issue](

### Common Issues

- **Context length errors**: If you see `ValueError: The decoder prompt (length xxxx) is longer than the maximum model length`, increase `max_input_length` and `max_generate_length` or reduce steps in `swebench.yaml`.
- **Context length errors**: If you see `ValueError: The decoder prompt (length xxxx) is longer than the maximum model length`, increase the vLLM `engine_init_kwargs.max_model_len`, reduce `max_input_length`, or reduce steps in `swebench.yaml`. `max_generate_length` is the assistant-token budget for a trajectory and does not increase the model context window.

- **All zero rewards**: If rewards are consistently zero, the task may be too difficult. Consider:
- Filtering data for a better mix of easy/hard samples
Expand Down
57 changes: 39 additions & 18 deletions examples/train/mini_swe_agent/mini_swe_generator.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Optional, Any, Tuple
import yaml
import traceback
import ray
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from minisweagent.models import get_model
import ray
import yaml
from minisweagent.agents.default import DefaultAgent
from minisweagent.run.utils.save import save_traj
from minisweagent.config import get_config_path
from .mini_swe_utils import evaluate_trajectory, get_sb_environment
from minisweagent.models import get_model
from minisweagent.run.utils.save import save_traj

from skyrl.train.config import GeneratorConfig, SkyRLGymConfig
from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator, GeneratorOutput, GeneratorInput
from skyrl.train.generators.base import TrajectoryID, TrainingPhase, BatchMetadata
from skyrl.backends.skyrl_train.inference_engines.base import ConversationType
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import (
InferenceEngineClient,
)
from skyrl.backends.skyrl_train.inference_engines.utils import (
get_sampling_params_for_backend,
)
from skyrl.train.config import GeneratorConfig, SkyRLGymConfig
from skyrl.train.generators.base import BatchMetadata, TrainingPhase, TrajectoryID
from skyrl.train.generators.skyrl_gym_generator import (
GeneratorInput,
GeneratorOutput,
SkyRLGymGenerator,
)
from skyrl.train.generators.utils import (
get_rollout_metrics,
get_response_ids_and_loss_mask_from_messages,
get_rollout_metrics,
)

from .mini_swe_utils import evaluate_trajectory, get_sb_environment


@dataclass
class MiniSWEGeneratorConfig(GeneratorConfig):
Expand Down Expand Up @@ -199,15 +208,27 @@ async def minisweagent_agent_loop(
# Extract prompt ids
prompt_ids = initial_input_ids

# Calculate maximum response tokens allowed
max_response_tokens = max_tokens + max_input_length - initial_prompt_length
# Truncate by assistant-token budget first. Environment/user observations are kept only
# insofar as they fit the secondary packed-sequence guard below; they do not consume
# max_generate_length because their loss mask is 0.
assistant_tokens = 0
assistant_budget_response_tokens = len(response_ids)
assistant_budget_exceeded = False
for idx, mask in enumerate(loss_mask):
assistant_tokens += int(bool(mask))
if assistant_tokens > max_tokens:
assistant_budget_response_tokens = idx
assistant_budget_exceeded = True
break

# Keep the packed prompt+response sequence bounded for training tensor sizes.
packed_response_tokens = max(0, max_tokens + max_input_length - initial_prompt_length)
max_response_tokens = min(assistant_budget_response_tokens, packed_response_tokens)

# Determine stop reason
stop_reason = "complete" # Default for trial completion
if len(response_ids) > max_response_tokens:
if assistant_budget_exceeded or len(response_ids) > packed_response_tokens:
stop_reason = "length"

# Truncate to maximum allowed length
response_ids = response_ids[:max_response_tokens]
loss_mask = loss_mask[:max_response_tokens]

Expand Down
10 changes: 8 additions & 2 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ class FullyAsyncConfig(BaseConfig):
@dataclass
class SamplingParams(BaseConfig):
max_generate_length: int = 1024
"""Trajectory-level assistant/generated-token budget. In multi-turn generators,
environment observation tokens are loss-masked and do not count against this budget.
The vLLM request field is ``max_tokens`` and may be reduced per turn to fit context."""
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
Expand Down Expand Up @@ -496,7 +499,9 @@ class InferenceEngineConfig(BaseConfig):
"""When True, pass ``language_model_only=True`` to the vLLM engine so that
multimodal models (e.g. Qwen3.5) skip vision encoder initialization."""
engine_init_kwargs: Dict[str, Any] = field(default_factory=dict)
"""Pass-through kwargs for the vLLM engine. Names must match the engine's args."""
"""Pass-through kwargs for the vLLM engine. Names must match the engine's args. If
``max_model_len`` is set, rollout requests are capped so input tokens plus per-request
generated tokens fit within that window."""
override_existing_update_group: str = "auto"
"""``"auto"``, ``"enable"``, or ``"disable"``."""
external_proxy_url: Optional[str] = None
Expand Down Expand Up @@ -528,7 +533,8 @@ class GeneratorConfig(BaseConfig):
batched: bool = False
max_turns: int = 1
max_input_length: Optional[int] = None
"""Max generator input length for multi-turn conversations. For single-turn, set equal to ``max_prompt_length``."""
"""Max input/context length allowed before each generation turn. For single-turn, set
equal to ``max_prompt_length``. Distinct from ``sampling_params.max_generate_length``."""
chat_template: ChatTemplateConfig = field(default_factory=ChatTemplateConfig)
chat_template_kwargs: Dict[str, Any] = field(default_factory=dict)
"""Kwargs passed to ``tokenizer.apply_chat_template``."""
Expand Down
9 changes: 7 additions & 2 deletions skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ generator:
n_samples_per_prompt: 5
async_engine: true
batched: false
max_input_length: ${trainer.max_prompt_length} # max generator input length used for multi-turn conversations - for single turn set equal to max_prompt_length
# Max input/context length checked before each generation turn. For single-turn, set equal to max_prompt_length.
# This is distinct from sampling_params.max_generate_length, which budgets assistant-generated tokens.
max_input_length: ${trainer.max_prompt_length}
# VLLM_ENABLE_V1_MULTIPROCESSING=0 for reproducibility
vllm_v1_disable_multiproc: true
enable_prefix_caching: true
Expand Down Expand Up @@ -334,11 +336,14 @@ generator:

# Inference engine arguments. Arguments are passed directly to the vLLM engine, so names must match
# the engine's args. To specify an engine arg in the CLI override, use the format: +generator.engine_init_kwargs.arg_name=value
# If max_model_len is set, each rollout request's max_tokens is capped so prompt+completion fits this window.
engine_init_kwargs: {}

override_existing_update_group: "auto" # "auto", "enable", "disable"
# sampling params for generation phase
sampling_params:
# Trajectory-level assistant/generated-token budget. Multi-turn environment observations are loss-masked
# and do not count against this value.
max_generate_length: 1024
repetition_penalty: 1.0
temperature: 1.0
Expand Down Expand Up @@ -395,4 +400,4 @@ generator:
environment:
env_class: "gsm8k"
# NOTE: environment specific defaults for environment.skyrl_gym are set at the following path:
# skyrl_gym: config/skyrl_gym_config/default.yaml
# skyrl_gym: config/skyrl_gym_config/default.yaml
101 changes: 80 additions & 21 deletions skyrl/train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import asyncio
import copy
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import uuid4

Expand All @@ -33,9 +33,13 @@
)
from skyrl.train.generators.utils import (
apply_overlong_filtering,
compute_request_max_tokens,
get_custom_chat_template,
get_generation_prompt_ids,
get_max_model_len,
get_rollout_metrics,
normalize_sampling_params,
sampling_params_with_max_tokens,
)
from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput

Expand Down Expand Up @@ -330,11 +334,12 @@ async def agent_loop(
loss_mask = [] # this excludes the prompt
rollout_logprobs = None

# `sampling_params` if provided is a dict in the format expected by the inference engine backend
# we cast default config to a dict for consistency
current_sampling_params: dict = (
sampling_params if sampling_params is not None else asdict(self.generator_cfg.sampling_params)
)
# `sampling_params` if provided is a dict in the format expected by the inference engine backend.
# When absent, normalize the config dataclass into the backend shape here so agent_loop() direct
# callers receive the same behavior as the main generator path.
base_sampling_params = normalize_sampling_params(self.generator_cfg, sampling_params)
max_model_len = get_max_model_len(self.generator_cfg)
generated_tokens_used = 0

# Accumulate per-step rewards. Format: (reward, response_end_token_idx)
per_step_rewards: List[Tuple[float, Optional[int]]] = []
Expand All @@ -343,7 +348,7 @@ async def agent_loop(

agent_loop_output = StepWiseOutput(step_outputs=[]) if is_step_wise else None

get_logprobs = self.generator_cfg.sampling_params.logprobs is not None
get_logprobs = base_sampling_params.get("logprobs", None) is not None
agent_loop_state = AgentLoopState(
chat_history=chat_history,
input_ids=initial_input_ids,
Expand All @@ -352,6 +357,7 @@ async def agent_loop(
response_end_idx=None,
done=False,
)
new_obs: ConversationType = []

while not agent_loop_state.done:

Expand All @@ -373,8 +379,20 @@ async def agent_loop(
agent_loop_state.loss_mask = []
agent_loop_state.rollout_logprobs = None

request_max_tokens = compute_request_max_tokens(
max_tokens - generated_tokens_used,
len(agent_loop_state.input_ids),
max_model_len,
)
if request_max_tokens <= 0:
stop_reason = "length"
break

current_sampling_params = sampling_params_with_max_tokens(base_sampling_params, request_max_tokens)
engine_input = InferenceEngineInput(
prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params
prompt_token_ids=[agent_loop_state.input_ids],
session_ids=[session_id],
sampling_params=current_sampling_params,
)
engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name)
output = engine_output["responses"][0]
Expand Down Expand Up @@ -440,6 +458,22 @@ async def agent_loop(
if turn_output.rollout_expert_indices is not None and agent_loop_state.rollout_expert_indices is None:
agent_loop_state.rollout_expert_indices = []

turn_generated_tokens = sum(turn_output.get_turn_loss_mask())
if not self.use_conversation_multi_turn and not retokenize_chat_history:
new_resp_tokens = turn_output.output_ids.copy()
if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id:
new_resp_tokens = new_resp_tokens[:-1]
turn_generated_tokens = len(new_resp_tokens)
Comment on lines +461 to +466
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating turn_generated_tokens can be simplified to avoid an unnecessary list copy and slicing when use_conversation_multi_turn is False. Additionally, moving the sum call into an else block avoids redundant calculation in that scenario.

Suggested change
turn_generated_tokens = sum(turn_output.get_turn_loss_mask())
if not self.use_conversation_multi_turn and not retokenize_chat_history:
new_resp_tokens = turn_output.output_ids.copy()
if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id:
new_resp_tokens = new_resp_tokens[:-1]
turn_generated_tokens = len(new_resp_tokens)
if not self.use_conversation_multi_turn and not retokenize_chat_history:
turn_generated_tokens = len(turn_output.output_ids)
if turn_output.output_ids and turn_output.output_ids[-1] == self.tokenizer.eos_token_id:
turn_generated_tokens -= 1
else:
turn_generated_tokens = sum(turn_output.get_turn_loss_mask())

generated_tokens_used += turn_generated_tokens
if generated_tokens_used >= max_tokens and not agent_loop_state.done:
# The trajectory-level assistant budget is exhausted. Do not append the
# next observation, since there will be no following generation request.
stop_reason = "length"
agent_loop_state.done = True
new_obs = []
turn_output.new_obs = []
turn_output.obs_ids = []

if is_step_wise:
# current response + observation ids
turn_response_ids = turn_output.output_ids + turn_output.obs_ids
Expand Down Expand Up @@ -489,6 +523,23 @@ async def agent_loop(
rollout_expert_indices_out = None
response_ids = None

if not is_step_wise and not retokenize_chat_history and agent_loop_state.response_end_idx is None:
agent_loop_output = TrajectoryOutput(
response_ids=[],
reward=[],
stop_reason=stop_reason,
loss_mask=[],
prompt_ids=prompt_ids,
rollout_logprobs=[] if get_logprobs else None,
env_metrics=env_metrics,
rollout_expert_indices=None,
)
return self._post_process_agent_loop_output(
agent_loop_output,
env_extras,
trajectory_id,
)

# Prepare the final loss_mask, response_ids and rollout_logprobs .
# We remove the final observation messages /token IDs here
# Note that during the agent loop, we still add the final observation messages/ tokens because we terminate the agent loop if the input length
Expand Down Expand Up @@ -531,17 +582,21 @@ async def agent_loop(
if not self.use_conversation_multi_turn:
assert response_ids is not None and loss_mask is not None
if stop_reason != "length" and response_ids and response_ids[-1] != self.tokenizer.eos_token_id:
response_ids.append(self.tokenizer.eos_token_id)
# TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss
# masked with 0, why bother adding it?
loss_mask.append(1)
if rollout_logprobs is not None:
rollout_logprobs.append(0.0)
if rollout_expert_indices_out is not None and rollout_expert_indices_out:
layer_num = len(rollout_expert_indices_out[0])
topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0
rollout_expert_indices_out.append([[0] * topk for _ in range(layer_num)])
appended_eos_token = True
if generated_tokens_used < max_tokens:
response_ids.append(self.tokenizer.eos_token_id)
# TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss
# masked with 0, why bother adding it?
loss_mask.append(1)
if rollout_logprobs is not None:
rollout_logprobs.append(0.0)
if rollout_expert_indices_out is not None and rollout_expert_indices_out:
layer_num = len(rollout_expert_indices_out[0])
topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0
rollout_expert_indices_out.append([[0] * topk for _ in range(layer_num)])
generated_tokens_used += 1
appended_eos_token = True
else:
stop_reason = "length"

if self.generator_cfg.step_wise_trajectories:
for per_step_output, (reward, resp_end_idx) in zip(agent_loop_output.step_outputs, per_step_rewards):
Expand Down Expand Up @@ -713,7 +768,10 @@ async def generate_batched(
tokenize=True,
return_dict=False,
)
engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
request_sampling_params = sampling_params_with_max_tokens(
normalize_sampling_params(self.generator_cfg, sampling_params), max_tokens
)
engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=request_sampling_params)
engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name)
outputs = engine_output["responses"]
responses = engine_output["response_ids"]
Expand Down Expand Up @@ -788,7 +846,8 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False
if self.generator_cfg.step_wise_trajectories:
assert trajectory_ids is not None, "`trajectory_ids` is a required field for step wise training"
sampling_params: Optional[dict] = input_batch.get("sampling_params", None)
max_tokens = self.generator_cfg.sampling_params.max_generate_length
base_sampling_params = normalize_sampling_params(self.generator_cfg, sampling_params)
max_tokens = int(base_sampling_params.get("max_tokens", self.generator_cfg.sampling_params.max_generate_length))
max_input_length = self.generator_cfg.max_input_length

if self.batched:
Expand Down
Loading
Loading