Skip to content

fix+feat(gemma2): fix SWA correctness bugs and add FlexAttention fused softcap+SWA path#4308

Open
nvegesna-netizen wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
nvegesna-netizen:nvegesna/fix-gemma2-swa-and-assertions
Open

fix+feat(gemma2): fix SWA correctness bugs and add FlexAttention fused softcap+SWA path#4308
nvegesna-netizen wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
nvegesna-netizen:nvegesna/fix-gemma2-swa-and-assertions

Conversation

@nvegesna-netizen

@nvegesna-netizen nvegesna-netizen commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Fixes four silent correctness bugs in Gemma2DotProductAttention that caused sliding window attention (SWA) to either do nothing or produce wrong masks during pretraining and fine-tuning. On top of the corrected unfused path, adds Gemma2FlexDotProductAttention — a fused attention subclass that uses PyTorch FlexAttention (built-in since PyTorch 2.5) to run softcap and SWA in a single Triton kernel, falling back to the fixed unfused parent when a padding mask is present or dropout is active.

Changelog

gemma2_provider.py: Fix SWA gate from if attention_mask is not None and self.window_size is not None: to if self.window_size is not None:. The old gate was never True on MCore's pretrain path because MCore passes attention_mask=None, so SWA was silently skipped for every pretrain step. The mask must always be generated for SWA (even-numbered) layers: these layers use attn_mask_type=arbitrary, and FusedScaleMaskSoftmax with arbitrary + mask=None routes through ScaledSoftmax — plain softmax with no causal masking — so omitting the mask for any sequence length (including short ones where the window covers the full sequence) drops causal masking entirely. get_swa() degenerates to a pure causal mask when the window covers all positions, so always building it is correct.

gemma2_provider.py: Fix AttnMaskType.causal → AttnMaskType.arbitrary for SWA (even-numbered) layers. FusedScaleMaskSoftmax with causal routes through ScaledUpperTriangMaskedSoftmax.apply(input, scale), which takes no mask argument and silently discards the externally generated SWA mask. Switching to arbitrary routes through ScaledMaskedSoftmax, which applies it correctly. Odd-numbered layers remain causal and keep the fast fused path.

gemma2_provider.py: Fix SWA mask dimensionality — get_swa() returns [sq, sk]; the fused CUDA softmax kernel requires a 4D mask. Now unsqueezed to [1, 1, sq, sk] when there is no padding mask, broadcasting correctly to [b, np, sq, sk].

gemma2_provider.py: Fix fine-tuning padding mask combination — SWA mask is now ORed with any incoming padding mask (swa_mask | attention_mask) rather than replacing it. Both masks use True=masked-out; [sq, sk] broadcasts to [b, 1, sq, sk].

gemma2_provider.py: Fix Gemma2ModelProvider.window_size default from (4096, 0) to (4095, 0). window_size[0] counts prior tokens exclusively, so a 4096-token window requires 4095, consistent with the gemma2_bridge.py round-trip convention.

gemma2_provider.py: Replace bare assert statements (stripped by -O) with raise ValueError for CP > 1 and packed-seq guards.

gemma2_provider.py: Add Gemma2FlexDotProductAttention — uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into a single Triton kernel via score_mod and block_mask. Module-level _get_softcap_score_mod is lru_cache-decorated so all layers with the same softcap share one function object, avoiding N redundant torch.compile recompilations at startup (torch.compile guards on id(fn)). flex_attention itself is wrapped with torch.compile to trigger the fused Triton path. Falls back to the unfused parent when FlexAttention is unavailable, a padding mask is present, or dropout is active. gemma2_layer_spec updated to use this class as core_attention.

test_gemma2_provider.py: Add TestGemma2DotProductAttention with 7 tests covering all correctness fixes: CP ValueError, packed-seq ValueError, SWA gate fires on pretrain path (attention_mask=None), odd layers
test_gemma2_provider.py: Add TestGemma2DotProductAttention with 7 tests covering all correctness fixes: CP ValueError, packed-seq ValueError, SWA gate fires on pretrain path (attention_mask=None), odd layers have no SWA, even layers use AttnMaskType.arbitrary, SWA combined with padding mask via OR, and window_size default. Add TestGemma2FlexDotProductAttention with 12 tests covering fused path invocation, all fallback conditions (no FlexAttention, dropout, padding mask), output shape, softmax scale, score_mod softcap correctness, lru_cache sharing invariant, block mask caching, SWA layer block mask, and packed-seq ValueError.

