Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 162 additions & 12 deletions src/megatron/bridge/models/gemma/gemma2_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import logging
import math
from dataclasses import dataclass
from typing import Callable, Optional, Union
Expand Down Expand Up @@ -49,6 +51,43 @@
from megatron.bridge.models.gpt_provider import GPTModelProvider


logger = logging.getLogger(__name__)

_HAVE_FLEX_ATTN = False
_flex_attn_func = None
_create_flex_block_mask = None

try:
from torch.nn.attention.flex_attention import create_block_mask as _flex_mask_candidate
from torch.nn.attention.flex_attention import flex_attention as _flex_candidate

_flex_attn_func = torch.compile(_flex_candidate)
_create_flex_block_mask = _flex_mask_candidate
_HAVE_FLEX_ATTN = True
logger.warning("Gemma2: PyTorch FlexAttention available — softcap+SWA fused via Triton score_mod.")
del _flex_candidate, _flex_mask_candidate
except ImportError:
pass

if not _HAVE_FLEX_ATTN:
logger.warning("Gemma2: FlexAttention not available — using unfused attention fallback.")


@functools.lru_cache(maxsize=None)
def _get_softcap_score_mod(softcap: float):
"""Return a score_mod closure for the given softcap, cached so all layers share one object.

torch.compile guards on score_mod identity (id(fn)), so sharing one object across the
N attention layers avoids N redundant Triton kernel recompilations at startup.
"""

def _score_mod(score, b, h, q_idx, kv_idx):
return softcap * torch.tanh(score / softcap)

_score_mod.__qualname__ = f"softcap_score_mod_{softcap}"
return _score_mod


class Gemma2DotProductAttention(MegatronModule):
"""
Region where selective activation recomputation is applied.
Expand Down Expand Up @@ -77,18 +116,23 @@ def __init__(

self.config: TransformerConfig = config

assert self.config.context_parallel_size == 1, (
"Context parallelism is only supported by TEDotProductAttention!"
)
if self.config.context_parallel_size != 1:
raise ValueError("Context parallelism is only supported by TEDotProductAttention!")

self.layer_number = max(1, layer_number)

self.window_size = None
if self.layer_number % 2 == 0:
self.window_size = config.window_size

self.attn_mask_type = attn_mask_type
self.attention_type = attention_type # unused for now
# SWA layers generate an external mask via get_swa() in forward(). With
# AttnMaskType.causal, FusedScaleMaskSoftmax always takes the fused upper-
# triangular causal kernel (ScaledUpperTriangMaskedSoftmax) which never reads
# the mask argument, silently dropping the SWA mask. Switching to arbitrary
# for SWA layers routes through ScaledMaskedSoftmax, which applies the mask.
# Odd-numbered layers remain causal and keep the fast fused causal path.
self.attn_mask_type = AttnMaskType.arbitrary if self.window_size is not None else attn_mask_type

projection_size = self.config.kv_channels * self.config.num_attention_heads

Expand Down Expand Up @@ -137,9 +181,10 @@ def forward(
Modified from mcore.transformer.dot_product_attention to support Gemma2-specific
final_logit_softcapping.
"""
assert packed_seq_params is None, (
"Packed sequence is not supported by DotProductAttention.Please use TEDotProductAttention instead."
)
if packed_seq_params is not None:
raise ValueError(
"Packed sequence is not supported by DotProductAttention. Use TEDotProductAttention instead."
)

