Skip to content

Commit ed9f473

Browse files
fix+test(model): fix SWA mask type, unsqueeze to 4D, combine with padding mask
- Override attn_mask_type to arbitrary for SWA layers so FusedScaleMaskSoftmax routes through ScaledMaskedSoftmax (causal type silently ignores the mask arg) - Unsqueeze get_swa() output [sq, sk] to [1, 1, sq, sk] for the fused CUDA kernel - OR-combine SWA mask with incoming padding mask instead of replacing it - Add TestGemma2DotProductAttention with 7 tests covering all fixes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
1 parent 99f88e4 commit ed9f473

2 files changed

Lines changed: 116 additions & 12 deletions

File tree

src/megatron/bridge/models/gemma/gemma2_provider.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,14 @@ def __init__(
8686
if self.layer_number % 2 == 0:
8787
self.window_size = config.window_size
8888

89-
self.attn_mask_type = attn_mask_type
9089
self.attention_type = attention_type # unused for now
90+
# SWA layers generate an external mask via get_swa() in forward(). With
91+
# AttnMaskType.causal, FusedScaleMaskSoftmax always takes the fused upper-
92+
# triangular causal kernel (ScaledUpperTriangMaskedSoftmax) which never reads
93+
# the mask argument, silently dropping the SWA mask. Switching to arbitrary
94+
# for SWA layers routes through ScaledMaskedSoftmax, which applies the mask.
95+
# Odd-numbered layers remain causal and keep the fast fused causal path.
96+
self.attn_mask_type = AttnMaskType.arbitrary if self.window_size is not None else attn_mask_type
9197

9298
projection_size = self.config.kv_channels * self.config.num_attention_heads
9399

@@ -200,9 +206,13 @@ def forward(
200206
# Attention probs and dropout
201207
# ===========================
202208

203-
# sliding window attention
209+
# sliding window attention: combine SWA mask with any incoming padding mask.
210+
# Both use True=masked-out; logical OR gives the union of masked positions.
211+
# get_swa() returns [sq, sk]; a padding mask is typically [b, 1, sq, sk] —
212+
# PyTorch broadcasts [sq, sk] to [b, 1, sq, sk] correctly under |.
204213
if self.window_size is not None:
205-
attention_mask = get_swa(query.size(0), key.size(0), self.window_size)
214+
swa_mask = get_swa(query.size(0), key.size(0), self.window_size)
215+
attention_mask = swa_mask if attention_mask is None else (swa_mask | attention_mask)
206216

207217
# attention scores and attention mask [b, np, sq, sk]
208218
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)

tests/unit_tests/models/gemma/test_gemma2_provider.py

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def _make_attention(context_parallel_size: int = 1, window_size: tuple = (4095,
217217
config.attention_softmax_in_fp32 = True
218218
config.attention_dropout = 0.0
219219
config.sequence_parallel = False
220+
config.attn_logit_softcapping = 0.0 # disable softcapping in unit tests
220221
return Gemma2DotProductAttention(
221222
config=config,
222223
layer_number=2, # even layer → SWA active
@@ -247,16 +248,43 @@ def test_packed_seq_raises_value_error(self):
247248
)
248249

249250
def test_swa_applied_when_attention_mask_is_none(self):
250-
"""SWA mask must be generated even when attention_mask=None (the pretrain path)."""
251+
"""SWA mask must be generated even when attention_mask=None (the pretrain path).
252+
253+
Prior to the fix, the gate was:
254+
if attention_mask is not None and self.window_size is not None:
255+
which was never True on the pretrain path (MCore passes attention_mask=None).
256+
After the fix the gate is:
257+
if self.window_size is not None:
258+
We verify this by patching get_swa and confirming it is called from forward()
259+
when attention_mask=None is passed to an even-numbered layer.
260+
"""
251261
attn = _make_attention(window_size=(4095, 0))
252-
# Even layer → self.window_size is set; odd layer → None
253-
assert attn.window_size == (4095, 0)
254-
255-
# get_swa should produce a boolean mask of the correct shape
256-
seq_len = 8
257-
mask = get_swa(seq_len, seq_len, (4095, 0))
258-
assert mask.shape == (seq_len, seq_len)
259-
assert mask.dtype == torch.bool
262+
assert attn.window_size == (4095, 0), "even layer must have window_size set"
263+
264+
sentinel = Mock(name="swa_mask")
265+
with patch(
266+
"megatron.bridge.models.gemma.gemma2_provider.get_swa",
267+
return_value=sentinel,
268+
) as mock_get_swa:
269+
# scale_mask_softmax is called with the mask; stub it so forward() completes.
270+
attn.scale_mask_softmax = Mock(return_value=torch.zeros(1, 8, 4, 4))
271+
attn.attention_dropout = torch.nn.Identity()
272+
seq, batch, heads, head_dim = 4, 1, 8, 32
273+
q = torch.zeros(seq, batch, heads, head_dim)
274+
k = torch.zeros(seq, batch, heads, head_dim)
275+
v = torch.zeros(seq, batch, heads, head_dim)
276+
with patch("megatron.bridge.models.gemma.gemma2_provider.parallel_state") as mock_ps, \
277+
patch("megatron.bridge.models.gemma.gemma2_provider.tensor_parallel") as mock_tp:
278+
buf = torch.zeros(batch * heads, seq, seq)
279+
mock_ps.get_global_memory_buffer.return_value.get_tensor.return_value = buf
280+
mock_tp.get_cuda_rng_tracker.return_value.fork.return_value.__enter__ = lambda s: None
281+
mock_tp.get_cuda_rng_tracker.return_value.fork.return_value.__exit__ = Mock(return_value=False)
282+
attn.forward(query=q, key=k, value=v, attention_mask=None)
283+
284+
mock_get_swa.assert_called_once_with(seq, seq, (4095, 0))
285+
# The SWA mask (our sentinel) must have been passed to scale_mask_softmax.
286+
call_args = attn.scale_mask_softmax.call_args
287+
assert call_args[0][1] is sentinel, "scale_mask_softmax must receive the SWA mask"
260288

261289
def test_odd_layer_has_no_swa(self):
262290
"""Odd-numbered layers must not have a window_size (full attention)."""
@@ -282,6 +310,72 @@ def test_odd_layer_has_no_swa(self):
282310
attention_type="self",
283311
)
284312
assert odd_attn.window_size is None
313+
assert odd_attn.attn_mask_type == AttnMaskType.causal
314+
assert odd_attn.scale_mask_softmax.attn_mask_type == AttnMaskType.causal
315+
316+
def test_swa_layer_uses_arbitrary_mask_type(self):
317+
"""Even-numbered (SWA) layers must override attn_mask_type to arbitrary.
318+
319+
FusedScaleMaskSoftmax with AttnMaskType.causal takes the ScaledUpperTriangMaskedSoftmax
320+
path which silently ignores the mask argument. Switching to arbitrary routes through
321+
ScaledMaskedSoftmax, which correctly applies the externally generated SWA mask.
322+
Odd-numbered layers must keep AttnMaskType.causal to retain the fast fused path.
323+
"""
324+
even_attn = _make_attention(window_size=(4095, 0)) # layer_number=2 (even)
325+
assert even_attn.attn_mask_type == AttnMaskType.arbitrary, (
326+
"SWA layers must use AttnMaskType.arbitrary so FusedScaleMaskSoftmax "
327+
"routes through ScaledMaskedSoftmax and applies the mask"
328+
)
329+
# Also verify the FusedScaleMaskSoftmax instance stored the right type
330+
assert even_attn.scale_mask_softmax.attn_mask_type == AttnMaskType.arbitrary
331+
332+
def test_swa_combined_with_padding_mask(self):
333+
"""When a padding mask is present, forward() must OR it with the SWA mask.
334+
335+
Prior to this fix, the forward() code was:
336+
attention_mask = get_swa(...)
337+
which silently discarded any incoming padding mask. The correct behaviour is:
338+
attention_mask = swa_mask if attention_mask is None else (swa_mask | attention_mask)
339+
Both masks use True=masked-out, so logical OR gives the union of blocked positions.
340+
"""
341+
attn = _make_attention(window_size=(4095, 0))
342+
343+
seq, batch, heads, head_dim = 4, 2, 8, 32
344+
# Padding mask [b, 1, sq, sk]: block last key-position for the first sample only.
345+
padding_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.bool)
346+
padding_mask[0, 0, :, -1] = True
347+
348+
# SWA mask returned by the patched get_swa (all-zeros for simplicity).
349+
swa_mask_val = torch.zeros(seq, seq, dtype=torch.bool)
350+
351+
captured: dict = {}
352+
353+
def fake_softmax(scores, mask):
354+
captured["mask"] = mask
355+
return torch.zeros(batch, heads, seq, seq)
356+
357+
with patch(
358+
"megatron.bridge.models.gemma.gemma2_provider.get_swa",
359+
return_value=swa_mask_val,
360+
) as mock_get_swa:
361+
attn.scale_mask_softmax = fake_softmax
362+
attn.attention_dropout = torch.nn.Identity()
363+
q = torch.zeros(seq, batch, heads, head_dim)
364+
k = torch.zeros(seq, batch, heads, head_dim)
365+
v = torch.zeros(seq, batch, heads, head_dim)
366+
with patch("megatron.bridge.models.gemma.gemma2_provider.parallel_state") as mock_ps, \
367+
patch("megatron.bridge.models.gemma.gemma2_provider.tensor_parallel") as mock_tp:
368+
buf = torch.zeros(batch * heads, seq, seq)
369+
mock_ps.get_global_memory_buffer.return_value.get_tensor.return_value = buf
370+
mock_tp.get_cuda_rng_tracker.return_value.fork.return_value.__enter__ = lambda s: None
371+
mock_tp.get_cuda_rng_tracker.return_value.fork.return_value.__exit__ = Mock(return_value=False)
372+
attn.forward(query=q, key=k, value=v, attention_mask=padding_mask)
373+
374+
mock_get_swa.assert_called_once_with(seq, seq, (4095, 0))
375+
expected = swa_mask_val | padding_mask # [sq, sk] | [b, 1, sq, sk] → [b, 1, sq, sk]
376+
assert torch.equal(captured["mask"], expected), (
377+
"scale_mask_softmax must receive swa_mask | padding_mask, not just swa_mask"
378+
)
285379

286380
def test_window_size_default_is_4095(self):
287381
"""Gemma2ModelProvider.window_size default must be (4095, 0) to match gemma2_bridge convention."""

0 commit comments

Comments
 (0)