|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import functools |
| 16 | +import logging |
15 | 17 | import math |
16 | 18 | from dataclasses import dataclass |
17 | 19 | from typing import Callable, Optional, Union |
|
49 | 51 | from megatron.bridge.models.gpt_provider import GPTModelProvider |
50 | 52 |
|
51 | 53 |
|
| 54 | +logger = logging.getLogger(__name__) |
| 55 | + |
| 56 | +_HAVE_FLEX_ATTN = False |
| 57 | +_flex_attn_func = None |
| 58 | +_create_flex_block_mask = None |
| 59 | + |
| 60 | +try: |
| 61 | + from torch.nn.attention.flex_attention import create_block_mask as _flex_mask_candidate |
| 62 | + from torch.nn.attention.flex_attention import flex_attention as _flex_candidate |
| 63 | + |
| 64 | + _flex_attn_func = torch.compile(_flex_candidate) |
| 65 | + _create_flex_block_mask = _flex_mask_candidate |
| 66 | + _HAVE_FLEX_ATTN = True |
| 67 | + logger.warning("Gemma2: PyTorch FlexAttention available — softcap+SWA fused via Triton score_mod.") |
| 68 | + del _flex_candidate, _flex_mask_candidate |
| 69 | +except ImportError: |
| 70 | + pass |
| 71 | + |
| 72 | +if not _HAVE_FLEX_ATTN: |
| 73 | + logger.warning("Gemma2: FlexAttention not available — using unfused attention fallback.") |
| 74 | + |
| 75 | + |
| 76 | +@functools.lru_cache(maxsize=None) |
| 77 | +def _get_softcap_score_mod(softcap: float): |
| 78 | + """Return a score_mod closure for the given softcap, cached so all layers share one object. |
| 79 | +
|
| 80 | + torch.compile guards on score_mod identity (id(fn)), so sharing one object across the |
| 81 | + N attention layers avoids N redundant Triton kernel recompilations at startup. |
| 82 | + """ |
| 83 | + |
| 84 | + def _score_mod(score, b, h, q_idx, kv_idx): |
| 85 | + return softcap * torch.tanh(score / softcap) |
| 86 | + |
| 87 | + _score_mod.__qualname__ = f"softcap_score_mod_{softcap}" |
| 88 | + return _score_mod |
| 89 | + |
| 90 | + |
52 | 91 | class Gemma2DotProductAttention(MegatronModule): |
53 | 92 | """ |
54 | 93 | Region where selective activation recomputation is applied. |
@@ -208,11 +247,20 @@ def forward( |
208 | 247 |
|
209 | 248 | # sliding window attention: combine SWA mask with any incoming padding mask. |
210 | 249 | # Both use True=masked-out; logical OR gives the union of masked positions. |
211 | | - # get_swa() returns [sq, sk]; a padding mask is typically [b, 1, sq, sk] — |
212 | | - # PyTorch broadcasts [sq, sk] to [b, 1, sq, sk] correctly under |. |
213 | | - if self.window_size is not None: |
| 250 | + # get_swa() returns [sq, sk]; the fused CUDA softmax kernel requires a 4D |
| 251 | + # mask [b, np, sq, sk], so we unsqueeze to [1, 1, sq, sk] when there is no |
| 252 | + # padding mask. When a padding mask [b, 1, sq, sk] is present, the | already |
| 253 | + # produces a 4D result via broadcasting. |
| 254 | + # Skip mask generation when the window fully covers the sequence: masking only |
| 255 | + # fires when query index i > window_size[0], i.e. seq_q > window_size[0] + 1. |
| 256 | + # For seq_length=4096 with window=4095 this is a no-op, so we stay on the |
| 257 | + # fast ScaledUpperTriangMaskedSoftmax (and FlexAttention) path. |
| 258 | + if self.window_size is not None and query.size(0) > self.window_size[0] + 1: |
214 | 259 | swa_mask = get_swa(query.size(0), key.size(0), self.window_size) |
215 | | - attention_mask = swa_mask if attention_mask is None else (swa_mask | attention_mask) |
| 260 | + if attention_mask is None: |
| 261 | + attention_mask = swa_mask.unsqueeze(0).unsqueeze(0) |
| 262 | + else: |
| 263 | + attention_mask = swa_mask | attention_mask |
216 | 264 |
|
217 | 265 | # attention scores and attention mask [b, np, sq, sk] |
218 | 266 | attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) |
@@ -262,6 +310,97 @@ def forward( |
262 | 310 | return context |
263 | 311 |
|
264 | 312 |
|
| 313 | +class Gemma2FlexDotProductAttention(Gemma2DotProductAttention): |
| 314 | + """Gemma2 fused attention with native softcap and sliding window support. |
| 315 | +
|
| 316 | + Uses PyTorch FlexAttention (built-in, PyTorch 2.5+) to fuse softcap and SWA into |
| 317 | + a single Triton kernel. Falls back to the unfused parent when a padding |
| 318 | + attention_mask is present (fine-tuning / variable-length batches) or when |
| 319 | + dropout is active. Pretraining always uses the fused path. |
| 320 | + """ |
| 321 | + |
| 322 | + def __init__( |
| 323 | + self, |
| 324 | + config: TransformerConfig, |
| 325 | + layer_number: int, |
| 326 | + attn_mask_type: AttnMaskType, |
| 327 | + attention_type: str, |
| 328 | + attention_dropout: float = None, |
| 329 | + **kwargs, |
| 330 | + ): |
| 331 | + super().__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout, **kwargs) |
| 332 | + # softcap passed directly to the fused kernel; avoids post-hoc tanh rescaling |
| 333 | + self.softcap = float(getattr(config, "attn_logit_softcapping", 0.0) or 0.0) |
| 334 | + # Gemma2 uses 1/sqrt(query_pre_attn_scalar=224), not 1/sqrt(head_dim) — must override |
| 335 | + self.softmax_scale = 1.0 / self.norm_factor |
| 336 | + self.dropout_p = config.attention_dropout if attention_dropout is None else attention_dropout |
| 337 | + # window_size for FlexAttention block_mask: (-1, -1) = full causal; (left, right) = SWA |
| 338 | + self._flex_window_size = (-1, -1) if self.window_size is None else (self.window_size[0], self.window_size[1]) |
| 339 | + |
| 340 | + if _HAVE_FLEX_ATTN: |
| 341 | + self._flex_score_mod = _get_softcap_score_mod(self.softcap) |
| 342 | + self._flex_block_mask_cache: dict = {} |
| 343 | + |
| 344 | + def _build_flex_block_mask(self, sq: int, sk: int, device: torch.device): |
| 345 | + """Build a FlexAttention block_mask encoding causal + optional SWA.""" |
| 346 | + window_left = self._flex_window_size[0] |
| 347 | + if window_left < 0: |
| 348 | + |
| 349 | + def _mask(b, h, q_idx, kv_idx): |
| 350 | + return q_idx >= kv_idx |
| 351 | + |
| 352 | + else: |
| 353 | + w = window_left |
| 354 | + |
| 355 | + def _mask(b, h, q_idx, kv_idx, _w=w): |
| 356 | + return (q_idx >= kv_idx) & (q_idx - kv_idx <= _w) |
| 357 | + |
| 358 | + return _create_flex_block_mask(_mask, B=None, H=None, Q_LEN=sq, KV_LEN=sk, device=device) |
| 359 | + |
| 360 | + def forward( |
| 361 | + self, |
| 362 | + query: Tensor, |
| 363 | + key: Tensor, |
| 364 | + value: Tensor, |
| 365 | + attention_mask: Tensor, |
| 366 | + attn_mask_type: AttnMaskType = None, |
| 367 | + packed_seq_params: PackedSeqParams = None, |
| 368 | + **kwargs, |
| 369 | + ): |
| 370 | + """Forward: FlexAttention fused path when possible, unfused fallback otherwise.""" |
| 371 | + if packed_seq_params is not None: |
| 372 | + raise ValueError( |
| 373 | + "Packed sequence is not supported by DotProductAttention. Use TEDotProductAttention instead." |
| 374 | + ) |
| 375 | + |
| 376 | + dropout_p = self.dropout_p if self.training else 0.0 |
| 377 | + fused_eligible = attention_mask is None and dropout_p == 0.0 |
| 378 | + |
| 379 | + if _HAVE_FLEX_ATTN and fused_eligible: |
| 380 | + # FlexAttention path — expects [b, np, sq, hn] |
| 381 | + sq, b, np_heads, hn = query.shape |
| 382 | + q = query.permute(1, 2, 0, 3) |
| 383 | + k = key.permute(1, 2, 0, 3) |
| 384 | + v = value.permute(1, 2, 0, 3) |
| 385 | + cache_key = (sq, key.size(0)) |
| 386 | + if cache_key not in self._flex_block_mask_cache: |
| 387 | + self._flex_block_mask_cache[cache_key] = self._build_flex_block_mask( |
| 388 | + *cache_key, query.device |
| 389 | + ) |
| 390 | + out = _flex_attn_func( |
| 391 | + q, k, v, |
| 392 | + score_mod=self._flex_score_mod, |
| 393 | + block_mask=self._flex_block_mask_cache[cache_key], |
| 394 | + scale=self.softmax_scale, |
| 395 | + enable_gqa=(k.size(1) != q.size(1)), |
| 396 | + ) |
| 397 | + return out.permute(2, 0, 1, 3).contiguous().view(sq, b, np_heads * hn) |
| 398 | + |
| 399 | + return super().forward( |
| 400 | + query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=None, **kwargs |
| 401 | + ) |
| 402 | + |
| 403 | + |
265 | 404 | class TERowParallelLinearLayerNorm(TERowParallelLinear): |
266 | 405 | """Modified From TERowParallelLinear with an additional Post-LN.""" |
267 | 406 |
|
@@ -326,7 +465,7 @@ def gemma2_layer_spec(config: "GPTModelProvider") -> ModuleSpec: |
326 | 465 | params={"attn_mask_type": AttnMaskType.causal}, |
327 | 466 | submodules=SelfAttentionSubmodules( |
328 | 467 | linear_qkv=TELayerNormColumnParallelLinear, |
329 | | - core_attention=Gemma2DotProductAttention, # use unfused SDPA for attn logit softcapping |
| 468 | + core_attention=Gemma2FlexDotProductAttention, # FlexAttention fast path; falls back to unfused when unavailable |
330 | 469 | linear_proj=TERowParallelLinearLayerNorm, # post attn RMSNorm |
331 | 470 | ), |
332 | 471 | ), |
|
0 commit comments