Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def compute_approx_kl(
raise ValueError(f"Invalid KL estimator type: {kl_estimator_type}")

if loss_mask is not None:
kld = kld * loss_mask
# Multiplying by `loss_mask` can leak `nan` from masked positions,
# so route masked positions to 0.0 directly while keeping mask scaling elsewhere
kld = torch.where(loss_mask.bool(), kld * loss_mask, 0.0)
return kld


Expand Down
24 changes: 24 additions & 0 deletions tests/backends/skyrl_train/utils/test_ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
register_advantage_estimator,
register_policy_loss,
)
from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean


@pytest.fixture
Expand Down Expand Up @@ -66,6 +67,29 @@ def test_compute_approx_kl(dummy_data):
assert torch.allclose(kl_k3, expected_k3, atol=1e-4), "k3 estimator is not correct"


@pytest.mark.parametrize("kl_estimator_type", ["k1", "k2", "k3", "abs"])
def test_compute_approx_kl_applies_loss_mask(kl_estimator_type: str) -> None:
"""Scales kept positions; masked positions become 0.0, even when their inputs are nan/inf."""
log_probs = torch.tensor([[0.2, 0.3, 0.5, 0.7]])
# Position 3: nan input + mask=0.0; assertions below check the nan doesn't leak
log_probs_base = torch.tensor([[0.1, 0.2, 0.4, float("nan")]])
mask = torch.tensor([[1.0, 0.5, 0.25, 0.0]])

kld = compute_approx_kl(log_probs, log_probs_base, mask, kl_estimator_type=kl_estimator_type)

# A masked position must contribute nothing, even when its input is non-finite
assert torch.isfinite(kld).all(), f"{kl_estimator_type}: kld leaked non-finite values: {kld}"
assert kld[0, 3].item() == 0.0, f"{kl_estimator_type}: masked position not zeroed"
assert torch.isfinite(masked_mean(kld, mask)), f"{kl_estimator_type}: masked_mean is non-finite"

# Soft-mask values scale each kept position multiplicatively
unmasked = compute_approx_kl(log_probs, log_probs_base, None, kl_estimator_type=kl_estimator_type)
expected_kept = unmasked[:, :3] * mask[:, :3]
assert torch.allclose(
kld[:, :3], expected_kept, atol=1e-6
), f"{kl_estimator_type}: soft mask scaling not preserved: {kld[:, :3]} vs {expected_kept}"


def test_compute_reinforce_plus_plus_outcome_advantage_returns_and_masking():
"""REINFORCE++ returns should be discounted sums with reset after EOS; advantages masked."""
token_level_rewards = torch.tensor([[1.0, 2.0, 3.0]])
Expand Down
Loading