Add opt-in handling for failed SkyRLGym rollouts#1641
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the skip_failed_rollouts feature for non-batched generation, which replaces failed rollouts with zero-reward, loss-masked placeholders to prevent training interruptions. The implementation includes enhanced error handling across SkyRLGymGenerator and SkyRLVLMGymGenerator to ensure environments are properly closed after exceptions, along with new metrics for tracking rollout error rates. Review feedback identifies a high-severity issue where vision features are lost during step-wise trajectories and suggests adding guards against potential ZeroDivisionError when calculating error metrics for empty batches.
| pixel_values = self._normalize_optional_tensor_features( | ||
| [getattr(output, "pixel_values", None) for output in all_outputs] | ||
| ) | ||
| image_grid_thw = self._normalize_optional_tensor_features( | ||
| [getattr(output, "image_grid_thw", None) for output in all_outputs] | ||
| ) |
There was a problem hiding this comment.
The collection of vision features here does not account for StepWiseOutput when step_wise_trajectories=True. In step-wise mode, output is a StepWiseOutput object which does not have a pixel_values attribute; instead, these features are stored within the individual TrajectoryOutput objects in output.step_outputs. Consequently, vision features will be lost during flattening. Additionally, the detection logic on context line 954 will fail to identify vision features in step-wise mode for the same reason.
| if num_rollout_errors == len(stop_reasons): | ||
| logger.warning( | ||
| "All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders." | ||
| ) |
There was a problem hiding this comment.
Potential ZeroDivisionError if stop_reasons is empty. Although batches are typically non-empty, it's safer to guard against this, especially since an empty batch would also trigger the "All rollouts failed" warning incorrectly.
| if num_rollout_errors == len(stop_reasons): | |
| logger.warning( | |
| "All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders." | |
| ) | |
| rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(stop_reasons) if stop_reasons else 0.0 | |
| if stop_reasons and num_rollout_errors == len(stop_reasons): | |
| logger.warning( | |
| "All SkyRLGym rollouts in this batch failed and were replaced with loss-masked placeholders." | |
| ) |
| if result.get("stop_reasons") is not None and has_rollout_error_metric: | ||
| num_rollout_errors = sum(reason == ROLLOUT_ERROR_STOP_REASON for reason in result["stop_reasons"]) | ||
| rollout_metrics["generate/num_rollout_errors"] = num_rollout_errors | ||
| rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(result["stop_reasons"]) |
There was a problem hiding this comment.
Potential ZeroDivisionError if result["stop_reasons"] is empty. Guarding against zero length ensures robustness for empty generator outputs.
| rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(result["stop_reasons"]) | |
| rollout_metrics["generate/rollout_error_rate"] = num_rollout_errors / len(result["stop_reasons"]) if result["stop_reasons"] else 0.0 |
Summary
Adds an opt-in
generator.skip_failed_rolloutsflag for non-batchedSkyRLGymGeneratorrollouts. When enabled, an individual failed rollout is logged and replaced with a structurally valid, loss-masked placeholder row usingstop_reason="rollout_error", allowing the rest of the generation batch to complete.Fixes #1613.
Root Cause
SkyRLGymGenerator.generate()fans out non-batched rollouts throughtqdm.gather. Under normal gather semantics, the first ordinary exception from any rollout propagates and aborts the entire training step, which is brittle for flaky multi-turn agentic environments.Changes
generator.skip_failed_rollouts: falseto the Python config, default YAML, and docs.skip_failed_rollouts=Truewith batched generation, where per-row recovery is ambiguous.asyncio.CancelledErrorso cancellations and interrupts still stop the step.Validation
python3.12 -m py_compile skyrl/train/generators/skyrl_gym_generator.py skyrl/train/generators/skyrl_vlm_generator.py skyrl/train/generators/utils.py skyrl/train/config/config.py tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_generator_output_utils.pygit diff --checkuv run --python 3.12 --with transformers --with ruff --extra dev --extra skyrl-train --isolated ruff check skyrl/train/generators/skyrl_gym_generator.py skyrl/train/generators/skyrl_vlm_generator.py skyrl/train/generators/utils.py skyrl/train/config/config.py tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_generator_output_utils.pyuv run --python 3.12 --with black --extra dev --extra skyrl-train --isolated black --check --target-version py312 skyrl/train/generators/skyrl_gym_generator.py skyrl/train/generators/skyrl_vlm_generator.py skyrl/train/generators/utils.py skyrl/train/config/config.py tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_generator_output_utils.pyuv run --python 3.12 --with transformers --extra dev --extra skyrl-train --isolated pytest tests/train/generators/test_skyrl_gym_generator.py tests/train/generators/test_generator_output_utils.py tests/train/generators/test_skyrl_vlm_generator.py tests/train/generators/test_utils.py tests/train/test_config.py tests/train/test_trainer_utils.py -qThe combined pytest slice passed with
173 passed, 4 warnings; the warnings were existing Ray/Hydra/legacy-config noise.