Skip to content

fix(gla): mask v_decay_exp to avoid NaN at padded chunk tail#1073

Merged
Rodrian7 merged 1 commit into
sgl-project:cjx/gla-kernelfrom
Rodrian7:wyx/gla-nan-guard
May 14, 2026
Merged

fix(gla): mask v_decay_exp to avoid NaN at padded chunk tail#1073
Rodrian7 merged 1 commit into
sgl-project:cjx/gla-kernelfrom
Rodrian7:wyx/gla-nan-guard

Conversation

@Rodrian7
Copy link
Copy Markdown
Collaborator

Summary

Fix a NaN bug in _chunk_fwd_h_kernel_varlen that triggers whenever a
sequence length is not a multiple of chunk_size (= 64 by default).

Root cause

When seq_real_lens is provided, sequences are internally padded to a
multiple of chunk_size by _align_varlen_inputs. Inside the kernel,
the padded tail of v_tile is 0 while the per-token decay exponent
b_g_last - b_g can be very negative, and overflows to -inf in bf16.
The product 0 * inf then evaluates to NaN, contaminating the
recurrent state and propagating downstream.

Fix

Mask v_decay_exp to a large negative value (-1e9) on positions
beyond L_chunk, so exp(...) underflows to 0 and the masked
v_tile contribution is exactly 0.

L_chunk = jnp.minimum(BT, effective_remaining)
b_g_last = (g_gamma_ref[i_h].astype(jnp.float32) * L_chunk).astype(g_gamma_ref.dtype)
scratch_ref[...] *= exp(b_g_last)
v_decay_exp = b_g_last - b_g
v_decay_exp = jnp.where(jnp.arange(BT) < L_chunk, v_decay_exp, -1e9)
v_tile = (v_tile * exp(v_decay_exp)[:, None]).astype(v_tile.dtype)

Scope

  • Only python/sgl_jax/srt/kernels/simple_gla/simple_gla.py
    (1 file, +6/-4).
  • No public API change, no caller change.

Test plan

  • End-to-end serving on Ling-2.6 with prompts whose token counts are
    not multiples of chunk_size; verify no NaN in output and
    generations match the previously-correct path.

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a numerical stability issue in the GLA kernel where padded sequence inputs resulted in NaN values. By masking the decay exponent calculation, the kernel now correctly handles sequences that are not perfectly aligned with the default chunk size, preventing the contamination of recurrent states.

Highlights

  • NaN Bug Fix: Implemented a masking strategy for v_decay_exp to prevent NaN values from occurring at the padded tails of chunks when sequence lengths are not multiples of the chunk size.
  • Numerical Stability: Added a conditional mask that sets the decay exponent to -1e9 for padded positions, ensuring that the exponential function underflows to zero instead of producing invalid numerical results.
New Features

🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize the Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counterproductive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

When _chunk_fwd_h_kernel_varlen runs with seq_real_lens (sequences
padded to a multiple of chunk_size internally), the trailing padded
positions hold v_tile = 0 while the per-token decay exponent
b_g_last - b_g can be very negative or even -inf in bf16. The
product 0 * inf evaluates to NaN, contaminating the recurrent state
and propagating downstream.

Mask v_decay_exp to a large negative value (-1e9) on positions
beyond L_chunk so that exp(...) underflows to 0 and the masked
v_tile contribution is exactly 0.

Triggered by any sequence whose length is not a multiple of
chunk_size (= 64 by default), since each such sequence is internally
padded by _align_varlen_inputs.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Rodrian7 Rodrian7 force-pushed the wyx/gla-nan-guard branch from 47e4e15 to 417e848 Compare May 13, 2026 16:04
@Rodrian7 Rodrian7 merged commit 691d63b into sgl-project:cjx/gla-kernel May 14, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant