Skip to content

Fix nan leak from masked positions in compute_approx_kl#1635

Merged
erictang000 merged 2 commits into
NovaSky-AI:mainfrom
EdisonScientific:fix-compute-approx-kl-nan-leak
May 8, 2026
Merged

Fix nan leak from masked positions in compute_approx_kl#1635
erictang000 merged 2 commits into
NovaSky-AI:mainfrom
EdisonScientific:fix-compute-approx-kl-nan-leak

Conversation

@jamesbraza
Copy link
Copy Markdown
Contributor

@jamesbraza jamesbraza commented May 8, 2026

Replace kld * loss_mask with torch.where(loss_mask.bool(), kld * loss_mask, 0.0) in compute_approx_kl, so nan at masked positions (where 0 * nan = nan) can no longer leak through into policy_kl / final_loss.

Closes #1633

Masking via `kld * loss_mask` propagates `nan` from masked positions
because IEEE 754 defines `0 * nan = nan`, poisoning the downstream
masked_mean and any metric (e.g. policy_kl, final_loss) that consumes
the KL scalar. Switch to `masked_fill` so masked positions are forced
to 0.0 regardless of the input value there. Autograd is unaffected.

Add a parametrized regression test covering all four estimator types
(k1, k2, k3, abs) that injects `nan` at a masked position and asserts
the output and downstream `masked_mean` stay finite.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the KL divergence computation to prevent nan leakage from masked positions by replacing direct multiplication with masked_fill. A corresponding unit test was added to verify the fix. The reviewer recommended using torch.where instead of masked_fill to preserve potential soft masking functionality while still addressing the nan leakage issue.

Comment thread skyrl/backends/skyrl_train/utils/ppo_utils.py Outdated
Switch the mask-sanitization step from `masked_fill(~mask.bool(), 0.0)`
to `torch.where(mask.bool(), kld * mask, 0.0)` so non-binary mask values
still scale the kept positions multiplicatively, while masked (mask==0)
positions are still forced to 0.0 so non-finite inputs there cannot leak.

Combine the two prior regression test cases into one parametrized test
(`test_compute_approx_kl_applies_loss_mask`) that exercises both
invariants in one shot: a soft mask `{1.0, 0.5, 0.25, 0.0}` with `nan`
injected at the masked position, asserting kept-position scaling and
masked-position zeroing for all four estimator types.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this lgtm, thanks!

@erictang000 erictang000 merged commit 2df8f51 into NovaSky-AI:main May 8, 2026
4 of 5 checks passed
@jamesbraza jamesbraza deleted the fix-compute-approx-kl-nan-leak branch May 8, 2026 23:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: compute_approx_kl mask-by-multiplication leaks nan from masked positions into policy_kl / final_loss

2 participants