Commit 417e848
fix(gla): mask v_decay_exp to avoid NaN at padded chunk tail
When _chunk_fwd_h_kernel_varlen runs with seq_real_lens (sequences
padded to a multiple of chunk_size internally), the trailing padded
positions hold v_tile = 0 while the per-token decay exponent
b_g_last - b_g can be very negative or even -inf in bf16. The
product 0 * inf evaluates to NaN, contaminating the recurrent state
and propagating downstream.
Mask v_decay_exp to a large negative value (-1e9) on positions
beyond L_chunk so that exp(...) underflows to 0 and the masked
v_tile contribution is exactly 0.
Triggered by any sequence whose length is not a multiple of
chunk_size (= 64 by default), since each such sequence is internally
padded by _align_varlen_inputs.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>1 parent 2f84fb4 commit 417e848
1 file changed
Lines changed: 6 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
381 | 381 | | |
382 | 382 | | |
383 | 383 | | |
384 | | - | |
385 | | - | |
386 | | - | |
| 384 | + | |
| 385 | + | |
387 | 386 | | |
388 | | - | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
389 | 391 | | |
390 | 392 | | |
391 | 393 | | |
| |||
0 commit comments