Skip to content

Commit beba296

Browse files
fix(model): always build SWA mask to preserve causal masking at all seq lengths
The previous optimization skipped get_swa() when sq <= window_size[0] + 1, assuming the window covered the full sequence. But SWA layers use attn_mask_type=arbitrary, which routes through ScaledSoftmax (plain softmax, no causal mask) when mask=None — dropping causal masking entirely for short sequences, not just the SWA restriction. get_swa() encodes the causal triangular structure via triu/tril and degenerates to a pure causal mask when the window covers all positions, so always calling it is both correct and sufficient. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
1 parent d5d0133 commit beba296

2 files changed

Lines changed: 12 additions & 14 deletions

File tree

src/megatron/bridge/models/gemma/gemma2_provider.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,12 @@ def forward(
251251
# mask [b, np, sq, sk], so we unsqueeze to [1, 1, sq, sk] when there is no
252252
# padding mask. When a padding mask [b, 1, sq, sk] is present, the | already
253253
# produces a 4D result via broadcasting.
254-
# Skip mask generation when the window fully covers the sequence: masking only
255-
# fires when query index i > window_size[0], i.e. seq_q > window_size[0] + 1.
256-
# For seq_length=4096 with window=4095 this is a no-op, so we stay on the
257-
# fast ScaledUpperTriangMaskedSoftmax (and FlexAttention) path.
258-
if self.window_size is not None and query.size(0) > self.window_size[0] + 1:
254+
# The mask is always generated for SWA layers: attn_mask_type=arbitrary means
255+
# FusedScaleMaskSoftmax routes through ScaledSoftmax (no causal masking) when
256+
# mask=None, so omitting the mask for short sequences would drop causal masking
257+
# entirely. get_swa() encodes causal structure via triu/tril and degenerates to
258+
# a pure causal mask when the window fully covers the sequence.
259+
if self.window_size is not None:
259260
swa_mask = get_swa(query.size(0), key.size(0), self.window_size)
260261
if attention_mask is None:
261262
attention_mask = swa_mask.unsqueeze(0).unsqueeze(0)

tests/unit_tests/models/gemma/test_gemma2_provider.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,14 @@ def test_swa_applied_when_attention_mask_is_none(self):
255255
Prior to the fix, the gate was:
256256
if attention_mask is not None and self.window_size is not None:
257257
which was never True on the pretrain path (MCore passes attention_mask=None).
258-
After the fix the gate is:
259-
if self.window_size is not None and query.size(0) > self.window_size[0] + 1:
260-
We verify this by patching get_swa and confirming it is called from forward()
261-
when attention_mask=None is passed to an even-numbered layer whose window is
262-
smaller than the sequence length. window=(2, 0) with seq=4: 4 > 3, so the
263-
guard fires and the SWA mask is built and unsqueezed to [1, 1, sq, sk].
258+
The gate is now simply:
259+
if self.window_size is not None:
260+
The mask is always built for SWA layers — omitting it when the window covers the
261+
full sequence would drop causal masking entirely because attn_mask_type=arbitrary
262+
routes through ScaledSoftmax (plain softmax, no causal mask) when mask=None.
263+
get_swa() degenerates to a pure causal mask when the window covers all positions.
264264
"""
265265
seq, batch, heads, head_dim = 4, 1, 8, 32
266-
# window=(2, 0): seq=4 > window+1=3, so the SWA guard fires.
267266
attn = _make_attention(window_size=(2, 0))
268267
assert attn.window_size == (2, 0), "even layer must have window_size set"
269268

@@ -346,9 +345,7 @@ def test_swa_combined_with_padding_mask(self):
346345
which silently discarded any incoming padding mask. The correct behaviour is:
347346
attention_mask = swa_mask if attention_mask is None else (swa_mask | attention_mask)
348347
Both masks use True=masked-out, so logical OR gives the union of blocked positions.
349-
window=(2, 0) with seq=4: 4 > 3, so the SWA guard fires.
350348
"""
351-
# window=(2, 0): seq=4 > window+1=3, so the SWA guard fires.
352349
attn = _make_attention(window_size=(2, 0))
353350

354351
seq, batch, heads, head_dim = 4, 2, 8, 32

0 commit comments

Comments
 (0)