Skip to content

fix+feat(gemma2): fix SWA correctness bugs and add FlexAttention fused softcap+SWA path#4308

Open
nvegesna-netizen wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
nvegesna-netizen:nvegesna/fix-gemma2-swa-and-assertions
Open

fix+feat(gemma2): fix SWA correctness bugs and add FlexAttention fused softcap+SWA path#4308
nvegesna-netizen wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
nvegesna-netizen:nvegesna/fix-gemma2-swa-and-assertions

Conversation

@nvegesna-netizen

@nvegesna-netizen nvegesna-netizen commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Fixes four silent correctness bugs in Gemma2DotProductAttention that caused sliding window attention (SWA) to either do nothing or produce wrong masks during pretraining and fine-tuning. On top of the corrected unfused path, adds Gemma2FlexDotProductAttention — a fused attention subclass that uses PyTorch FlexAttention (built-in since PyTorch 2.5) to run softcap and SWA in a single Triton kernel, falling back to the fixed unfused parent when a padding mask is present or dropout is active.

Changelog

gemma2_provider.py: Fix SWA gate from if attention_mask is not None and self.window_size is not None: to if self.window_size is not None:. The old gate was never True on MCore's pretrain path because MCore passes attention_mask=None, so SWA was silently skipped for every pretrain step. The mask must always be generated for SWA (even-numbered) layers: these layers use attn_mask_type=arbitrary, and FusedScaleMaskSoftmax with arbitrary + mask=None routes through ScaledSoftmax — plain softmax with no causal masking — so omitting the mask for any sequence length (including short ones where the window covers the full sequence) drops causal masking entirely. get_swa() degenerates to a pure causal mask when the window covers all positions, so always building it is correct.

gemma2_provider.py: Fix AttnMaskType.causal → AttnMaskType.arbitrary for SWA (even-numbered) layers. FusedScaleMaskSoftmax with causal routes through ScaledUpperTriangMaskedSoftmax.apply(input, scale), which takes no mask argument and silently discards the externally generated SWA mask. Switching to arbitrary routes through ScaledMaskedSoftmax, which applies it correctly. Odd-numbered layers remain causal and keep the fast fused path.

gemma2_provider.py: Fix SWA mask dimensionality — get_swa() returns [sq, sk]; the fused CUDA softmax kernel requires a 4D mask. Now unsqueezed to [1, 1, sq, sk] when there is no padding mask, broadcasting correctly to [b, np, sq, sk].

gemma2_provider.py: Fix fine-tuning padding mask combination — SWA mask is now ORed with any incoming padding mask (swa_mask | attention_mask) rather than replacing it. Both masks use True=masked-out; [sq, sk] broadcasts to [b, 1, sq, sk].

gemma2_provider.py: Fix Gemma2ModelProvider.window_size default from (4096, 0) to (4095, 0). window_size[0] counts prior tokens exclusively, so a 4096-token window requires 4095, consistent with the gemma2_bridge.py round-trip convention.

gemma2_provider.py: Replace bare assert statements (stripped by -O) with raise ValueError for CP > 1 and packed-seq guards.

gemma2_provider.py: Add Gemma2FlexDotProductAttention — uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into a single Triton kernel via score_mod and block_mask. Module-level _get_softcap_score_mod is lru_cache-decorated so all layers with the same softcap share one function object, avoiding N redundant torch.compile recompilations at startup (torch.compile guards on id(fn)). flex_attention itself is wrapped with torch.compile to trigger the fused Triton path. Falls back to the unfused parent when FlexAttention is unavailable, a padding mask is present, or dropout is active. gemma2_layer_spec updated to use this class as core_attention.

