Skip to content

Commit d5d0133

Browse files
feat(gemma2): add Gemma2FlexDotProductAttention with fused softcap+SWA
Uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into a single Triton kernel via score_mod and block_mask. Falls back to the unfused parent when a padding mask is present, dropout is active, or FlexAttention is unavailable (pretraining always takes the fused path). - _get_softcap_score_mod: lru_cache-decorated so all layers share one function object, avoiding N redundant torch.compile recompilations at startup - flex_attention wrapped with torch.compile to trigger the fused Triton path - Block mask cache keyed by (sq, sk) per instance to avoid rebuilding each step - gemma2_layer_spec updated to use Gemma2FlexDotProductAttention as core_attention - TestGemma2FlexDotProductAttention: 12 tests covering fused path, all fallback conditions, output shape, scale, score_mod correctness, lru_cache sharing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Nitin Vegesna <nvegesna@nvidia.com>
1 parent ed9f473 commit d5d0133

2 files changed

Lines changed: 496 additions & 18 deletions

File tree

src/megatron/bridge/models/gemma/gemma2_provider.py

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
import logging
1517
import math
1618
from dataclasses import dataclass
1719
from typing import Callable, Optional, Union
@@ -49,6 +51,43 @@
4951
from megatron.bridge.models.gpt_provider import GPTModelProvider
5052

5153

