Skip to content

Commit 2df8f51

Browse files
jamesbrazaclaude
andauthored
Fix nan leak from masked positions in compute_approx_kl (#1635)
Replace `kld * loss_mask` with `torch.where(loss_mask.bool(), kld * loss_mask, 0.0)` in `compute_approx_kl`, so `nan` at masked positions (where `0 * nan = nan`) can no longer leak through into `policy_kl` / `final_loss`. Closes #1633 --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 811f8ee commit 2df8f51

2 files changed

Lines changed: 27 additions & 1 deletion

File tree

skyrl/backends/skyrl_train/utils/ppo_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def compute_approx_kl(
119119
raise ValueError(f"Invalid KL estimator type: {kl_estimator_type}")
120120

121121
if loss_mask is not None:
122-
kld = kld * loss_mask
122+
# Multiplying by `loss_mask` can leak `nan` from masked positions,
123+
# so route masked positions to 0.0 directly while keeping mask scaling elsewhere
124+
kld = torch.where(loss_mask.bool(), kld * loss_mask, 0.0)
123125
return kld
124126

125127

tests/backends/skyrl_train/utils/test_ppo_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
register_advantage_estimator,
2727
register_policy_loss,
2828
)
29+
from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean
2930

3031

3132
@pytest.fixture
@@ -66,6 +67,29 @@ def test_compute_approx_kl(dummy_data):
6667
assert torch.allclose(kl_k3, expected_k3, atol=1e-4), "k3 estimator is not correct"
6768

6869

70+
@pytest.mark.parametrize("kl_estimator_type", ["k1", "k2", "k3", "abs"])
71+
def test_compute_approx_kl_applies_loss_mask(kl_estimator_type: str) -> None:
72+
"""Scales kept positions; masked positions become 0.0, even when their inputs are nan/inf."""
73+
log_probs = torch.tensor([[0.2, 0.3, 0.5, 0.7]])
74+
# Position 3: nan input + mask=0.0; assertions below check the nan doesn't leak
75+
log_probs_base = torch.tensor([[0.1, 0.2, 0.4, float("nan")]])
76+
mask = torch.tensor([[1.0, 0.5, 0.25, 0.0]])
77+
78+
kld = compute_approx_kl(log_probs, log_probs_base, mask, kl_estimator_type=kl_estimator_type)
79+
80+
# A masked position must contribute nothing, even when its input is non-finite
81+
assert torch.isfinite(kld).all(), f"{kl_estimator_type}: kld leaked non-finite values: {kld}"
82+
assert kld[0, 3].item() == 0.0, f"{kl_estimator_type}: masked position not zeroed"
83+
assert torch.isfinite(masked_mean(kld, mask)), f"{kl_estimator_type}: masked_mean is non-finite"
84+
85+
# Soft-mask values scale each kept position multiplicatively
86+
unmasked = compute_approx_kl(log_probs, log_probs_base, None, kl_estimator_type=kl_estimator_type)
87+
expected_kept = unmasked[:, :3] * mask[:, :3]
88+
assert torch.allclose(
89+
kld[:, :3], expected_kept, atol=1e-6
90+
), f"{kl_estimator_type}: soft mask scaling not preserved: {kld[:, :3]} vs {expected_kept}"
91+
92+
6993
def test_compute_reinforce_plus_plus_outcome_advantage_returns_and_masking():
7094
"""REINFORCE++ returns should be discounted sums with reset after EOS; advantages masked."""
7195
token_level_rewards = torch.tensor([[1.0, 2.0, 3.0]])

0 commit comments

Comments
 (0)