# ===================================
# Raw attention scores. [b, n/p, s, s]
Expand Down Expand Up @@ -200,9 +245,23 @@ def forward(
# Attention probs and dropout
# ===========================

# sliding window attention
if attention_mask is not None and self.window_size is not None:
attention_mask = get_swa(query.size(0), key.size(0), self.window_size)
# sliding window attention: combine SWA mask with any incoming padding mask.
# Both use True=masked-out; logical OR gives the union of masked positions.
# get_swa() returns [sq, sk]; the fused CUDA softmax kernel requires a 4D
# mask [b, np, sq, sk], so we unsqueeze to [1, 1, sq, sk] when there is no
# padding mask. When a padding mask [b, 1, sq, sk] is present, the | already
# produces a 4D result via broadcasting.
# The mask is always generated for SWA layers: attn_mask_type=arbitrary means
# FusedScaleMaskSoftmax routes through ScaledSoftmax (no causal masking) when
# mask=None, so omitting the mask for short sequences would drop causal masking
# entirely. get_swa() encodes causal structure via triu/tril and degenerates to
# a pure causal mask when the window fully covers the sequence.
if self.window_size is not None:
swa_mask = get_swa(query.size(0), key.size(0), self.window_size)
if attention_mask is None:
attention_mask = swa_mask.unsqueeze(0).unsqueeze(0)
else:
attention_mask = swa_mask | attention_mask

# attention scores and attention mask [b, np, sq, sk]
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
Expand Down Expand Up @@ -252,6 +311,97 @@ def forward(
return context


class Gemma2FlexDotProductAttention(Gemma2DotProductAttention):
"""Gemma2 fused attention with native softcap and sliding window support.

Uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into
a single Triton kernel. Falls back to the unfused parent when a padding
attention_mask is present (fine-tuning / variable-length batches) or when
dropout is active. Pretraining always uses the fused path.
"""

def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
**kwargs,
):
super().__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout, **kwargs)
# softcap passed directly to the fused kernel; avoids post-hoc tanh rescaling
self.softcap = float(getattr(config, "attn_logit_softcapping", 0.0) or 0.0)
# Gemma2 uses 1/sqrt(query_pre_attn_scalar=224), not 1/sqrt(head_dim) — must override
self.softmax_scale = 1.0 / self.norm_factor
self.dropout_p = config.attention_dropout if attention_dropout is None else attention_dropout
# window_size for FlexAttention block_mask: (-1, -1) = full causal; (left, right) = SWA
self._flex_window_size = (-1, -1) if self.window_size is None else (self.window_size[0], self.window_size[1])

if _HAVE_FLEX_ATTN:
self._flex_score_mod = _get_softcap_score_mod(self.softcap)
self._flex_block_mask_cache: dict = {}

def _build_flex_block_mask(self, sq: int, sk: int, device: torch.device):
"""Build a FlexAttention block_mask encoding causal + optional SWA."""
window_left = self._flex_window_size[0]
if window_left < 0:

def _mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

else:
w = window_left

def _mask(b, h, q_idx, kv_idx, _w=w):
return (q_idx >= kv_idx) & (q_idx - kv_idx <= _w)

return _create_flex_block_mask(_mask, B=None, H=None, Q_LEN=sq, KV_LEN=sk, device=device)

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType = None,
packed_seq_params: PackedSeqParams = None,
**kwargs,
):
"""Forward: FlexAttention fused path when possible, unfused fallback otherwise."""
if packed_seq_params is not None:
raise ValueError(
"Packed sequence is not supported by DotProductAttention. Use TEDotProductAttention instead."
)

dropout_p = self.dropout_p if self.training else 0.0
fused_eligible = attention_mask is None and dropout_p == 0.0

if _HAVE_FLEX_ATTN and fused_eligible:
# FlexAttention path — expects [b, np, sq, hn]
sq, b, np_heads, hn = query.shape
q = query.permute(1, 2, 0, 3)
k = key.permute(1, 2, 0, 3)
v = value.permute(1, 2, 0, 3)
cache_key = (sq, key.size(0))
if cache_key not in self._flex_block_mask_cache:
self._flex_block_mask_cache[cache_key] = self._build_flex_block_mask(
*cache_key, query.device
)
out = _flex_attn_func(
q, k, v,
score_mod=self._flex_score_mod,
block_mask=self._flex_block_mask_cache[cache_key],
scale=self.softmax_scale,
enable_gqa=(k.size(1) != q.size(1)),
)
return out.permute(2, 0, 1, 3).contiguous().view(sq, b, np_heads * hn)

return super().forward(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=None, **kwargs
)


class TERowParallelLinearLayerNorm(TERowParallelLinear):
"""Modified From TERowParallelLinear with an additional Post-LN."""

Expand Down Expand Up @@ -316,7 +466,7 @@ def gemma2_layer_spec(config: "GPTModelProvider") -> ModuleSpec:
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=Gemma2DotProductAttention, # use unfused SDPA for attn logit softcapping
core_attention=Gemma2FlexDotProductAttention, # FlexAttention fast path; falls back to unfused when unavailable
linear_proj=TERowParallelLinearLayerNorm, # post attn RMSNorm
),
),
Expand Down Expand Up @@ -359,7 +509,7 @@ class Gemma2ModelProvider(GPTModelProvider):
layernorm_epsilon: float = 1e-6
rotary_base: float = 10000

window_size: tuple[int, int] = (4096, 0)
window_size: tuple[int, int] = (4095, 0)
vocab_size: int = 256000

transformer_layer_spec: Union[ModuleSpec, Callable[["GPTModelProvider"], ModuleSpec]] = gemma2_layer_spec
Expand Down
Loading
Loading