Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ generator:

zero_reward_on_non_stop: false

skip_failed_rollouts: false

apply_overlong_filtering: false

```
Expand Down Expand Up @@ -747,5 +749,6 @@ For more details on how different placement options work, please refer to the [p
### Misc Configuration

- `generator.zero_reward_on_non_stop`: Whether to set the reward to 0 if the `stop_reason` is not `stop`. Cases where this is useful: Often, we have format rewards for the LLM to follow, but in cases where the LLM didn't finish the response, we typically don't want to reward it. This is a general setting for all environments.
- `generator.skip_failed_rollouts`: Whether to skip individual failed non-batched rollouts by replacing each failed row with a zero-reward, loss-masked placeholder whose `stop_reason` is `rollout_error`. This catches normal rollout exceptions only; cancellations and interrupts still stop the training step.
- `generator.apply_overlong_filtering`: Whether to apply DAPO Overlong Filtering to the loss masks. For each trajectory that exceeds the max length (i.e., truncated and does not end with an EOS token), this masks out every token in the loss mask.
- `generator.step_wise_trajectories`: Whether to return outputs in a step-wise fashion. If `true`, then the generator will return multi-turn generations with the (prompt, response) pair of each turn being a separate trajectory. Advantages are computed based on the last step of each trajectory and propagated to the previous steps.
2 changes: 2 additions & 0 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,8 @@ class GeneratorConfig(BaseConfig):
eval_n_samples_per_prompt: int = 1
zero_reward_on_non_stop: bool = False
"""Set reward to 0 when ``stop_reason`` is not ``"stop"`` (i.e., generation was truncated or aborted)."""
skip_failed_rollouts: bool = False
"""Replace failed non-batched rollouts with zero-reward, loss-masked placeholders."""
apply_overlong_filtering: bool = False
"""Apply DAPO Overlong Filtering: mask out all tokens in the loss mask for trajectories that
exceed max length (truncated, no EOS token)."""
Expand Down
7 changes: 6 additions & 1 deletion skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ generator:
# TODO (erictang000): Show clear ablations for benefits of this on GSM8K or SQL.
zero_reward_on_non_stop: false

# Whether to skip individual failed rollouts by substituting zero-reward,
# loss-masked placeholder rows with stop_reason="rollout_error".
# This is only supported for non-batched generation.
skip_failed_rollouts: false

# Whether to apply DAPO Overlong Filtering to the loss masks.
# For each trajectory that exceeds the max length (i.e., truncated and does not end with an
# EOS token), this masks out every token in the loss mask.
Expand All @@ -395,4 +400,4 @@ generator:
environment:
env_class: "gsm8k"
# NOTE: environment specific defaults for environment.skyrl_gym are set at the following path:
# skyrl_gym: config/skyrl_gym_config/default.yaml
# skyrl_gym: config/skyrl_gym_config/default.yaml
189 changes: 148 additions & 41 deletions skyrl/train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import copy
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from uuid import uuid4

import torch
Expand All @@ -32,6 +32,7 @@
TrajectoryID,
)
from skyrl.train.generators.utils import (
ROLLOUT_ERROR_STOP_REASON,
apply_overlong_filtering,
get_custom_chat_template,
get_generation_prompt_ids,
Expand Down Expand Up @@ -169,7 +170,8 @@ def __init__(
self.generation_prompt_ids = get_generation_prompt_ids(tokenizer) if self.use_conversation_multi_turn else None
if self.skyrl_gym_cfg.max_env_workers > 0:
self.env_executor = ThreadPoolExecutor(
max_workers=self.skyrl_gym_cfg.max_env_workers, thread_name_prefix="skyrl-gym-env-"
max_workers=self.skyrl_gym_cfg.max_env_workers,
thread_name_prefix="skyrl-gym-env-",
)
else:
self.env_executor = None
Expand Down Expand Up @@ -205,6 +207,8 @@ def _validate_cfg(self, generator_cfg: GeneratorConfig):
raise ValueError(
"`chat_template_kwargs` is not compatible with `batched=True` since the chat templating is handled by the inference engine"
)
if generator_cfg.skip_failed_rollouts and generator_cfg.batched:
raise ValueError("`skip_failed_rollouts=True` is only supported with `batched=False`.")

if self.generator_cfg.step_wise_trajectories:
if self.batched:
Expand All @@ -228,6 +232,80 @@ async def _run_in_executor_if_available(self, func, *args, **kwargs):
else:
return func(*args, **kwargs)

async def _close_env_after_exception(self, env, context: str):
try:
await self._run_in_executor_if_available(env.close)
except Exception as close_exc:
logger.opt(exception=close_exc).warning(
"Failed to close SkyRL-Gym environment after {} failure: {}",
context,
close_exc,
)

def _uses_rollout_logprobs(self, sampling_params: Optional[dict]) -> bool:
if sampling_params is not None:
return sampling_params.get("logprobs", None) is not None
return self.generator_cfg.sampling_params.logprobs is not None

def _placeholder_token_id(self) -> int:
for attr_name in ("eos_token_id", "pad_token_id"):
token_id = getattr(self.tokenizer, attr_name, None)
if token_id is not None:
return token_id
return 0

def _failed_rollout_placeholder(self, include_logprobs: bool) -> Union[TrajectoryOutput, StepWiseOutput]:
token_id = self._placeholder_token_id()
reward: Union[float, List[float]] = (
0.0 if self.custom_chat_template and not self.generator_cfg.step_wise_trajectories else [0.0]
)
output = TrajectoryOutput(
response_ids=[token_id],
reward=reward,
stop_reason=ROLLOUT_ERROR_STOP_REASON,
loss_mask=[0],
prompt_ids=[token_id],
rollout_logprobs=[0.0] if include_logprobs else None,
env_metrics={},
rollout_expert_indices=None,
)
if self.generator_cfg.step_wise_trajectories:
return StepWiseOutput(step_outputs=[output])
return output

async def _safe_rollout(
self,
idx: int,
env_class: str,
trajectory_id: Optional[TrajectoryID],
rollout: Awaitable[Union[TrajectoryOutput, StepWiseOutput]],
include_logprobs: bool,
) -> Union[TrajectoryOutput, StepWiseOutput]:
try:
return await rollout
except asyncio.CancelledError:
raise
except Exception as exc:
trajectory = trajectory_id.to_string() if trajectory_id is not None else None
logger.opt(exception=exc).warning(
"SkyRLGym rollout {} failed for env_class={} trajectory_id={} with {}: {}; "
"substituting zero-reward placeholder with stop_reason={}",
idx,
env_class,
trajectory,
type(exc).__name__,
exc,
ROLLOUT_ERROR_STOP_REASON,
)
return self._failed_rollout_placeholder(include_logprobs=include_logprobs)

def _normalize_optional_tensor_features(self, values: List[Optional[torch.Tensor]]) -> Optional[List[torch.Tensor]]:
ref = next((value for value in values if value is not None), None)
if ref is None:
return None
placeholder = torch.empty(0, *ref.shape[1:], dtype=ref.dtype, device=ref.device)
return [value if value is not None else placeholder for value in values]

# ------------------------------------------------------------------
# Subclass hooks. Default implementations are no-ops so generic envs
# see the upstream behavior; subclasses (e.g. RLMGymGenerator) override.
Expand Down Expand Up @@ -313,18 +391,26 @@ async def agent_loop(
chat_history = copy.deepcopy(prompt)

# init() returns the first prompt to be given to the model, and optional metadata dict
chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history)
try:
chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history)
except Exception:
await self._close_env_after_exception(env, "env.init")
raise
initial_chat_history_length = len(chat_history)
initial_input_ids = self.tokenizer.apply_chat_template(
chat_history,
# If retokenize_chat_history==True, avoid including the generation prompt in both the
# prompt_ids and response_ids due to how `response_encodings["input_ids"]` works.
add_generation_prompt=not retokenize_chat_history,
chat_template=self.custom_chat_template if retokenize_chat_history else None,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)
try:
initial_input_ids = self.tokenizer.apply_chat_template(
chat_history,
# If retokenize_chat_history==True, avoid including the generation prompt in both the
# prompt_ids and response_ids due to how `response_encodings["input_ids"]` works.
add_generation_prompt=not retokenize_chat_history,
chat_template=self.custom_chat_template if retokenize_chat_history else None,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)
except Exception:
await self._close_env_after_exception(env, "initial chat templating")
raise

initial_prompt_length = len(initial_input_ids)
loss_mask = [] # this excludes the prompt
Expand All @@ -343,7 +429,7 @@ async def agent_loop(

agent_loop_output = StepWiseOutput(step_outputs=[]) if is_step_wise else None

get_logprobs = self.generator_cfg.sampling_params.logprobs is not None
get_logprobs = self._uses_rollout_logprobs(sampling_params)
agent_loop_state = AgentLoopState(
chat_history=chat_history,
input_ids=initial_input_ids,
Expand All @@ -354,7 +440,6 @@ async def agent_loop(
)

while not agent_loop_state.done:

if len(agent_loop_state.input_ids) > max_input_length:
stop_reason = "length"
break
Expand All @@ -374,9 +459,15 @@ async def agent_loop(
agent_loop_state.rollout_logprobs = None

engine_input = InferenceEngineInput(
prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params
prompt_token_ids=[agent_loop_state.input_ids],
session_ids=[session_id],
sampling_params=sampling_params,
)
engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name)
try:
engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name)
except Exception:
await self._close_env_after_exception(env, "inference generation")
raise
output = engine_output["responses"][0]
output_ids = engine_output["response_ids"][0]
stop_reason = engine_output["stop_reasons"][0]
Expand Down Expand Up @@ -408,7 +499,11 @@ async def agent_loop(
added_eos = True

# 2. Environment step
env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output)
try:
env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output)
except Exception:
await self._close_env_after_exception(env, "env.step")
raise
new_obs = env_step_output["observations"]
step_reward: float = env_step_output["reward"]
agent_loop_state.done = env_step_output["done"]
Expand Down Expand Up @@ -571,7 +666,10 @@ async def agent_loop(
return agent_loop_output

def _build_per_token_rewards(
self, per_step_rewards: List[Tuple[float, Optional[int]]], response_ids: List[int], appended_eos_token: bool
self,
per_step_rewards: List[Tuple[float, Optional[int]]],
response_ids: List[int],
appended_eos_token: bool,
) -> Union[float, List[float]]:
"""
Build reward output from per-step rewards.
Expand Down Expand Up @@ -794,20 +892,24 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False
if self.batched:
return await self.generate_batched(prompts, env_classes, env_extras, max_tokens, sampling_params)

get_logprobs = self._uses_rollout_logprobs(sampling_params)

# Async agent loop to generate trajectories in parallel.
tasks = []
for i in range(len(prompts)):
tasks.append(
self.agent_loop(
prompts[i],
env_classes[i],
env_extras[i],
max_tokens,
max_input_length,
sampling_params=sampling_params,
trajectory_id=trajectory_ids[i] if trajectory_ids is not None else None,
)
trajectory_id = trajectory_ids[i] if trajectory_ids is not None else None
rollout = self.agent_loop(
prompts[i],
env_classes[i],
env_extras[i],
max_tokens,
max_input_length,
sampling_params=sampling_params,
trajectory_id=trajectory_id,
)
if self.generator_cfg.skip_failed_rollouts:
rollout = self._safe_rollout(i, env_classes[i], trajectory_id, rollout, get_logprobs)
tasks.append(rollout)

all_outputs = await tqdm.gather(
*tasks,
Expand Down Expand Up @@ -850,18 +952,15 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False
out_trajectory_ids = None

has_vision_features = any(getattr(output, "pixel_values", None) is not None for output in all_outputs)
pixel_values = (
[getattr(output, "pixel_values", None) for output in all_outputs] if has_vision_features else None
)
image_grid_thw = (
[getattr(output, "image_grid_thw", None) for output in all_outputs] if has_vision_features else None
)

if sampling_params is not None:
# sampling params will be a dict in the format of the inference engine backend
get_logprobs = sampling_params.get("logprobs", None) is not None
else:
get_logprobs = self.generator_cfg.sampling_params.logprobs is not None
pixel_values = None
image_grid_thw = None
if has_vision_features:
pixel_values = self._normalize_optional_tensor_features(
[getattr(output, "pixel_values", None) for output in all_outputs]
)
image_grid_thw = self._normalize_optional_tensor_features(
[getattr(output, "image_grid_thw", None) for output in all_outputs]
)
Comment on lines +958 to +963
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The collection of vision features here does not account for StepWiseOutput when step_wise_trajectories=True. In step-wise mode, output is a StepWiseOutput object which does not have a pixel_values attribute; instead, these features are stored within the individual TrajectoryOutput objects in output.step_outputs. Consequently, vision features will be lost during flattening. Additionally, the detection logic on context line 954 will fail to identify vision features in step-wise mode for the same reason.


if get_logprobs:
if self.generator_cfg.step_wise_trajectories:
Expand All @@ -880,6 +979,14 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False
rollout_expert_indices = None

rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes, loss_masks)
if self.generator_cfg.skip_failed_rollouts:
num_rollout_errors = sum(reason == ROLLOUT_ERROR_STOP_REASON for reason in stop_reasons)
rollout_metrics["generate/num_rollout_errors"] = num_rollout_errors
rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(stop_reasons)
if num_rollout_errors == len(stop_reasons):
logger.warning(
"All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders."
)
Comment on lines +986 to +989
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Potential ZeroDivisionError if stop_reasons is empty. Although batches are typically non-empty, it's safer to guard against this, especially since an empty batch would also trigger the "All rollouts failed" warning incorrectly.

Suggested change
if num_rollout_errors == len(stop_reasons):
logger.warning(
"All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders."
)
rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(stop_reasons) if stop_reasons else 0.0
if stop_reasons and num_rollout_errors == len(stop_reasons):
logger.warning(
"All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders."
)


if self.generator_cfg.zero_reward_on_non_stop:
# set reward to 0 if the stop reason is not "stop"
Expand Down
Loading
Loading