fix+feat(gemma2): fix SWA correctness bugs and add FlexAttention fused softcap+SWA path#4308
Open
nvegesna-netizen wants to merge 5 commits into
Open
Conversation
… and assert→ValueError Three bugs in Gemma2DotProductAttention fixed: 1. SWA silently disabled (correctness blocker): the gate `if attention_mask is not None and self.window_size is not None` was never True in causal pretraining because MCore passes attention_mask=None on the standard forward path. All Gemma2 pretraining ran full causal attention instead of sliding window attention. Fix: drop the `attention_mask is not None` condition. 2. window_size dataclass default (4096, 0) → (4095, 0): gemma2_bridge.py converts HF sliding_window via `(sliding_window - 1, 0)` (line 54). When AutoBridge is not used and the provider default is taken directly, the effective window was 4097 tokens instead of 4096. Aligned default with the bridge convention. 3. Bare assert → ValueError for CP > 1 and packed_seq_params: bare `assert` statements are silenced by Python -O and produce unhelpful tracebacks. Replaced both with explicit ValueError raises with descriptive messages. Tests: update existing window_size assertion; add TestGemma2DotProductAttention covering CP>1 ValueError, packed_seq ValueError, SWA-mask-when-mask-is-None, odd-layer has no window_size, and window_size default. Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
…ding mask - Override attn_mask_type to arbitrary for SWA layers so FusedScaleMaskSoftmax routes through ScaledMaskedSoftmax (causal type silently ignores the mask arg) - Unsqueeze get_swa() output [sq, sk] to [1, 1, sq, sk] for the fused CUDA kernel - OR-combine SWA mask with incoming padding mask instead of replacing it - Add TestGemma2DotProductAttention with 7 tests covering all fixes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
Uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into a single Triton kernel via score_mod and block_mask. Falls back to the unfused parent when a padding mask is present, dropout is active, or FlexAttention is unavailable (pretraining always takes the fused path). - _get_softcap_score_mod: lru_cache-decorated so all layers share one function object, avoiding N redundant torch.compile recompilations at startup - flex_attention wrapped with torch.compile to trigger the fused Triton path - Block mask cache keyed by (sq, sk) per instance to avoid rebuilding each step - gemma2_layer_spec updated to use Gemma2FlexDotProductAttention as core_attention - TestGemma2FlexDotProductAttention: 12 tests covering fused path, all fallback conditions, output shape, scale, score_mod correctness, lru_cache sharing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
…eq lengths The previous optimization skipped get_swa() when sq <= window_size[0] + 1, assuming the window covered the full sequence. But SWA layers use attn_mask_type=arbitrary, which routes through ScaledSoftmax (plain softmax, no causal mask) when mask=None — dropping causal masking entirely for short sequences, not just the SWA restriction. get_swa() encodes the causal triangular structure via triu/tril and degenerates to a pure causal mask when the window covers all positions, so always calling it is both correct and sufficient. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
83495ce to
beba296
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes four silent correctness bugs in Gemma2DotProductAttention that caused sliding window attention (SWA) to either do nothing or produce wrong masks during pretraining and fine-tuning. On top of the corrected unfused path, adds Gemma2FlexDotProductAttention — a fused attention subclass that uses PyTorch FlexAttention (built-in since PyTorch 2.5) to run softcap and SWA in a single Triton kernel, falling back to the fixed unfused parent when a padding mask is present or dropout is active.
Changelog
gemma2_provider.py: Fix SWA gate from if attention_mask is not None and self.window_size is not None: to if self.window_size is not None:. The old gate was never True on MCore's pretrain path because MCore passes attention_mask=None, so SWA was silently skipped for every pretrain step. The mask must always be generated for SWA (even-numbered) layers: these layers use attn_mask_type=arbitrary, and FusedScaleMaskSoftmax with arbitrary + mask=None routes through ScaledSoftmax — plain softmax with no causal masking — so omitting the mask for any sequence length (including short ones where the window covers the full sequence) drops causal masking entirely. get_swa() degenerates to a pure causal mask when the window covers all positions, so always building it is correct.
gemma2_provider.py: Fix AttnMaskType.causal → AttnMaskType.arbitrary for SWA (even-numbered) layers. FusedScaleMaskSoftmax with causal routes through ScaledUpperTriangMaskedSoftmax.apply(input, scale), which takes no mask argument and silently discards the externally generated SWA mask. Switching to arbitrary routes through ScaledMaskedSoftmax, which applies it correctly. Odd-numbered layers remain causal and keep the fast fused path.
gemma2_provider.py: Fix SWA mask dimensionality — get_swa() returns [sq, sk]; the fused CUDA softmax kernel requires a 4D mask. Now unsqueezed to [1, 1, sq, sk] when there is no padding mask, broadcasting correctly to [b, np, sq, sk].
gemma2_provider.py: Fix fine-tuning padding mask combination — SWA mask is now ORed with any incoming padding mask (swa_mask | attention_mask) rather than replacing it. Both masks use True=masked-out; [sq, sk] broadcasts to [b, 1, sq, sk].
gemma2_provider.py: Fix Gemma2ModelProvider.window_size default from (4096, 0) to (4095, 0). window_size[0] counts prior tokens exclusively, so a 4096-token window requires 4095, consistent with the gemma2_bridge.py round-trip convention.
gemma2_provider.py: Replace bare assert statements (stripped by -O) with raise ValueError for CP > 1 and packed-seq guards.
gemma2_provider.py: Add Gemma2FlexDotProductAttention — uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into a single Triton kernel via score_mod and block_mask. Module-level _get_softcap_score_mod is lru_cache-decorated so all layers with the same softcap share one function object, avoiding N redundant torch.compile recompilations at startup (torch.compile guards on id(fn)). flex_attention itself is wrapped with torch.compile to trigger the fused Triton path. Falls back to the unfused parent when FlexAttention is unavailable, a padding mask is present, or dropout is active. gemma2_layer_spec updated to use this class as core_attention.
test_gemma2_provider.py: Add TestGemma2DotProductAttention with 7 tests covering all correctness fixes: CP ValueError, packed-seq ValueError, SWA gate fires on pretrain path (attention_mask=None), odd layers
test_gemma2_provider.py: Add TestGemma2DotProductAttention with 7 tests covering all correctness fixes: CP ValueError, packed-seq ValueError, SWA gate fires on pretrain path (attention_mask=None), odd layers have no SWA, even layers use AttnMaskType.arbitrary, SWA combined with padding mask via OR, and window_size default. Add TestGemma2FlexDotProductAttention with 12 tests covering fused path invocation, all fallback conditions (no FlexAttention, dropout, padding mask), output shape, softmax scale, score_mod softcap correctness, lru_cache sharing invariant, block mask caching, SWA layer block mask, and packed-seq ValueError.
Additional Information
Related to:
Malay/dsv3 2509 (669)intor0.1.0#752: confirms MCore passes attention_mask=None on the pretrain path