diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 64606b26a7..8a173a13e0 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -119,7 +119,9 @@ def compute_approx_kl( raise ValueError(f"Invalid KL estimator type: {kl_estimator_type}") if loss_mask is not None: - kld = kld * loss_mask + # Multiplying by `loss_mask` can leak `nan` from masked positions, + # so route masked positions to 0.0 directly while keeping mask scaling elsewhere + kld = torch.where(loss_mask.bool(), kld * loss_mask, 0.0) return kld diff --git a/tests/backends/skyrl_train/utils/test_ppo_utils.py b/tests/backends/skyrl_train/utils/test_ppo_utils.py index 6a34e93de5..81055debda 100644 --- a/tests/backends/skyrl_train/utils/test_ppo_utils.py +++ b/tests/backends/skyrl_train/utils/test_ppo_utils.py @@ -26,6 +26,7 @@ register_advantage_estimator, register_policy_loss, ) +from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean @pytest.fixture @@ -66,6 +67,29 @@ def test_compute_approx_kl(dummy_data): assert torch.allclose(kl_k3, expected_k3, atol=1e-4), "k3 estimator is not correct" +@pytest.mark.parametrize("kl_estimator_type", ["k1", "k2", "k3", "abs"]) +def test_compute_approx_kl_applies_loss_mask(kl_estimator_type: str) -> None: + """Scales kept positions; masked positions become 0.0, even when their inputs are nan/inf.""" + log_probs = torch.tensor([[0.2, 0.3, 0.5, 0.7]]) + # Position 3: nan input + mask=0.0; assertions below check the nan doesn't leak + log_probs_base = torch.tensor([[0.1, 0.2, 0.4, float("nan")]]) + mask = torch.tensor([[1.0, 0.5, 0.25, 0.0]]) + + kld = compute_approx_kl(log_probs, log_probs_base, mask, kl_estimator_type=kl_estimator_type) + + # A masked position must contribute nothing, even when its input is non-finite + assert torch.isfinite(kld).all(), f"{kl_estimator_type}: kld leaked non-finite values: {kld}" + assert kld[0, 3].item() == 0.0, f"{kl_estimator_type}: masked position not zeroed" + assert torch.isfinite(masked_mean(kld, mask)), f"{kl_estimator_type}: masked_mean is non-finite" + + # Soft-mask values scale each kept position multiplicatively + unmasked = compute_approx_kl(log_probs, log_probs_base, None, kl_estimator_type=kl_estimator_type) + expected_kept = unmasked[:, :3] * mask[:, :3] + assert torch.allclose( + kld[:, :3], expected_kept, atol=1e-6 + ), f"{kl_estimator_type}: soft mask scaling not preserved: {kld[:, :3]} vs {expected_kept}" + + def test_compute_reinforce_plus_plus_outcome_advantage_returns_and_masking(): """REINFORCE++ returns should be discounted sums with reset after EOS; advantages masked.""" token_level_rewards = torch.tensor([[1.0, 2.0, 3.0]])