test_gemma2_provider.py: Add TestGemma2DotProductAttention with 7 tests covering all correctness fixes: CP ValueError, packed-seq ValueError, SWA gate fires on pretrain path (attention_mask=None), odd layers
test_gemma2_provider.py: Add TestGemma2DotProductAttention with 7 tests covering all correctness fixes: CP ValueError, packed-seq ValueError, SWA gate fires on pretrain path (attention_mask=None), odd layers have no SWA, even layers use AttnMaskType.arbitrary, SWA combined with padding mask via OR, and window_size default. Add TestGemma2FlexDotProductAttention with 12 tests covering fused path invocation, all fallback conditions (no FlexAttention, dropout, padding mask), output shape, softmax scale, score_mod softcap correctness, lru_cache sharing invariant, block mask caching, SWA layer block mask, and packed-seq ValueError.

Additional Information

Related to:

@copy-pr-bot

copy-pr-bot Bot commented Jun 11, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@nvegesna-netizen nvegesna-netizen marked this pull request as draft June 11, 2026 17:00
@nvegesna-netizen nvegesna-netizen marked this pull request as ready for review June 12, 2026 09:35
@nvegesna-netizen nvegesna-netizen marked this pull request as draft June 12, 2026 09:35
@nvegesna-netizen nvegesna-netizen changed the title fix(gemma2): fix sliding window attention correctness (SWA disabled in pretrain, mask dropped, window_size off-by-one) fix+feat(gemma2): fix SWA correctness bugs and add FlexAttention fused softcap+SWA path Jun 12, 2026
@nvegesna-netizen nvegesna-netizen marked this pull request as ready for review June 12, 2026 09:44
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic bug Something isn't working community-request needs-review PR is ready for code review and waiting on a reviewer labels Jun 12, 2026
nvegesna-netizen and others added 4 commits June 12, 2026 09:09
… and assert→ValueError

Three bugs in Gemma2DotProductAttention fixed:

1. SWA silently disabled (correctness blocker): the gate
   `if attention_mask is not None and self.window_size is not None`
   was never True in causal pretraining because MCore passes
   attention_mask=None on the standard forward path. All Gemma2
   pretraining ran full causal attention instead of sliding window
   attention. Fix: drop the `attention_mask is not None` condition.

2. window_size dataclass default (4096, 0) → (4095, 0): gemma2_bridge.py
   converts HF sliding_window via `(sliding_window - 1, 0)` (line 54).
   When AutoBridge is not used and the provider default is taken directly,
   the effective window was 4097 tokens instead of 4096. Aligned default
   with the bridge convention.

3. Bare assert → ValueError for CP > 1 and packed_seq_params: bare
   `assert` statements are silenced by Python -O and produce unhelpful
   tracebacks. Replaced both with explicit ValueError raises with
   descriptive messages.

Tests: update existing window_size assertion; add TestGemma2DotProductAttention
covering CP>1 ValueError, packed_seq ValueError, SWA-mask-when-mask-is-None,
odd-layer has no window_size, and window_size default.

Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
…ding mask

- Override attn_mask_type to arbitrary for SWA layers so FusedScaleMaskSoftmax
  routes through ScaledMaskedSoftmax (causal type silently ignores the mask arg)
- Unsqueeze get_swa() output [sq, sk] to [1, 1, sq, sk] for the fused CUDA kernel
- OR-combine SWA mask with incoming padding mask instead of replacing it
- Add TestGemma2DotProductAttention with 7 tests covering all fixes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
Uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA
into a single Triton kernel via score_mod and block_mask. Falls back to the
unfused parent when a padding mask is present, dropout is active, or
FlexAttention is unavailable (pretraining always takes the fused path).

- _get_softcap_score_mod: lru_cache-decorated so all layers share one function
  object, avoiding N redundant torch.compile recompilations at startup
- flex_attention wrapped with torch.compile to trigger the fused Triton path
- Block mask cache keyed by (sq, sk) per instance to avoid rebuilding each step
- gemma2_layer_spec updated to use Gemma2FlexDotProductAttention as core_attention
- TestGemma2FlexDotProductAttention: 12 tests covering fused path, all fallback
  conditions, output shape, scale, score_mod correctness, lru_cache sharing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
