-
Notifications
You must be signed in to change notification settings - Fork 321
Add opt-in handling for failed SkyRLGym rollouts #1641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
taivu1998
wants to merge
1
commit into
NovaSky-AI:main
Choose a base branch
from
taivu1998:tdv/issue-1613-skip-failed-rollouts
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |||||||||||||||||||
| import copy | ||||||||||||||||||||
| from concurrent.futures import ThreadPoolExecutor | ||||||||||||||||||||
| from dataclasses import asdict, dataclass | ||||||||||||||||||||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||||||||||||||||||||
| from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union | ||||||||||||||||||||
| from uuid import uuid4 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import torch | ||||||||||||||||||||
|
|
@@ -32,6 +32,7 @@ | |||||||||||||||||||
| TrajectoryID, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| from skyrl.train.generators.utils import ( | ||||||||||||||||||||
| ROLLOUT_ERROR_STOP_REASON, | ||||||||||||||||||||
| apply_overlong_filtering, | ||||||||||||||||||||
| get_custom_chat_template, | ||||||||||||||||||||
| get_generation_prompt_ids, | ||||||||||||||||||||
|
|
@@ -169,7 +170,8 @@ def __init__( | |||||||||||||||||||
| self.generation_prompt_ids = get_generation_prompt_ids(tokenizer) if self.use_conversation_multi_turn else None | ||||||||||||||||||||
| if self.skyrl_gym_cfg.max_env_workers > 0: | ||||||||||||||||||||
| self.env_executor = ThreadPoolExecutor( | ||||||||||||||||||||
| max_workers=self.skyrl_gym_cfg.max_env_workers, thread_name_prefix="skyrl-gym-env-" | ||||||||||||||||||||
| max_workers=self.skyrl_gym_cfg.max_env_workers, | ||||||||||||||||||||
| thread_name_prefix="skyrl-gym-env-", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| self.env_executor = None | ||||||||||||||||||||
|
|
@@ -205,6 +207,8 @@ def _validate_cfg(self, generator_cfg: GeneratorConfig): | |||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| "`chat_template_kwargs` is not compatible with `batched=True` since the chat templating is handled by the inference engine" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if generator_cfg.skip_failed_rollouts and generator_cfg.batched: | ||||||||||||||||||||
| raise ValueError("`skip_failed_rollouts=True` is only supported with `batched=False`.") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if self.generator_cfg.step_wise_trajectories: | ||||||||||||||||||||
| if self.batched: | ||||||||||||||||||||
|
|
@@ -228,6 +232,80 @@ async def _run_in_executor_if_available(self, func, *args, **kwargs): | |||||||||||||||||||
| else: | ||||||||||||||||||||
| return func(*args, **kwargs) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def _close_env_after_exception(self, env, context: str): | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| await self._run_in_executor_if_available(env.close) | ||||||||||||||||||||
| except Exception as close_exc: | ||||||||||||||||||||
| logger.opt(exception=close_exc).warning( | ||||||||||||||||||||
| "Failed to close SkyRL-Gym environment after {} failure: {}", | ||||||||||||||||||||
| context, | ||||||||||||||||||||
| close_exc, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _uses_rollout_logprobs(self, sampling_params: Optional[dict]) -> bool: | ||||||||||||||||||||
| if sampling_params is not None: | ||||||||||||||||||||
| return sampling_params.get("logprobs", None) is not None | ||||||||||||||||||||
| return self.generator_cfg.sampling_params.logprobs is not None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _placeholder_token_id(self) -> int: | ||||||||||||||||||||
| for attr_name in ("eos_token_id", "pad_token_id"): | ||||||||||||||||||||
| token_id = getattr(self.tokenizer, attr_name, None) | ||||||||||||||||||||
| if token_id is not None: | ||||||||||||||||||||
| return token_id | ||||||||||||||||||||
| return 0 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _failed_rollout_placeholder(self, include_logprobs: bool) -> Union[TrajectoryOutput, StepWiseOutput]: | ||||||||||||||||||||
| token_id = self._placeholder_token_id() | ||||||||||||||||||||
| reward: Union[float, List[float]] = ( | ||||||||||||||||||||
| 0.0 if self.custom_chat_template and not self.generator_cfg.step_wise_trajectories else [0.0] | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| output = TrajectoryOutput( | ||||||||||||||||||||
| response_ids=[token_id], | ||||||||||||||||||||
| reward=reward, | ||||||||||||||||||||
| stop_reason=ROLLOUT_ERROR_STOP_REASON, | ||||||||||||||||||||
| loss_mask=[0], | ||||||||||||||||||||
| prompt_ids=[token_id], | ||||||||||||||||||||
| rollout_logprobs=[0.0] if include_logprobs else None, | ||||||||||||||||||||
| env_metrics={}, | ||||||||||||||||||||
| rollout_expert_indices=None, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if self.generator_cfg.step_wise_trajectories: | ||||||||||||||||||||
| return StepWiseOutput(step_outputs=[output]) | ||||||||||||||||||||
| return output | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def _safe_rollout( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| idx: int, | ||||||||||||||||||||
| env_class: str, | ||||||||||||||||||||
| trajectory_id: Optional[TrajectoryID], | ||||||||||||||||||||
| rollout: Awaitable[Union[TrajectoryOutput, StepWiseOutput]], | ||||||||||||||||||||
| include_logprobs: bool, | ||||||||||||||||||||
| ) -> Union[TrajectoryOutput, StepWiseOutput]: | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| return await rollout | ||||||||||||||||||||
| except asyncio.CancelledError: | ||||||||||||||||||||
| raise | ||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||
| trajectory = trajectory_id.to_string() if trajectory_id is not None else None | ||||||||||||||||||||
| logger.opt(exception=exc).warning( | ||||||||||||||||||||
| "SkyRLGym rollout {} failed for env_class={} trajectory_id={} with {}: {}; " | ||||||||||||||||||||
| "substituting zero-reward placeholder with stop_reason={}", | ||||||||||||||||||||
| idx, | ||||||||||||||||||||
| env_class, | ||||||||||||||||||||
| trajectory, | ||||||||||||||||||||
| type(exc).__name__, | ||||||||||||||||||||
| exc, | ||||||||||||||||||||
| ROLLOUT_ERROR_STOP_REASON, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| return self._failed_rollout_placeholder(include_logprobs=include_logprobs) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _normalize_optional_tensor_features(self, values: List[Optional[torch.Tensor]]) -> Optional[List[torch.Tensor]]: | ||||||||||||||||||||
| ref = next((value for value in values if value is not None), None) | ||||||||||||||||||||
| if ref is None: | ||||||||||||||||||||
| return None | ||||||||||||||||||||
| placeholder = torch.empty(0, *ref.shape[1:], dtype=ref.dtype, device=ref.device) | ||||||||||||||||||||
| return [value if value is not None else placeholder for value in values] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # ------------------------------------------------------------------ | ||||||||||||||||||||
| # Subclass hooks. Default implementations are no-ops so generic envs | ||||||||||||||||||||
| # see the upstream behavior; subclasses (e.g. RLMGymGenerator) override. | ||||||||||||||||||||
|
|
@@ -313,18 +391,26 @@ async def agent_loop( | |||||||||||||||||||
| chat_history = copy.deepcopy(prompt) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # init() returns the first prompt to be given to the model, and optional metadata dict | ||||||||||||||||||||
| chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) | ||||||||||||||||||||
| except Exception: | ||||||||||||||||||||
| await self._close_env_after_exception(env, "env.init") | ||||||||||||||||||||
| raise | ||||||||||||||||||||
| initial_chat_history_length = len(chat_history) | ||||||||||||||||||||
| initial_input_ids = self.tokenizer.apply_chat_template( | ||||||||||||||||||||
| chat_history, | ||||||||||||||||||||
| # If retokenize_chat_history==True, avoid including the generation prompt in both the | ||||||||||||||||||||
| # prompt_ids and response_ids due to how `response_encodings["input_ids"]` works. | ||||||||||||||||||||
| add_generation_prompt=not retokenize_chat_history, | ||||||||||||||||||||
| chat_template=self.custom_chat_template if retokenize_chat_history else None, | ||||||||||||||||||||
| tokenize=True, | ||||||||||||||||||||
| return_dict=False, | ||||||||||||||||||||
| **self.generator_cfg.chat_template_kwargs, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| initial_input_ids = self.tokenizer.apply_chat_template( | ||||||||||||||||||||
| chat_history, | ||||||||||||||||||||
| # If retokenize_chat_history==True, avoid including the generation prompt in both the | ||||||||||||||||||||
| # prompt_ids and response_ids due to how `response_encodings["input_ids"]` works. | ||||||||||||||||||||
| add_generation_prompt=not retokenize_chat_history, | ||||||||||||||||||||
| chat_template=self.custom_chat_template if retokenize_chat_history else None, | ||||||||||||||||||||
| tokenize=True, | ||||||||||||||||||||
| return_dict=False, | ||||||||||||||||||||
| **self.generator_cfg.chat_template_kwargs, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| except Exception: | ||||||||||||||||||||
| await self._close_env_after_exception(env, "initial chat templating") | ||||||||||||||||||||
| raise | ||||||||||||||||||||
|
|
||||||||||||||||||||
| initial_prompt_length = len(initial_input_ids) | ||||||||||||||||||||
| loss_mask = [] # this excludes the prompt | ||||||||||||||||||||
|
|
@@ -343,7 +429,7 @@ async def agent_loop( | |||||||||||||||||||
|
|
||||||||||||||||||||
| agent_loop_output = StepWiseOutput(step_outputs=[]) if is_step_wise else None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| get_logprobs = self.generator_cfg.sampling_params.logprobs is not None | ||||||||||||||||||||
| get_logprobs = self._uses_rollout_logprobs(sampling_params) | ||||||||||||||||||||
| agent_loop_state = AgentLoopState( | ||||||||||||||||||||
| chat_history=chat_history, | ||||||||||||||||||||
| input_ids=initial_input_ids, | ||||||||||||||||||||
|
|
@@ -354,7 +440,6 @@ async def agent_loop( | |||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| while not agent_loop_state.done: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if len(agent_loop_state.input_ids) > max_input_length: | ||||||||||||||||||||
| stop_reason = "length" | ||||||||||||||||||||
| break | ||||||||||||||||||||
|
|
@@ -374,9 +459,15 @@ async def agent_loop( | |||||||||||||||||||
| agent_loop_state.rollout_logprobs = None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| engine_input = InferenceEngineInput( | ||||||||||||||||||||
| prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params | ||||||||||||||||||||
| prompt_token_ids=[agent_loop_state.input_ids], | ||||||||||||||||||||
| session_ids=[session_id], | ||||||||||||||||||||
| sampling_params=sampling_params, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) | ||||||||||||||||||||
| except Exception: | ||||||||||||||||||||
| await self._close_env_after_exception(env, "inference generation") | ||||||||||||||||||||
| raise | ||||||||||||||||||||
| output = engine_output["responses"][0] | ||||||||||||||||||||
| output_ids = engine_output["response_ids"][0] | ||||||||||||||||||||
| stop_reason = engine_output["stop_reasons"][0] | ||||||||||||||||||||
|
|
@@ -408,7 +499,11 @@ async def agent_loop( | |||||||||||||||||||
| added_eos = True | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 2. Environment step | ||||||||||||||||||||
| env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) | ||||||||||||||||||||
| except Exception: | ||||||||||||||||||||
| await self._close_env_after_exception(env, "env.step") | ||||||||||||||||||||
| raise | ||||||||||||||||||||
| new_obs = env_step_output["observations"] | ||||||||||||||||||||
| step_reward: float = env_step_output["reward"] | ||||||||||||||||||||
| agent_loop_state.done = env_step_output["done"] | ||||||||||||||||||||
|
|
@@ -571,7 +666,10 @@ async def agent_loop( | |||||||||||||||||||
| return agent_loop_output | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _build_per_token_rewards( | ||||||||||||||||||||
| self, per_step_rewards: List[Tuple[float, Optional[int]]], response_ids: List[int], appended_eos_token: bool | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| per_step_rewards: List[Tuple[float, Optional[int]]], | ||||||||||||||||||||
| response_ids: List[int], | ||||||||||||||||||||
| appended_eos_token: bool, | ||||||||||||||||||||
| ) -> Union[float, List[float]]: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Build reward output from per-step rewards. | ||||||||||||||||||||
|
|
@@ -794,20 +892,24 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False | |||||||||||||||||||
| if self.batched: | ||||||||||||||||||||
| return await self.generate_batched(prompts, env_classes, env_extras, max_tokens, sampling_params) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| get_logprobs = self._uses_rollout_logprobs(sampling_params) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Async agent loop to generate trajectories in parallel. | ||||||||||||||||||||
| tasks = [] | ||||||||||||||||||||
| for i in range(len(prompts)): | ||||||||||||||||||||
| tasks.append( | ||||||||||||||||||||
| self.agent_loop( | ||||||||||||||||||||
| prompts[i], | ||||||||||||||||||||
| env_classes[i], | ||||||||||||||||||||
| env_extras[i], | ||||||||||||||||||||
| max_tokens, | ||||||||||||||||||||
| max_input_length, | ||||||||||||||||||||
| sampling_params=sampling_params, | ||||||||||||||||||||
| trajectory_id=trajectory_ids[i] if trajectory_ids is not None else None, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| trajectory_id = trajectory_ids[i] if trajectory_ids is not None else None | ||||||||||||||||||||
| rollout = self.agent_loop( | ||||||||||||||||||||
| prompts[i], | ||||||||||||||||||||
| env_classes[i], | ||||||||||||||||||||
| env_extras[i], | ||||||||||||||||||||
| max_tokens, | ||||||||||||||||||||
| max_input_length, | ||||||||||||||||||||
| sampling_params=sampling_params, | ||||||||||||||||||||
| trajectory_id=trajectory_id, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if self.generator_cfg.skip_failed_rollouts: | ||||||||||||||||||||
| rollout = self._safe_rollout(i, env_classes[i], trajectory_id, rollout, get_logprobs) | ||||||||||||||||||||
| tasks.append(rollout) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| all_outputs = await tqdm.gather( | ||||||||||||||||||||
| *tasks, | ||||||||||||||||||||
|
|
@@ -850,18 +952,15 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False | |||||||||||||||||||
| out_trajectory_ids = None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| has_vision_features = any(getattr(output, "pixel_values", None) is not None for output in all_outputs) | ||||||||||||||||||||
| pixel_values = ( | ||||||||||||||||||||
| [getattr(output, "pixel_values", None) for output in all_outputs] if has_vision_features else None | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| image_grid_thw = ( | ||||||||||||||||||||
| [getattr(output, "image_grid_thw", None) for output in all_outputs] if has_vision_features else None | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if sampling_params is not None: | ||||||||||||||||||||
| # sampling params will be a dict in the format of the inference engine backend | ||||||||||||||||||||
| get_logprobs = sampling_params.get("logprobs", None) is not None | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| get_logprobs = self.generator_cfg.sampling_params.logprobs is not None | ||||||||||||||||||||
| pixel_values = None | ||||||||||||||||||||
| image_grid_thw = None | ||||||||||||||||||||
| if has_vision_features: | ||||||||||||||||||||
| pixel_values = self._normalize_optional_tensor_features( | ||||||||||||||||||||
| [getattr(output, "pixel_values", None) for output in all_outputs] | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| image_grid_thw = self._normalize_optional_tensor_features( | ||||||||||||||||||||
| [getattr(output, "image_grid_thw", None) for output in all_outputs] | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if get_logprobs: | ||||||||||||||||||||
| if self.generator_cfg.step_wise_trajectories: | ||||||||||||||||||||
|
|
@@ -880,6 +979,14 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False | |||||||||||||||||||
| rollout_expert_indices = None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes, loss_masks) | ||||||||||||||||||||
| if self.generator_cfg.skip_failed_rollouts: | ||||||||||||||||||||
| num_rollout_errors = sum(reason == ROLLOUT_ERROR_STOP_REASON for reason in stop_reasons) | ||||||||||||||||||||
| rollout_metrics["generate/num_rollout_errors"] = num_rollout_errors | ||||||||||||||||||||
| rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(stop_reasons) | ||||||||||||||||||||
| if num_rollout_errors == len(stop_reasons): | ||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||
| "All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
+986
to
+989
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| if self.generator_cfg.zero_reward_on_non_stop: | ||||||||||||||||||||
| # set reward to 0 if the stop reason is not "stop" | ||||||||||||||||||||
|
|
||||||||||||||||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The collection of vision features here does not account for
StepWiseOutputwhenstep_wise_trajectories=True. In step-wise mode,outputis aStepWiseOutputobject which does not have apixel_valuesattribute; instead, these features are stored within the individualTrajectoryOutputobjects inoutput.step_outputs. Consequently, vision features will be lost during flattening. Additionally, the detection logic on context line 954 will fail to identify vision features in step-wise mode for the same reason.