@@ -217,6 +217,7 @@ def _make_attention(context_parallel_size: int = 1, window_size: tuple = (4095,
217217 config .attention_softmax_in_fp32 = True
218218 config .attention_dropout = 0.0
219219 config .sequence_parallel = False
220+ config .attn_logit_softcapping = 0.0 # disable softcapping in unit tests
220221 return Gemma2DotProductAttention (
221222 config = config ,
222223 layer_number = 2 , # even layer → SWA active
@@ -247,16 +248,43 @@ def test_packed_seq_raises_value_error(self):
247248 )
248249
249250 def test_swa_applied_when_attention_mask_is_none (self ):
250- """SWA mask must be generated even when attention_mask=None (the pretrain path)."""
251+ """SWA mask must be generated even when attention_mask=None (the pretrain path).
252+
253+ Prior to the fix, the gate was:
254+ if attention_mask is not None and self.window_size is not None:
255+ which was never True on the pretrain path (MCore passes attention_mask=None).
256+ After the fix the gate is:
257+ if self.window_size is not None:
258+ We verify this by patching get_swa and confirming it is called from forward()
259+ when attention_mask=None is passed to an even-numbered layer.
260+ """
251261 attn = _make_attention (window_size = (4095 , 0 ))
252- # Even layer → self.window_size is set; odd layer → None
253- assert attn .window_size == (4095 , 0 )
254-
255- # get_swa should produce a boolean mask of the correct shape
256- seq_len = 8
257- mask = get_swa (seq_len , seq_len , (4095 , 0 ))
258- assert mask .shape == (seq_len , seq_len )
259- assert mask .dtype == torch .bool
262+ assert attn .window_size == (4095 , 0 ), "even layer must have window_size set"
263+
264+ sentinel = Mock (name = "swa_mask" )
265+ with patch (
266+ "megatron.bridge.models.gemma.gemma2_provider.get_swa" ,
267+ return_value = sentinel ,
268+ ) as mock_get_swa :
269+ # scale_mask_softmax is called with the mask; stub it so forward() completes.
270+ attn .scale_mask_softmax = Mock (return_value = torch .zeros (1 , 8 , 4 , 4 ))
271+ attn .attention_dropout = torch .nn .Identity ()
272+ seq , batch , heads , head_dim = 4 , 1 , 8 , 32
273+ q = torch .zeros (seq , batch , heads , head_dim )
274+ k = torch .zeros (seq , batch , heads , head_dim )
275+ v = torch .zeros (seq , batch , heads , head_dim )
276+ with patch ("megatron.bridge.models.gemma.gemma2_provider.parallel_state" ) as mock_ps , \
277+ patch ("megatron.bridge.models.gemma.gemma2_provider.tensor_parallel" ) as mock_tp :
278+ buf = torch .zeros (batch * heads , seq , seq )
279+ mock_ps .get_global_memory_buffer .return_value .get_tensor .return_value = buf
280+ mock_tp .get_cuda_rng_tracker .return_value .fork .return_value .__enter__ = lambda s : None
281+ mock_tp .get_cuda_rng_tracker .return_value .fork .return_value .__exit__ = Mock (return_value = False )
282+ attn .forward (query = q , key = k , value = v , attention_mask = None )
283+
284+ mock_get_swa .assert_called_once_with (seq , seq , (4095 , 0 ))
285+ # The SWA mask (our sentinel) must have been passed to scale_mask_softmax.
286+ call_args = attn .scale_mask_softmax .call_args
287+ assert call_args [0 ][1 ] is sentinel , "scale_mask_softmax must receive the SWA mask"
260288
261289 def test_odd_layer_has_no_swa (self ):
262290 """Odd-numbered layers must not have a window_size (full attention)."""
@@ -282,6 +310,72 @@ def test_odd_layer_has_no_swa(self):
282310 attention_type = "self" ,
283311 )
284312 assert odd_attn .window_size is None
313+ assert odd_attn .attn_mask_type == AttnMaskType .causal
314+ assert odd_attn .scale_mask_softmax .attn_mask_type == AttnMaskType .causal
315+
316+ def test_swa_layer_uses_arbitrary_mask_type (self ):
317+ """Even-numbered (SWA) layers must override attn_mask_type to arbitrary.
318+
319+ FusedScaleMaskSoftmax with AttnMaskType.causal takes the ScaledUpperTriangMaskedSoftmax
320+ path which silently ignores the mask argument. Switching to arbitrary routes through
321+ ScaledMaskedSoftmax, which correctly applies the externally generated SWA mask.
322+ Odd-numbered layers must keep AttnMaskType.causal to retain the fast fused path.
323+ """
324+ even_attn = _make_attention (window_size = (4095 , 0 )) # layer_number=2 (even)
325+ assert even_attn .attn_mask_type == AttnMaskType .arbitrary , (
326+ "SWA layers must use AttnMaskType.arbitrary so FusedScaleMaskSoftmax "
327+ "routes through ScaledMaskedSoftmax and applies the mask"
328+ )
329+ # Also verify the FusedScaleMaskSoftmax instance stored the right type
330+ assert even_attn .scale_mask_softmax .attn_mask_type == AttnMaskType .arbitrary
331+
332+ def test_swa_combined_with_padding_mask (self ):
333+ """When a padding mask is present, forward() must OR it with the SWA mask.
334+
335+ Prior to this fix, the forward() code was:
336+ attention_mask = get_swa(...)
337+ which silently discarded any incoming padding mask. The correct behaviour is:
338+ attention_mask = swa_mask if attention_mask is None else (swa_mask | attention_mask)
339+ Both masks use True=masked-out, so logical OR gives the union of blocked positions.
340+ """
341+ attn = _make_attention (window_size = (4095 , 0 ))
342+
343+ seq , batch , heads , head_dim = 4 , 2 , 8 , 32
344+ # Padding mask [b, 1, sq, sk]: block last key-position for the first sample only.
345+ padding_mask = torch .zeros (batch , 1 , seq , seq , dtype = torch .bool )
346+ padding_mask [0 , 0 , :, - 1 ] = True
347+
348+ # SWA mask returned by the patched get_swa (all-zeros for simplicity).
349+ swa_mask_val = torch .zeros (seq , seq , dtype = torch .bool )
350+
351+ captured : dict = {}
352+
353+ def fake_softmax (scores , mask ):
354+ captured ["mask" ] = mask
355+ return torch .zeros (batch , heads , seq , seq )
356+
357+ with patch (
358+ "megatron.bridge.models.gemma.gemma2_provider.get_swa" ,
359+ return_value = swa_mask_val ,
360+ ) as mock_get_swa :
361+ attn .scale_mask_softmax = fake_softmax
362+ attn .attention_dropout = torch .nn .Identity ()
363+ q = torch .zeros (seq , batch , heads , head_dim )
364+ k = torch .zeros (seq , batch , heads , head_dim )
365+ v = torch .zeros (seq , batch , heads , head_dim )
366+ with patch ("megatron.bridge.models.gemma.gemma2_provider.parallel_state" ) as mock_ps , \
367+ patch ("megatron.bridge.models.gemma.gemma2_provider.tensor_parallel" ) as mock_tp :
368+ buf = torch .zeros (batch * heads , seq , seq )
369+ mock_ps .get_global_memory_buffer .return_value .get_tensor .return_value = buf
370+ mock_tp .get_cuda_rng_tracker .return_value .fork .return_value .__enter__ = lambda s : None
371+ mock_tp .get_cuda_rng_tracker .return_value .fork .return_value .__exit__ = Mock (return_value = False )
372+ attn .forward (query = q , key = k , value = v , attention_mask = padding_mask )
373+
374+ mock_get_swa .assert_called_once_with (seq , seq , (4095 , 0 ))
375+ expected = swa_mask_val | padding_mask # [sq, sk] | [b, 1, sq, sk] → [b, 1, sq, sk]
376+ assert torch .equal (captured ["mask" ], expected ), (
377+ "scale_mask_softmax must receive swa_mask | padding_mask, not just swa_mask"
378+ )
285379
286380 def test_window_size_default_is_4095 (self ):
287381 """Gemma2ModelProvider.window_size default must be (4095, 0) to match gemma2_bridge convention."""
0 commit comments