Additional Information

Related to:

@copy-pr-bot

copy-pr-bot Bot commented Jun 11, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@nvegesna-netizen nvegesna-netizen marked this pull request as draft June 11, 2026 17:00
@nvegesna-netizen nvegesna-netizen marked this pull request as ready for review June 12, 2026 09:35
@nvegesna-netizen nvegesna-netizen marked this pull request as draft June 12, 2026 09:35
@nvegesna-netizen nvegesna-netizen changed the title fix(gemma2): fix sliding window attention correctness (SWA disabled in pretrain, mask dropped, window_size off-by-one) fix+feat(gemma2): fix SWA correctness bugs and add FlexAttention fused softcap+SWA path Jun 12, 2026
@nvegesna-netizen nvegesna-netizen marked this pull request as ready for review June 12, 2026 09:44
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic bug Something isn't working community-request needs-review PR is ready for code review and waiting on a reviewer labels Jun 12, 2026
nvegesna-netizen and others added 4 commits June 12, 2026 09:09
… and assert→ValueError

Three bugs in Gemma2DotProductAttention fixed:

1. SWA silently disabled (correctness blocker): the gate
   `if attention_mask is not None and self.window_size is not None`
   was never True in causal pretraining because MCore passes
   attention_mask=None on the standard forward path. All Gemma2
   pretraining ran full causal attention instead of sliding window
   attention. Fix: drop the `attention_mask is not None` condition.

2. window_size dataclass default (4096, 0) → (4095, 0): gemma2_bridge.py
   converts HF sliding_window via `(sliding_window - 1, 0)` (line 54).
   When AutoBridge is not used and the provider default is taken directly,
   the effective window was 4097 tokens instead of 4096. Aligned default
   with the bridge convention.

3. Bare assert → ValueError for CP > 1 and packed_seq_params: bare
   `assert` statements are silenced by Python -O and produce unhelpful
   tracebacks. Replaced both with explicit ValueError raises with
   descriptive messages.

Tests: update existing window_size assertion; add TestGemma2DotProductAttention
covering CP>1 ValueError, packed_seq ValueError, SWA-mask-when-mask-is-None,
odd-layer has no window_size, and window_size default.

Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
…ding mask

- Override attn_mask_type to arbitrary for SWA layers so FusedScaleMaskSoftmax
  routes through ScaledMaskedSoftmax (causal type silently ignores the mask arg)
- Unsqueeze get_swa() output [sq, sk] to [1, 1, sq, sk] for the fused CUDA kernel
- OR-combine SWA mask with incoming padding mask instead of replacing it
- Add TestGemma2DotProductAttention with 7 tests covering all fixes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
Uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA
into a single Triton kernel via score_mod and block_mask. Falls back to the
unfused parent when a padding mask is present, dropout is active, or
FlexAttention is unavailable (pretraining always takes the fused path).

- _get_softcap_score_mod: lru_cache-decorated so all layers share one function
  object, avoiding N redundant torch.compile recompilations at startup
- flex_attention wrapped with torch.compile to trigger the fused Triton path
- Block mask cache keyed by (sq, sk) per instance to avoid rebuilding each step
- gemma2_layer_spec updated to use Gemma2FlexDotProductAttention as core_attention
- TestGemma2FlexDotProductAttention: 12 tests covering fused path, all fallback
  conditions, output shape, scale, score_mod correctness, lru_cache sharing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
…eq lengths

The previous optimization skipped get_swa() when sq <= window_size[0] + 1,
assuming the window covered the full sequence. But SWA layers use
attn_mask_type=arbitrary, which routes through ScaledSoftmax (plain softmax,
no causal mask) when mask=None — dropping causal masking entirely for short
sequences, not just the SWA restriction.

get_swa() encodes the causal triangular structure via triu/tril and
degenerates to a pure causal mask when the window covers all positions,
so always calling it is both correct and sufficient.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
@nvegesna-netizen nvegesna-netizen force-pushed the nvegesna/fix-gemma2-swa-and-assertions branch from 83495ce to beba296 Compare June 12, 2026 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic bug Something isn't working community-request needs-review PR is ready for code review and waiting on a reviewer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants