Describe the Issue
The advantage normalization logic in example_trainer/data.py was found to be numerically unstable for groups with low or zero reward variance.
The original code used scores / max(scores.std(), 1e-8). If all rollouts in a group received the same score (e.g., all 1.0), the standard deviation would be effectively zero. Any tiny floating-point noise would be amplified by the 1e-8 floor, resulting in extremely large advantages. This leads to gradient spikes, massive grad_norm values, and eventual training divergence (NaNs).
Environment/API Details
- Environment Class/Name:
example_trainer/data.py
- Environment Configuration: Any configuration with
group_size > 1.
- API Endpoint/Method Involved:
pad_data_to_good_offset
Steps to Reproduce
- Run a training task where multiple rollouts in the same group receive identical rewards (e.g., all succeed or all fail).
- Observe the calculated advantages in
data.py.
- Monitor the
grad_norm in wandb.
- Observe intermittent spikes in gradient magnitude that do not correspond to actual policy changes.
Interaction Details (if applicable)
- Expected Behavior:
- The normalization should use a magnitude-relative epsilon (e.g.,
max(1e-8, 1e-4 * abs(mean))) to ignore statistically insignificant variance.
- If the standard deviation is below this threshold, the advantages should be centered but not scaled.
Setup Details
- OS: Linux
- Python Version: 3.10+
- Atropos Version: commit c20c852
- Relevant Libraries/Versions:
numpy, torch
Additional Context & Logs
This fix ensures that the RL signal remains stable even when the environment provides sparse or uniform rewards within a group.
Describe the Issue
The advantage normalization logic in
example_trainer/data.pywas found to be numerically unstable for groups with low or zero reward variance.The original code used
scores / max(scores.std(), 1e-8). If all rollouts in a group received the same score (e.g., all 1.0), the standard deviation would be effectively zero. Any tiny floating-point noise would be amplified by the1e-8floor, resulting in extremely large advantages. This leads to gradient spikes, massivegrad_normvalues, and eventual training divergence (NaNs).Environment/API Details
example_trainer/data.pygroup_size > 1.pad_data_to_good_offsetSteps to Reproduce
data.py.grad_normin wandb.Interaction Details (if applicable)
max(1e-8, 1e-4 * abs(mean))) to ignore statistically insignificant variance.Setup Details
numpy,torchAdditional Context & Logs
This fix ensures that the RL signal remains stable even when the environment provides sparse or uniform rewards within a group.