fix(gla): mask v_decay_exp to avoid NaN at padded chunk tail#1073
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a numerical stability issue in the GLA kernel where padded sequence inputs resulted in NaN values. By masking the decay exponent calculation, the kernel now correctly handles sequences that are not perfectly aligned with the default chunk size, preventing the contamination of recurrent states. Highlights
New Features🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize the Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counterproductive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
When _chunk_fwd_h_kernel_varlen runs with seq_real_lens (sequences padded to a multiple of chunk_size internally), the trailing padded positions hold v_tile = 0 while the per-token decay exponent b_g_last - b_g can be very negative or even -inf in bf16. The product 0 * inf evaluates to NaN, contaminating the recurrent state and propagating downstream. Mask v_decay_exp to a large negative value (-1e9) on positions beyond L_chunk so that exp(...) underflows to 0 and the masked v_tile contribution is exactly 0. Triggered by any sequence whose length is not a multiple of chunk_size (= 64 by default), since each such sequence is internally padded by _align_varlen_inputs. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
47e4e15 to
417e848
Compare
Summary
Fix a NaN bug in
_chunk_fwd_h_kernel_varlenthat triggers whenever asequence length is not a multiple of
chunk_size(= 64 by default).Root cause
When
seq_real_lensis provided, sequences are internally padded to amultiple of
chunk_sizeby_align_varlen_inputs. Inside the kernel,the padded tail of
v_tileis0while the per-token decay exponentb_g_last - b_gcan be very negative, and overflows to-infin bf16.The product
0 * infthen evaluates toNaN, contaminating therecurrent state and propagating downstream.
Fix
Mask
v_decay_expto a large negative value (-1e9) on positionsbeyond
L_chunk, soexp(...)underflows to0and the maskedv_tilecontribution is exactly0.Scope
python/sgl_jax/srt/kernels/simple_gla/simple_gla.py(1 file, +6/-4).
Test plan
not multiples of
chunk_size; verify noNaNin output andgenerations match the previously-correct path.