Commit 2df8f51
Fix nan leak from masked positions in
Replace `kld * loss_mask` with `torch.where(loss_mask.bool(), kld *
loss_mask, 0.0)` in `compute_approx_kl`, so `nan` at masked positions
(where `0 * nan = nan`) can no longer leak through into `policy_kl` /
`final_loss`.
Closes #1633
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>compute_approx_kl (#1635)1 parent 811f8ee commit 2df8f51
2 files changed
Lines changed: 27 additions & 1 deletion
File tree
- skyrl/backends/skyrl_train/utils
- tests/backends/skyrl_train/utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
119 | 119 | | |
120 | 120 | | |
121 | 121 | | |
122 | | - | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
123 | 125 | | |
124 | 126 | | |
125 | 127 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| 29 | + | |
29 | 30 | | |
30 | 31 | | |
31 | 32 | | |
| |||
66 | 67 | | |
67 | 68 | | |
68 | 69 | | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
69 | 93 | | |
70 | 94 | | |
71 | 95 | | |
| |||
0 commit comments