Skip to content

Commit 417e848

Browse files
Rodrian7claude
andcommitted
fix(gla): mask v_decay_exp to avoid NaN at padded chunk tail
When _chunk_fwd_h_kernel_varlen runs with seq_real_lens (sequences padded to a multiple of chunk_size internally), the trailing padded positions hold v_tile = 0 while the per-token decay exponent b_g_last - b_g can be very negative or even -inf in bf16. The product 0 * inf evaluates to NaN, contaminating the recurrent state and propagating downstream. Mask v_decay_exp to a large negative value (-1e9) on positions beyond L_chunk so that exp(...) underflows to 0 and the masked v_tile contribution is exactly 0. Triggered by any sequence whose length is not a multiple of chunk_size (= 64 by default), since each such sequence is internally padded by _align_varlen_inputs. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 2f84fb4 commit 417e848

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

python/sgl_jax/srt/kernels/simple_gla/simple_gla.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,13 @@ def store_fn():
381381
else:
382382
effective_remaining = eos - t0
383383
# tpu not support scalar bf16 mul
384-
b_g_last = (
385-
g_gamma_ref[i_h].astype(jnp.float32) * jnp.minimum(BT, effective_remaining)
386-
).astype(g_gamma_ref.dtype)
384+
L_chunk = jnp.minimum(BT, effective_remaining)
385+
b_g_last = (g_gamma_ref[i_h].astype(jnp.float32) * L_chunk).astype(g_gamma_ref.dtype)
387386
scratch_ref[...] *= exp(b_g_last)
388-
v_tile = (v_tile * exp(b_g_last - b_g)[:, None]).astype(v_tile.dtype)
387+
# Mask exponent to avoid NaN (0 * inf) in padding positions
388+
v_decay_exp = b_g_last - b_g
389+
v_decay_exp = jnp.where(jnp.arange(BT) < L_chunk, v_decay_exp, -1e9)
390+
v_tile = (v_tile * exp(v_decay_exp)[:, None]).astype(v_tile.dtype)
389391

390392
if gk_ref is not None:
391393
gk_tile = gk_ref[(0, slice(None), slice(None))] # BT * BK

0 commit comments

Comments
 (0)