54+
logger = logging.getLogger(__name__)
55+
56+
_HAVE_FLEX_ATTN = False
57+
_flex_attn_func = None
58+
_create_flex_block_mask = None
59+
60+
try:
61+
from torch.nn.attention.flex_attention import create_block_mask as _flex_mask_candidate
62+
from torch.nn.attention.flex_attention import flex_attention as _flex_candidate
63+
64+
_flex_attn_func = torch.compile(_flex_candidate)
65+
_create_flex_block_mask = _flex_mask_candidate
66+
_HAVE_FLEX_ATTN = True
67+
logger.warning("Gemma2: PyTorch FlexAttention available — softcap+SWA fused via Triton score_mod.")
68+
del _flex_candidate, _flex_mask_candidate
69+
except ImportError:
70+
pass
71+
72+
if not _HAVE_FLEX_ATTN:
73+
logger.warning("Gemma2: FlexAttention not available — using unfused attention fallback.")
74+
75+
76+
@functools.lru_cache(maxsize=None)
77+
def _get_softcap_score_mod(softcap: float):
78+
"""Return a score_mod closure for the given softcap, cached so all layers share one object.
79+
80+
torch.compile guards on score_mod identity (id(fn)), so sharing one object across the
81+
N attention layers avoids N redundant Triton kernel recompilations at startup.
82+
"""
83+
84+
def _score_mod(score, b, h, q_idx, kv_idx):
85+
return softcap * torch.tanh(score / softcap)
86+
87+
_score_mod.__qualname__ = f"softcap_score_mod_{softcap}"
88+
return _score_mod
89+
90+
5291
class Gemma2DotProductAttention(MegatronModule):
5392
"""
5493
Region where selective activation recomputation is applied.
@@ -208,11 +247,20 @@ def forward(
208247

209248
# sliding window attention: combine SWA mask with any incoming padding mask.
210249
# Both use True=masked-out; logical OR gives the union of masked positions.
211-
# get_swa() returns [sq, sk]; a padding mask is typically [b, 1, sq, sk] —
212-
# PyTorch broadcasts [sq, sk] to [b, 1, sq, sk] correctly under |.
213-
if self.window_size is not None:
250+
# get_swa() returns [sq, sk]; the fused CUDA softmax kernel requires a 4D
251+
# mask [b, np, sq, sk], so we unsqueeze to [1, 1, sq, sk] when there is no
252+
# padding mask. When a padding mask [b, 1, sq, sk] is present, the | already
253+
# produces a 4D result via broadcasting.
254+
# Skip mask generation when the window fully covers the sequence: masking only
255+
# fires when query index i > window_size[0], i.e. seq_q > window_size[0] + 1.
256+
# For seq_length=4096 with window=4095 this is a no-op, so we stay on the
257+
# fast ScaledUpperTriangMaskedSoftmax (and FlexAttention) path.
258+
if self.window_size is not None and query.size(0) > self.window_size[0] + 1:
214259
swa_mask = get_swa(query.size(0), key.size(0), self.window_size)
215-
attention_mask = swa_mask if attention_mask is None else (swa_mask | attention_mask)
260+
if attention_mask is None:
261+
attention_mask = swa_mask.unsqueeze(0).unsqueeze(0)
262+
else:
263+
attention_mask = swa_mask | attention_mask
216264

217265
# attention scores and attention mask [b, np, sq, sk]
218266
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
@@ -262,6 +310,97 @@ def forward(
262310
return context
263311

264312

313+
class Gemma2FlexDotProductAttention(Gemma2DotProductAttention):
314+
"""Gemma2 fused attention with native softcap and sliding window support.
315+
316+
Uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into
317+
a single Triton kernel. Falls back to the unfused parent when a padding
318+
attention_mask is present (fine-tuning / variable-length batches) or when
319+
dropout is active. Pretraining always uses the fused path.
320+
"""
321+
322+
def __init__(
323+
self,
324+
config: TransformerConfig,
325+
layer_number: int,
326+
attn_mask_type: AttnMaskType,
327+
attention_type: str,
328+
attention_dropout: float = None,
329+
**kwargs,
330+
):
331+
super().__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout, **kwargs)
332+
# softcap passed directly to the fused kernel; avoids post-hoc tanh rescaling
333+
self.softcap = float(getattr(config, "attn_logit_softcapping", 0.0) or 0.0)
334+
# Gemma2 uses 1/sqrt(query_pre_attn_scalar=224), not 1/sqrt(head_dim) — must override
335+
self.softmax_scale = 1.0 / self.norm_factor
336+
self.dropout_p = config.attention_dropout if attention_dropout is None else attention_dropout
337+
# window_size for FlexAttention block_mask: (-1, -1) = full causal; (left, right) = SWA
338+
self._flex_window_size = (-1, -1) if self.window_size is None else (self.window_size[0], self.window_size[1])
339+
340+
if _HAVE_FLEX_ATTN:
341+
self._flex_score_mod = _get_softcap_score_mod(self.softcap)
342+
self._flex_block_mask_cache: dict = {}
343+
344+
def _build_flex_block_mask(self, sq: int, sk: int, device: torch.device):
345+
"""Build a FlexAttention block_mask encoding causal + optional SWA."""
346+
window_left = self._flex_window_size[0]
347+
if window_left < 0:
348+
349+
def _mask(b, h, q_idx, kv_idx):
350+
return q_idx >= kv_idx
351+
352+
else:
353+
w = window_left
354+
355+
def _mask(b, h, q_idx, kv_idx, _w=w):
356+
return (q_idx >= kv_idx) & (q_idx - kv_idx <= _w)
357+
358+
return _create_flex_block_mask(_mask, B=None, H=None, Q_LEN=sq, KV_LEN=sk, device=device)
359+
360+
def forward(
361+
self,
362+
query: Tensor,
363+
key: Tensor,
364+
value: Tensor,
365+
attention_mask: Tensor,
366+
attn_mask_type: AttnMaskType = None,
367+
packed_seq_params: PackedSeqParams = None,
368+
**kwargs,
369+
):
370+
"""Forward: FlexAttention fused path when possible, unfused fallback otherwise."""
371+
if packed_seq_params is not None:
372+
raise ValueError(
373+
"Packed sequence is not supported by DotProductAttention. Use TEDotProductAttention instead."
374+
)
375+
376+
dropout_p = self.dropout_p if self.training else 0.0
377+
fused_eligible = attention_mask is None and dropout_p == 0.0
378+
379+
if _HAVE_FLEX_ATTN and fused_eligible:
380+
# FlexAttention path — expects [b, np, sq, hn]
381+
sq, b, np_heads, hn = query.shape
382+
q = query.permute(1, 2, 0, 3)
383+
k = key.permute(1, 2, 0, 3)
384+
v = value.permute(1, 2, 0, 3)
385+
cache_key = (sq, key.size(0))
386+
if cache_key not in self._flex_block_mask_cache:
387+
self._flex_block_mask_cache[cache_key] = self._build_flex_block_mask(
388+
*cache_key, query.device
389+
)
390+
out = _flex_attn_func(
391+
q, k, v,
392+
score_mod=self._flex_score_mod,
393+
block_mask=self._flex_block_mask_cache[cache_key],
394+
scale=self.softmax_scale,
395+
enable_gqa=(k.size(1) != q.size(1)),
396+
)
397+
return out.permute(2, 0, 1, 3).contiguous().view(sq, b, np_heads * hn)
398+
399+
return super().forward(
400+
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=None, **kwargs
401+
)
402+
403+
265404
class TERowParallelLinearLayerNorm(TERowParallelLinear):
266405
"""Modified From TERowParallelLinear with an additional Post-LN."""
267406

@@ -326,7 +465,7 @@ def gemma2_layer_spec(config: "GPTModelProvider") -> ModuleSpec:
326465
params={"attn_mask_type": AttnMaskType.causal},
327466
submodules=SelfAttentionSubmodules(
328467
linear_qkv=TELayerNormColumnParallelLinear,
329-
core_attention=Gemma2DotProductAttention, # use unfused SDPA for attn logit softcapping
468+
core_attention=Gemma2FlexDotProductAttention, # FlexAttention fast path; falls back to unfused when unavailable
330469
linear_proj=TERowParallelLinearLayerNorm, # post attn RMSNorm
331470
),
332471
),

0 commit comments

Comments
 (0)