…eq lengths

The previous optimization skipped get_swa() when sq <= window_size[0] + 1,
assuming the window covered the full sequence. But SWA layers use
attn_mask_type=arbitrary, which routes through ScaledSoftmax (plain softmax,
no causal mask) when mask=None — dropping causal masking entirely for short
sequences, not just the SWA restriction.

get_swa() encodes the causal triangular structure via triu/tril and
degenerates to a pure causal mask when the window covers all positions,
so always calling it is both correct and sufficient.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
@nvegesna-netizen nvegesna-netizen force-pushed the nvegesna/fix-gemma2-swa-and-assertions branch from 83495ce to beba296 Compare June 12, 2026 16:11
@yaoyu-33

Copy link
Copy Markdown
Contributor

/claude review

@yaoyu-33 yaoyu-33 added ready-to-merge PR is approved, current, and only waiting for CI to pass before merge community-request and removed needs-review PR is ready for code review and waiting on a reviewer labels Jun 14, 2026
Comment on lines +84 to +85
def _score_mod(score, b, h, q_idx, kv_idx):
return softcap * torch.tanh(score / softcap)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: division by zero when softcap=0.0

The unfused path guards against zero softcap (if not scale: return logits in logit_softcapping), but this score_mod doesn't. If attn_logit_softcapping is set to 0.0 (to disable softcapping), score / softcap produces inf/nan, and 0.0 * tanh(nan) = nan for any score value of exactly 0.0 (common for padding positions).

Gemma2 always uses 50.0 so this isn't triggered today, but it's a latent correctness bug if anyone disables softcapping on the fused path.

Suggested change
def _score_mod(score, b, h, q_idx, kv_idx):
return softcap * torch.tanh(score / softcap)
def _score_mod(score, b, h, q_idx, kv_idx):
if softcap == 0.0:
return score
return softcap * torch.tanh(score / softcap)

_flex_attn_func = torch.compile(_flex_candidate)
_create_flex_block_mask = _flex_mask_candidate
_HAVE_FLEX_ATTN = True
logger.warning("Gemma2: PyTorch FlexAttention available — softcap+SWA fused via Triton score_mod.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: logger.warning for a success message is semantically wrong — this is informational, not a warning condition. Both this line and line 73 should use logger.info.

Suggested change
logger.warning("Gemma2: PyTorch FlexAttention available — softcap+SWA fused via Triton score_mod.")
logger.info("Gemma2: PyTorch FlexAttention available — softcap+SWA fused via Triton score_mod.")

@claude

claude Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Light Code Review

Solid set of fixes. The FlexAttention fused path is well-structured with appropriate fallback conditions, and the lru_cache sharing for score_mod is a nice touch to avoid torch.compile recompilation overhead.
Issues:

  1. Bug: _get_softcap_score_mod division by zero when softcap=0.0 (inline comment) - The unfused logit_softcapping() function guards against zero scale, but the score_mod closure does not. score / 0.0 produces inf/nan, and 0.0 * tanh(nan) = nan. Gemma2 always uses 50.0 so this is latent, but it will silently produce NaN if anyone disables softcapping on the fused path.
  2. Minor: logger.warning for informational messages (inline comment) - Lines 67 and 73 use logger.warning for feature-availability diagnostics. Both should be logger.info.
  3. Minor: _build_flex_block_mask ignores the right window component - _flex_window_size[1] (the right window) is stored but never read in _build_flex_block_mask. For Gemma2 where right=0, the causal constraint handles it. But if right > 0 were used, the SWA mask would silently omit the right-window constraint.
    Test Coverage - thorough. Two gaps worth considering: (a) No test for softcap=0.0 on the fused path (would catch the division-by-zero bug above); (b) No test for GQA with FlexAttention (_make_flex_attention uses num_query_groups == num_attention_heads so enable_gqa=True is never exercised).
    Suggested test cases: No perf tests impacted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic bug Something isn't working ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants