diff --git a/test/prototype/attention/test_fp8_attention.py b/test/prototype/attention/test_fp8_attention.py index f494de07fa..8d8062022d 100644 --- a/test/prototype/attention/test_fp8_attention.py +++ b/test/prototype/attention/test_fp8_attention.py @@ -30,7 +30,29 @@ AttentionBackend, apply_low_precision_attention, ) - from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa + from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_rope_sdpa, + fp8_fa3_sdpa, + ) + + +def _rope_cos_sin(S, D, device): + freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D)) + angles = torch.outer(torch.arange(S, dtype=torch.float32), freqs) + cos_half = torch.cos(angles) + sin_half = torch.sin(angles) + cos = torch.cat([cos_half, cos_half], dim=-1).to(device) + sin = torch.cat([sin_half, sin_half], dim=-1).to(device) + return cos, sin + + +def _apply_rope(x, cos, sin): + """NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D].""" + D_HALF = x.shape[-1] // 2 + rotate = torch.cat([-x[..., D_HALF:], x[..., :D_HALF]], dim=-1) + return ( + x * cos.unsqueeze(0).unsqueeze(2) + rotate * sin.unsqueeze(0).unsqueeze(2) + ).to(x.dtype) class SimpleAttentionModel(nn.Module): @@ -52,6 +74,30 @@ def forward(self, x): return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) +class SimpleRoPEAttentionModel(nn.Module): + """Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA).""" + + def __init__(self, embed_dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + def forward(self, x, cos, sin): + B, S, _ = x.shape + q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim) + k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim) + v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim) + q = _apply_rope(q, cos, sin).transpose(1, 2) + k = _apply_rope(k, cos, sin).transpose(1, 2) + v = v.transpose(1, 2) + attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) + + @common_utils.instantiate_parametrized_tests class TestFP8FA3Attention(TestCase): @unittest.skipUnless( @@ -83,6 +129,41 @@ def test_sdpa_accuracy(self, shape, dtype): f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}", ) + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("shape", [(2, 1024, 8, 64), (1, 1024, 16, 128)]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_rope_sdpa_accuracy(self, shape, dtype): + B, S, H, D = shape + q = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + k = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + v = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + cos, sin = _rope_cos_sin(S, D, "cuda") + + with torch.no_grad(): + out_ref = F.scaled_dot_product_attention( + _apply_rope(q, cos, sin).transpose(1, 2), + _apply_rope(k, cos, sin).transpose(1, 2), + v.transpose(1, 2), + is_causal=False, + ) + + activate_flash_attention_impl("FA3") + try: + with torch.no_grad(): + out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False) + finally: + restore_flash_attention_impl() + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 25.0, + f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}", + ) + @unittest.skipUnless( torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), "Requires PyTorch >= 2.11, Hopper GPU, and FA3", @@ -122,6 +203,48 @@ def test_monkey_patch_model(self, dtype): f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}", ) + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_rope_fusion_model(self, dtype): + embed_dim, num_heads = 512, 8 + model = ( + SimpleRoPEAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + S = 128 + x = torch.randn(2, S, embed_dim, device="cuda", dtype=dtype) + cos, sin = _rope_cos_sin(S, embed_dim // num_heads, "cuda") + + with torch.no_grad(): + out_ref = model(x, cos, sin) + + fp8_model = ( + SimpleRoPEAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + fp8_model.load_state_dict(model.state_dict()) + fp8_model = apply_low_precision_attention( + fp8_model, + backend=AttentionBackend.FP8_FA3, + fuse_rope_using_torch_compile=True, + ) + fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend) + + with torch.no_grad(): + out_fp8 = fp8_model(x, cos, sin) + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 20.0, + f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}", + ) + if __name__ == "__main__": run_tests() diff --git a/torchao/prototype/attention/api.py b/torchao/prototype/attention/api.py index 8bbe8c9a34..3879029b04 100644 --- a/torchao/prototype/attention/api.py +++ b/torchao/prototype/attention/api.py @@ -67,6 +67,13 @@ def apply_low_precision_attention( Must be called before ``torch.compile``. KV caching should be disabled before calling (e.g., ``config.use_cache = False`` for HuggingFace models). + + When ``fuse_rope_using_torch_compile=True``, the returned wrapper + exposes a ``compile_backend`` attribute. You must compile with it to get + the RoPE fusion:: + + model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True) + model = torch.compile(model, backend=model.compile_backend) """ if isinstance(model, _LowPrecisionAttentionWrapper): raise RuntimeError( diff --git a/torchao/prototype/attention/fp8_fa3/__init__.py b/torchao/prototype/attention/fp8_fa3/__init__.py index c0e3007c4e..4a4487f043 100644 --- a/torchao/prototype/attention/fp8_fa3/__init__.py +++ b/torchao/prototype/attention/fp8_fa3/__init__.py @@ -4,12 +4,18 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -"""FP8 attention using FA3 backend.""" +""" +FP8 attention using FA3 backend. +""" -from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa +from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_rope_sdpa, + fp8_fa3_sdpa, +) from torchao.prototype.attention.quantization import _fp8_sdpa_quantize __all__ = [ "fp8_fa3_sdpa", + "fp8_fa3_rope_sdpa", "_fp8_sdpa_quantize", ] diff --git a/torchao/prototype/attention/fp8_fa3/attention.py b/torchao/prototype/attention/fp8_fa3/attention.py index eee06ed77f..bcec1e985b 100644 --- a/torchao/prototype/attention/fp8_fa3/attention.py +++ b/torchao/prototype/attention/fp8_fa3/attention.py @@ -4,15 +4,23 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -"""FP8 SDPA using FA3 backend. +""" +FP8 SDPA using FA3 backend. + +When using these functions directly (not through apply_low_precision_attention), +you must activate FA3 yourself:: -Thin wrapper around ``shared_utils/attention.py``. When using directly, -activate the FA3 flash attention implementation before calling. + activate_flash_attention_impl("FA3") + try: + out = fp8_fa3_sdpa(q, k, v, is_causal=True) + finally: + restore_flash_attention_impl() """ from functools import partial from torchao.prototype.attention.shared_utils.attention import ( + _fp8_rope_sdpa, _fp8_sdpa, ) @@ -20,3 +28,8 @@ fp8_fa3_sdpa.__doc__ = _fp8_sdpa.__doc__ fp8_fa3_sdpa.__name__ = "fp8_fa3_sdpa" fp8_fa3_sdpa.__qualname__ = "fp8_fa3_sdpa" + +fp8_fa3_rope_sdpa = partial(_fp8_rope_sdpa, backend_name="FA3") +fp8_fa3_rope_sdpa.__doc__ = _fp8_rope_sdpa.__doc__ +fp8_fa3_rope_sdpa.__name__ = "fp8_fa3_rope_sdpa" +fp8_fa3_rope_sdpa.__qualname__ = "fp8_fa3_rope_sdpa" diff --git a/torchao/prototype/attention/fp8_fa3/fusion_pass.py b/torchao/prototype/attention/fp8_fa3/fusion_pass.py new file mode 100644 index 0000000000..37bb6b07df --- /dev/null +++ b/torchao/prototype/attention/fp8_fa3/fusion_pass.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_rope_sdpa, + fp8_fa3_sdpa, +) +from torchao.prototype.attention.shared_utils.custom_ops import ( + make_backend_fn, + register_fp8_attention_ops, +) + +_ops = register_fp8_attention_ops( + backend_name="fa3", + rope_sdpa_fn=fp8_fa3_rope_sdpa, + sdpa_fn=fp8_fa3_sdpa, +) + +make_fp8_backend = make_backend_fn(_ops, backend_name="FA3", flash_impl_name="FA3") diff --git a/torchao/prototype/attention/quantization/__init__.py b/torchao/prototype/attention/quantization/__init__.py index 05bb8bcb74..945a52bfc1 100644 --- a/torchao/prototype/attention/quantization/__init__.py +++ b/torchao/prototype/attention/quantization/__init__.py @@ -4,12 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -"""Shared FP8 quantization kernels for low-precision attention.""" - -from torchao.prototype.attention.quantization.quantization import ( - _fp8_sdpa_quantize, +from torchao.prototype.attention.quantization.triton_qkv_quantization import ( + triton_fp8_sdpa_quantize as _fp8_sdpa_quantize, +) +from torchao.prototype.attention.quantization.triton_rope_qkv_quantization import ( + triton_fp8_rope_sdpa_quantize as _fp8_rope_sdpa_quantize, ) __all__ = [ "_fp8_sdpa_quantize", + "_fp8_rope_sdpa_quantize", ] diff --git a/torchao/prototype/attention/quantization/triton_rope_qkv_quantization.py b/torchao/prototype/attention/quantization/triton_rope_qkv_quantization.py new file mode 100644 index 0000000000..6ec95dd5af --- /dev/null +++ b/torchao/prototype/attention/quantization/triton_rope_qkv_quantization.py @@ -0,0 +1,745 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Fused RoPE + FP8 quantization kernels for Q, K, V. + +Input: [B, S, H, D], output: [B, H, S, D]. +Supports GQA (different head counts for Q vs K/V). +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +def _compute_num_chunks(tensor: torch.Tensor, S: int) -> int: + """Compute optimal number of chunks based on GPU properties.""" + props = torch.cuda.get_device_properties(tensor.device) + num_sms = props.multi_processor_count + B, _, H, _ = tensor.shape # [B, S, H, D] + base_parallelism = B * H + # Target 2-4x SMs for good occupancy/latency hiding + target_blocks = num_sms * 4 + num_chunks = max(1, target_blocks // base_parallelism) + # Ensure each chunk has at least 32 S positions for efficiency + num_chunks = min(num_chunks, S // 32) if S >= 32 else 1 + # Cap at reasonable maximum + num_chunks = min(num_chunks, 64) + # Adjust if S is small + num_chunks = min(num_chunks, S) + return num_chunks + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 256}, num_warps=4), + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + ], + key=["chunk_size", "D_HALF"], +) +@triton.jit +def rope_single_phase1_kernel( + # Input tensor [B, S, H, D] + x_ptr, + # RoPE frequency tensors [S, D] + cos_ptr, + sin_ptr, + # Intermediate output tensor [B, H, S, D] - stores RoPE'd result + x_rope_ptr, + # Output: partial max values [B * H * num_chunks] + partial_max_ptr, + # Input strides (for [B, S, H, D] layout) + stride_in_b, + stride_in_s, + stride_in_h, + stride_in_d, + # Output strides (for [B, H, S, D] layout) + stride_out_b, + stride_out_h, + stride_out_s, + stride_out_d, + # Dimensions + S, + D, + D_HALF, + H, + chunk_size, + num_chunks, + # Block size + BLOCK_SIZE: tl.constexpr, + ROPE_INTERLEAVED: tl.constexpr, +): + """ + Phase 1 for a single tensor (Q or K): Apply RoPE, store to intermediate, + compute partial max. + + Grid: (B, H, num_chunks) + + Supports two RoPE pairing variants (selected by ROPE_INTERLEAVED constexpr): + - NeoX half-split (ROPE_INTERLEAVED=False): pairs (j, j+D/2) for j in [0, D/2) + - Interleaved (ROPE_INTERLEAVED=True): pairs (2i, 2i+1) for i in [0, D/2) + + Each pair shares the same rotation angle and the 2D rotation formula is: + out[first] = x[first]*cos - x[second]*sin + out[second] = x[second]*cos + x[first]*sin + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + actual_chunk_size = s_end - s_start + + # Number of pairs to process in this chunk + chunk_pairs = actual_chunk_size * D_HALF + + # Base pointers for input [B, S, H, D] + in_base_b = pid_b * stride_in_b + in_base_h = pid_h * stride_in_h + + # Base pointers for output [B, H, S, D] + out_base_b = pid_b * stride_out_b + out_base_h = pid_h * stride_out_h + + # Initialize max accumulator + x_max = 0.0 + + # Linearized iteration over chunk_size * D_HALF pairs + for block_start in range(0, chunk_pairs, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_pairs + + # Convert linear offset to (s, pair_idx) coordinates + local_s = offs // D_HALF + pair_idx = offs % D_HALF + s_idx = s_start + local_s + + # Compute element indices for this pair based on RoPE variant + if ROPE_INTERLEAVED: + # FLUX/GPT-J interleaved: pair (2i, 2i+1) + d_first = pair_idx * 2 + d_second = pair_idx * 2 + 1 + else: + # NeoX/LLaMA half-split: pair (j, j+D/2) + d_first = pair_idx + d_second = pair_idx + D_HALF + + # Input offsets [B, S, H, D] + in_offset_first = ( + in_base_b + s_idx * stride_in_s + in_base_h + d_first * stride_in_d + ) + in_offset_second = ( + in_base_b + s_idx * stride_in_s + in_base_h + d_second * stride_in_d + ) + + # Output offsets [B, H, S, D] + out_offset_first = ( + out_base_b + out_base_h + s_idx * stride_out_s + d_first * stride_out_d + ) + out_offset_second = ( + out_base_b + out_base_h + s_idx * stride_out_s + d_second * stride_out_d + ) + + # Load input pairs (each element loaded exactly once) + x_first = tl.load(x_ptr + in_offset_first, mask=mask, other=0.0).to(tl.float32) + x_second = tl.load(x_ptr + in_offset_second, mask=mask, other=0.0).to( + tl.float32 + ) + + # Load cos/sin — both elements in a NeoX pair share the same + # rotation angle. cos[j] == cos[j+D/2] in LLaMA's frequency + # layout, so we only need one load per pair. + cos_offset = s_idx * D + d_first + cos_val = tl.load(cos_ptr + cos_offset, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + cos_offset, mask=mask, other=0.0).to(tl.float32) + + # Apply NeoX RoPE rotation: + # out[j] = x[j] * cos - x[j+D/2] * sin + # out[j+D/2] = x[j+D/2] * cos + x[j] * sin + x_rope_first = tl.math.fma(x_first, cos_val, -(x_second * sin_val)) + x_rope_second = tl.math.fma(x_second, cos_val, x_first * sin_val) + + # Store RoPE'd result to intermediate buffer [B, H, S, D] + tl.store( + x_rope_ptr + out_offset_first, x_rope_first.to(x_first.dtype), mask=mask + ) + tl.store( + x_rope_ptr + out_offset_second, x_rope_second.to(x_first.dtype), mask=mask + ) + + # Update max values (from RoPE'd values) + x_max = tl.maximum(x_max, tl.max(tl.abs(x_rope_first))) + x_max = tl.maximum(x_max, tl.max(tl.abs(x_rope_second))) + + # Store partial max + chunk_idx = pid_b * (H * num_chunks) + pid_h * num_chunks + pid_chunk + tl.store(partial_max_ptr + chunk_idx, x_max) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def v_phase1_kernel( + # Input tensor [B, S, H, D] + v_ptr, + # Output: partial max values [B * H * num_chunks] + partial_max_ptr, + # Input strides (for [B, S, H, D] layout) + stride_in_b, + stride_in_s, + stride_in_h, + stride_in_d, + # Dimensions + S, + D, + H, + chunk_size, + num_chunks, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 1 for V: Compute partial absmax (no RoPE applied). + + Grid: (B, H, num_chunks) + + Uses linearized iteration over chunk_size * D elements. + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointers for input [B, S, H, D] + in_base_b = pid_b * stride_in_b + in_base_h = pid_h * stride_in_h + + # Initialize max accumulator + v_max = 0.0 + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + # Input offset [B, S, H, D] + in_offset = in_base_b + s_idx * stride_in_s + in_base_h + d_idx * stride_in_d + + v_val = tl.load(v_ptr + in_offset, mask=mask, other=0.0).to(tl.float32) + v_max = tl.maximum(v_max, tl.max(tl.abs(v_val))) + + # Store partial max + chunk_idx = pid_b * (H * num_chunks) + pid_h * num_chunks + pid_chunk + tl.store(partial_max_ptr + chunk_idx, v_max) + + +@triton.jit +def single_reduce_kernel( + partial_max_ptr, # [B * H * num_chunks] + scale_ptr, + descale_ptr, + H, + num_chunks, +): + """ + Reduce partial maxes and compute scale/descale for a single tensor. + + Grid: (B, H) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + # Reduce across chunks for this (batch, head) + x_max = 0.0 + + base_idx = (pid_b * H + pid_h) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + # Compute scale and descale + # FP8 E4M3 max value is 448.0 + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H + pid_h + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.jit +def group_reduce_kernel( + partial_max_ptr, # [B * H_q * num_chunks] + scale_ptr, # [B, H_kv] + descale_ptr, # [B, H_kv] + H_q, + H_kv, + groups, # H_q // H_kv + num_chunks, +): + """ + Reduce partial maxes across head groups for GQA Q tensor. + + For each KV group, reduces the max across all Q heads in that group + and all chunks, producing one scale per (batch, kv_head). + + Grid: (B, H_kv) + """ + pid_b = tl.program_id(axis=0) + pid_hkv = tl.program_id(axis=1) + + x_max = 0.0 + + for g in range(groups): + h_q = pid_hkv * groups + g + base_idx = (pid_b * H_q + h_q) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H_kv + pid_hkv + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def rope_single_phase2_kernel( + # Intermediate tensor [B, H, S, D] - already RoPE'd + x_rope_ptr, + # Output tensor [B, H, S, D] - FP8 quantized + x_out_ptr, + # Precomputed scale [B, H_scale] + scale_ptr, + # Strides (for [B, H, S, D] layout) - same for intermediate and output + stride_b, + stride_h, + stride_s, + stride_d, + # Dimensions + S, + D, + H, + chunk_size, + # Scale indexing for GQA: scale has H_scale entries per batch, + # and each group of `groups` heads shares one scale. + # For non-GQA: H_scale = H, groups = 1. + H_scale, + groups, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 2 for a single tensor (Q or K): Quantize pre-computed RoPE'd values to FP8. + + Grid: (B, H, num_chunks) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Load scale for this head (or head group for GQA) + scale = tl.load(scale_ptr + pid_b * H_scale + pid_h // groups) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointer + base_offset = pid_b * stride_b + pid_h * stride_h + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + ptr_offset = base_offset + s_idx * stride_s + d_idx * stride_d + + # Load pre-computed RoPE'd value from intermediate buffer + x_val = tl.load(x_rope_ptr + ptr_offset, mask=mask, other=0.0).to(tl.float32) + + # Quantize to FP8 + x_fp8 = (x_val * scale).to(tl.float8e4nv) + + # Store to output + tl.store(x_out_ptr + ptr_offset, x_fp8, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def v_phase2_kernel( + # Original V tensor [B, S, H, D] - needs transpose only + v_ptr, + # Output tensor [B, H, S, D] - FP8 quantized + v_out_ptr, + # Precomputed scale [B, H] + scale_ptr, + # V input strides (for [B, S, H, D] layout) + stride_v_in_b, + stride_v_in_s, + stride_v_in_h, + stride_v_in_d, + # Output strides (for [B, H, S, D] layout) + stride_out_b, + stride_out_h, + stride_out_s, + stride_out_d, + # Dimensions + S, + D, + H, + chunk_size, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 2 for V: Transpose from [B,S,H,D] to [B,H,S,D] and quantize to FP8. + + Grid: (B, H, num_chunks) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Load scale for this head + scale = tl.load(scale_ptr + pid_b * H + pid_h) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointers + v_in_base = pid_b * stride_v_in_b + pid_h * stride_v_in_h + out_base = pid_b * stride_out_b + pid_h * stride_out_h + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + # V input offset [B, S, H, D] + v_in_offset = v_in_base + s_idx * stride_v_in_s + d_idx * stride_v_in_d + + # Output offset [B, H, S, D] + out_offset = out_base + s_idx * stride_out_s + d_idx * stride_out_d + + # Load V from original input (no RoPE, just transpose) + v_val = tl.load(v_ptr + v_in_offset, mask=mask, other=0.0).to(tl.float32) + + # Quantize to FP8 + v_fp8 = (v_val * scale).to(tl.float8e4nv) + + # Store to output + tl.store(v_out_ptr + out_offset, v_fp8, mask=mask) + + +def triton_fp8_rope_sdpa_quantize( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + num_chunks: Optional[int] = None, + rope_interleaved: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """ + Separated RoPE + FP8 quantization for Q, K, V tensors. + + Applies RoPE to Q and K, then quantizes all tensors to FP8 with per-head scaling. + Also performs layout transformation from [B, S, H, D] to [B, H, S, D]. + Each of Q, K, V is processed with independent kernel launches. + + Supports GQA where Q has more heads than K/V (H_q = groups * H_kv). + For GQA, Q is quantized with per-KV-group scaling so that q_descale + has shape [B, H_kv] as required by FA3. + + Args: + q: Query tensor of shape [B, S, H_q, D] in bf16/fp16 + k: Key tensor of shape [B, S, H_kv, D] in bf16/fp16 + v: Value tensor of shape [B, S, H_kv, D] in bf16/fp16 + cos: Cosine frequencies for RoPE, shape [S, D] + sin: Sine frequencies for RoPE, shape [S, D] + num_chunks: Number of chunks to split S dimension into. + If None, automatically selects based on GPU SM count. + + Returns: + q_fp8: Quantized query with RoPE, shape [B, H_q, S, D] in fp8 + k_fp8: Quantized key with RoPE, shape [B, H_kv, S, D] in fp8 + v_fp8: Quantized value, shape [B, H_kv, S, D] in fp8 + q_descale: Query descale factors, shape [B, H_kv] in fp32 + k_descale: Key descale factors, shape [B, H_kv] in fp32 + v_descale: Value descale factors, shape [B, H_kv] in fp32 + """ + assert q.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {q.dim()}D" + assert k.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {k.dim()}D" + assert v.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {v.dim()}D" + assert k.shape == v.shape, ( + f"K and V must have the same shape, got {k.shape} vs {v.shape}" + ) + assert q.shape[0] == k.shape[0], ( + f"Batch size mismatch: {q.shape[0]} vs {k.shape[0]}" + ) + assert q.shape[1] == k.shape[1], ( + f"Sequence length mismatch: {q.shape[1]} vs {k.shape[1]}" + ) + assert q.shape[3] == k.shape[3], f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}" + assert q.shape[2] % k.shape[2] == 0, ( + f"Q heads ({q.shape[2]}) must be a multiple of K heads ({k.shape[2]})" + ) + assert cos.dim() == 2, f"Expected 2D cos tensor [S, D], got {cos.dim()}D" + assert sin.dim() == 2, f"Expected 2D sin tensor [S, D], got {sin.dim()}D" + + B, S, H_q, D = q.shape + H_kv = k.shape[2] + groups = H_q // H_kv + + assert D % 2 == 0, f"Head dimension D must be even for RoPE, got D={D}" + assert cos.shape == (S, D), f"Expected cos shape [{S}, {D}], got {cos.shape}" + assert sin.shape == (S, D), f"Expected sin shape [{S}, {D}], got {sin.shape}" + + D_HALF = D // 2 + + # Make tensors contiguous if needed + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + # Compute number of chunks + if num_chunks is None: + num_chunks = _compute_num_chunks(q, S) + chunk_size = (S + num_chunks - 1) // num_chunks + + # Allocate output tensors in [B, H, S, D] layout for SDPA + q_fp8 = torch.empty(B, H_q, S, D, dtype=torch.float8_e4m3fn, device=q.device) + k_fp8 = torch.empty(B, H_kv, S, D, dtype=torch.float8_e4m3fn, device=q.device) + v_fp8 = torch.empty(B, H_kv, S, D, dtype=torch.float8_e4m3fn, device=q.device) + + # Allocate intermediate buffers for RoPE'd Q, K in [B, H, S, D] layout + q_rope_intermediate = torch.empty(B, H_q, S, D, dtype=q.dtype, device=q.device) + k_rope_intermediate = torch.empty(B, H_kv, S, D, dtype=k.dtype, device=q.device) + + # Allocate partial max buffers (one per tensor) + q_partial_max = torch.empty( + B * H_q * num_chunks, dtype=torch.float32, device=q.device + ) + k_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + v_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + + # Allocate scale/descale tensors. + # For GQA, Q scale/descale are [B, H_kv] (per KV group). + # K and V are always [B, H_kv] (per head). + q_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + q_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + + q_grid_chunked = (B, H_q, num_chunks) + kv_grid_chunked = (B, H_kv, num_chunks) + + # ---- Phase 1: RoPE + max for Q ---- + rope_single_phase1_kernel[q_grid_chunked]( + q, + cos, + sin, + q_rope_intermediate, + q_partial_max, + # Input strides [B, S, H_q, D] + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + # Output strides [B, H_q, S, D] + q_rope_intermediate.stride(0), + q_rope_intermediate.stride(1), + q_rope_intermediate.stride(2), + q_rope_intermediate.stride(3), + S, + D, + D_HALF, + H_q, + chunk_size, + num_chunks, + ROPE_INTERLEAVED=rope_interleaved, + ) + + # ---- Phase 1: RoPE + max for K ---- + rope_single_phase1_kernel[kv_grid_chunked]( + k, + cos, + sin, + k_rope_intermediate, + k_partial_max, + # Input strides [B, S, H_kv, D] + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + # Output strides [B, H_kv, S, D] + k_rope_intermediate.stride(0), + k_rope_intermediate.stride(1), + k_rope_intermediate.stride(2), + k_rope_intermediate.stride(3), + S, + D, + D_HALF, + H_kv, + chunk_size, + num_chunks, + ROPE_INTERLEAVED=rope_interleaved, + ) + + # ---- Phase 1: Max for V (no RoPE) ---- + v_phase1_kernel[kv_grid_chunked]( + v, + v_partial_max, + # Input strides [B, S, H_kv, D] + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + S, + D, + H_kv, + chunk_size, + num_chunks, + ) + + # ---- Reduce ---- + # Q: group reduce across `groups` Q heads per KV head + group_reduce_kernel[(B, H_kv)]( + q_partial_max, q_scale, q_descale, H_q, H_kv, groups, num_chunks + ) + # K, V: per-head reduce + single_reduce_kernel[(B, H_kv)](k_partial_max, k_scale, k_descale, H_kv, num_chunks) + single_reduce_kernel[(B, H_kv)](v_partial_max, v_scale, v_descale, H_kv, num_chunks) + + # ---- Phase 2: Quantize Q from intermediate ---- + # Q scale is [B, H_kv]; each group of `groups` Q heads shares one scale. + rope_single_phase2_kernel[q_grid_chunked]( + q_rope_intermediate, + q_fp8, + q_scale, + # Strides [B, H_q, S, D] + q_fp8.stride(0), + q_fp8.stride(1), + q_fp8.stride(2), + q_fp8.stride(3), + S, + D, + H_q, + chunk_size, + H_kv, + groups, + ) + + # ---- Phase 2: Quantize K from intermediate ---- + # K scale is [B, H_kv]; groups=1 (per-head). + rope_single_phase2_kernel[kv_grid_chunked]( + k_rope_intermediate, + k_fp8, + k_scale, + # Strides [B, H_kv, S, D] + k_fp8.stride(0), + k_fp8.stride(1), + k_fp8.stride(2), + k_fp8.stride(3), + S, + D, + H_kv, + chunk_size, + H_kv, + 1, + ) + + # ---- Phase 2: Transpose + quantize V ---- + v_phase2_kernel[kv_grid_chunked]( + v, + v_fp8, + v_scale, + # V input strides [B, S, H_kv, D] + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + # Output strides [B, H_kv, S, D] + v_fp8.stride(0), + v_fp8.stride(1), + v_fp8.stride(2), + v_fp8.stride(3), + S, + D, + H_kv, + chunk_size, + ) + + return q_fp8, k_fp8, v_fp8, q_descale, k_descale, v_descale diff --git a/torchao/prototype/attention/shared_utils/attention.py b/torchao/prototype/attention/shared_utils/attention.py index 15361b8e37..248668639a 100644 --- a/torchao/prototype/attention/shared_utils/attention.py +++ b/torchao/prototype/attention/shared_utils/attention.py @@ -4,7 +4,12 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -"""Shared FP8 scaled dot-product attention implementation.""" +""" +Shared FP8 scaled dot-product attention implementation. + +Backend-specific modules (``fp8_fa3/attention.py``, etc.) provide thin +named wrappers around these functions via ``functools.partial``. +""" from typing import Optional @@ -21,6 +26,7 @@ ) from torchao.prototype.attention.quantization import ( + _fp8_rope_sdpa_quantize, _fp8_sdpa_quantize, ) @@ -37,10 +43,13 @@ def _fp8_sdpa( *, backend_name: str = "FP8", ) -> torch.Tensor: - """FP8 SDPA implementation shared by all backends. + """FP8 SDPA shared by all backends. + + The correct flash attention implementation (e.g. FA3) must be + activated before calling this function. The high-level + ``apply_low_precision_attention`` API handles this automatically. - Quantizes Q, K, V to FP8 with per-head scaling, then calls - ``_scaled_dot_product_attention_quantized`` under ``SDPBackend.FLASH_ATTENTION``. + Input/output layout: [B, H, S, D]. """ if not _TORCH_VERSION_AT_LEAST_2_11: raise RuntimeError("Low-precision attention requires PyTorch 2.11+.") @@ -53,15 +62,6 @@ def _fp8_sdpa( input_dtype = query.dtype - # Ensure Triton kernels launch on the correct device. - # In the monkey-patch path (no torch.compile), accelerate's hooks move - # tensors to the correct device but don't call torch.cuda.set_device(). - # Triton dispatches based on current_device(), not tensor device, so - # without this guard the kernel launches on the wrong GPU's stream. - _prev_device = torch.cuda.current_device() - if query.device.index is not None and query.device.index != _prev_device: - torch.cuda.set_device(query.device) - q_fp8, k_fp8, v_fp8, descale_q, descale_k, descale_v = _fp8_sdpa_quantize( query, key, value ) @@ -78,8 +78,58 @@ def _fp8_sdpa( v_descale=descale_v, ) - # Restore previous device to avoid side effects on the caller. - if query.device.index is not None and query.device.index != _prev_device: - torch.cuda.set_device(_prev_device) + return out.to(input_dtype) + + +def _fp8_rope_sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + rope_interleaved: bool = False, + *, + backend_name: str = "FP8", +) -> torch.Tensor: + """Fused RoPE + FP8 SDPA shared by all backends. + + Input layout: [B, S, H, D] (pre-transpose). The fused quantization + kernel handles the transpose to [B, H, S, D] internally. + Output layout: [B, H, S, D]. + """ + if not _TORCH_VERSION_AT_LEAST_2_11: + raise RuntimeError("Low-precision attention requires PyTorch 2.11+.") + if attn_mask is not None: + raise ValueError(f"attn_mask not supported for FP8 {backend_name}") + if dropout_p != 0.0: + raise ValueError( + f"dropout_p must be 0.0 for FP8 {backend_name}, got {dropout_p}" + ) + + input_dtype = query.dtype + + cos = cos.to(query.device) + sin = sin.to(query.device) + + q_fp8, k_fp8, v_fp8, descale_q, descale_k, descale_v = _fp8_rope_sdpa_quantize( + query, key, value, cos, sin, rope_interleaved=rope_interleaved + ) + + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + out = _scaled_dot_product_attention_quantized( + q_fp8, + k_fp8, + v_fp8, + is_causal=is_causal, + scale=scale, + q_descale=descale_q, + k_descale=descale_k, + v_descale=descale_v, + ) return out.to(input_dtype) diff --git a/torchao/prototype/attention/shared_utils/custom_ops.py b/torchao/prototype/attention/shared_utils/custom_ops.py new file mode 100644 index 0000000000..c9bea64b1c --- /dev/null +++ b/torchao/prototype/attention/shared_utils/custom_ops.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared custom op registration and compile helpers for FP8 attention backends. + +Custom ops are needed because our FP8 SDPA functions call Triton kernels +that are not traceable by torch.compile. Registering them as custom_ops +tells the compiler to treat them as opaque nodes with known shapes/dtypes. +""" + +from functools import partial +from typing import Callable, NamedTuple + +import torch +import torch._inductor.config as inductor_config +import torch.nn as nn + +from torchao.prototype.attention.shared_utils.fusion_utils import ( + detect_causal_mask, +) +from torchao.prototype.attention.shared_utils.fusion_utils import ( + rope_sdpa_fusion_pass as _shared_fusion_pass, +) + + +class RegisteredOps(NamedTuple): + rope_sdpa_op: object + fp8_sdpa_op: object + + +def register_fp8_attention_ops( + backend_name: str, + rope_sdpa_fn: Callable, + sdpa_fn: Callable, +) -> RegisteredOps: + """Register the RoPE+SDPA and plain SDPA custom ops for a backend.""" + backend = backend_name.lower() + + rope_op_name = f"torchao::fp8_{backend}_rope_sdpa" + + @torch.library.custom_op(rope_op_name, mutates_args=()) + def _rope_sdpa_custom_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + rope_interleaved: bool = False, + ) -> torch.Tensor: + actual_scale = scale if scale != 0.0 else None + return rope_sdpa_fn( + q, + k, + v, + cos, + sin, + is_causal=is_causal, + scale=actual_scale, + enable_gqa=enable_gqa, + rope_interleaved=rope_interleaved, + ) + + @_rope_sdpa_custom_op.register_fake + def _rope_sdpa_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + rope_interleaved: bool = False, + ) -> torch.Tensor: + B, S, H, D = q.shape + return torch.empty(B, H, S, D, dtype=q.dtype, device=q.device) + + sdpa_op_name = f"torchao::fp8_{backend}_sdpa" + + @torch.library.custom_op(sdpa_op_name, mutates_args=()) + def _sdpa_custom_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + ) -> torch.Tensor: + actual_scale = scale if scale != 0.0 else None + return sdpa_fn( + q, + k, + v, + is_causal=is_causal, + scale=actual_scale, + enable_gqa=enable_gqa, + ) + + @_sdpa_custom_op.register_fake + def _sdpa_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + ) -> torch.Tensor: + return torch.empty(q.shape, dtype=q.dtype, device=q.device) + + rope_sdpa_op = getattr( + getattr(torch.ops, "torchao"), f"fp8_{backend}_rope_sdpa" + ).default + fp8_sdpa_op = getattr(getattr(torch.ops, "torchao"), f"fp8_{backend}_sdpa").default + + return RegisteredOps(rope_sdpa_op=rope_sdpa_op, fp8_sdpa_op=fp8_sdpa_op) + + +def make_backend_fn( + ops: RegisteredOps, + backend_name: str, + flash_impl_name: str, + max_head_dim: int = 256, +) -> Callable: + """Return a ``make_fp8_backend(model, fuse_rope_using_torch_compile)`` function for a backend.""" + + def make_fp8_backend( + model: nn.Module, + fuse_rope_using_torch_compile: bool, + ) -> Callable: + from torch._inductor.compile_fx import compile_fx + + strip_causal_mask = detect_causal_mask(model, flash_impl_name=flash_impl_name) + + pass_fn = partial( + _shared_fusion_pass, + rope_sdpa_op=ops.rope_sdpa_op, + fp8_sdpa_op=ops.fp8_sdpa_op, + max_head_dim=max_head_dim, + backend_name=backend_name, + fuse_rope=fuse_rope_using_torch_compile, + strip_causal_mask=strip_causal_mask, + ) + + def fp8_attention_backend(gm, example_inputs): + old_pass = inductor_config.pre_grad_custom_pass + inductor_config.pre_grad_custom_pass = pass_fn + try: + return compile_fx(gm, example_inputs) + finally: + inductor_config.pre_grad_custom_pass = old_pass + + return fp8_attention_backend + + return make_fp8_backend diff --git a/torchao/prototype/attention/shared_utils/fusion_utils.py b/torchao/prototype/attention/shared_utils/fusion_utils.py new file mode 100644 index 0000000000..cf1e85f2a1 --- /dev/null +++ b/torchao/prototype/attention/shared_utils/fusion_utils.py @@ -0,0 +1,977 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared FX graph pattern detection and fusion utilities for low-precision attention. +""" + +import logging +import operator +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.fx import Graph, Node + +logger = logging.getLogger(__name__) + + +@dataclass +class RoPEMatch: + """Result of detecting a RoPE pattern on a tensor.""" + + pre_rope_input: Node # tensor before RoPE: "x" in "x * cos + rotate_half(x) * sin" + cos_node: Node # cosine frequencies, traced back to [S, D] source + sin_node: Node # sine frequencies, traced back to [S, D] source + rope_interleaved: bool # True = FLUX interleaved, False = NeoX half-split + + +# FX Node Utilities + + +def _is_op(node: Node, *targets) -> bool: + """Check if an FX node matches one of the given targets.""" + if node.op in ("call_function", "call_method"): + return node.target in targets + return False + + +def _get_fake_tensor(node: Node) -> Optional[torch.Tensor]: + """Get the FakeTensor metadata from a node (pre-grad or post-grad).""" + for key in ("val", "example_value"): + if key in node.meta: + val = node.meta[key] + if isinstance(val, torch.Tensor): + return val + return None + + +def _get_node_shape(node: Node) -> Optional[Tuple[int, ...]]: + fake = _get_fake_tensor(node) + if fake is not None: + return tuple(fake.shape) + return None + + +def _reshape_cos_sin_to_2d( + graph: Graph, + cos_node: Node, + sin_node: Node, + insert_before: Node, +) -> Optional[Tuple[Node, Node]]: + """Reshape cos/sin nodes to 2D [S, D] if they have leading size-1 dims. + + HuggingFace models produce cos/sin with shape [B, S, D] or [1, 1, S, D]. + """ + cos_shape = _get_node_shape(cos_node) + sin_shape = _get_node_shape(sin_node) + + if cos_shape is None or sin_shape is None: + return cos_node, sin_node + + if len(cos_shape) == 2 and len(sin_shape) == 2: + return cos_node, sin_node + + for name, shape in [("cos", cos_shape), ("sin", sin_shape)]: + if len(shape) < 2: + logger.debug("RoPE %s has fewer than 2 dims: shape=%s", name, shape) + return None + for dim in shape[:-2]: + if dim != 1: + logger.debug( + "RoPE %s has non-unit leading dim: shape=%s", + name, + shape, + ) + return None + + s, d = cos_shape[-2], cos_shape[-1] + with graph.inserting_before(insert_before): + cos_2d = graph.call_function( + torch.ops.aten.view.default, + args=(cos_node, [s, d]), + ) + sin_2d = graph.call_function( + torch.ops.aten.view.default, + args=(sin_node, [s, d]), + ) + return cos_2d, sin_2d + + +def _trace_through_views(node: Node) -> Node: + """Trace backward through view-like ops (unsqueeze, expand, clone, to.dtype, etc.).""" + current = node + while isinstance(current, Node): + if current.op == "call_function" and current.target in ( + torch.ops.aten.unsqueeze.default, + torch.ops.aten.clone.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.expand.default, + torch.ops.aten.to.dtype, + torch.ops.aten._to_copy.default, + ): + current = current.args[0] + elif current.op == "call_method" and current.target in ( + "unsqueeze", + "clone", + "contiguous", + "expand", + "to", + "float", + "half", + "bfloat16", + ): + current = current.args[0] + elif ( + current.op == "call_function" + and current.target is operator.getitem + and len(current.args) >= 2 + and isinstance(current.args[1], tuple) + and all(i is None or i == slice(None) for i in current.args[1]) + ): + current = current.args[0] + else: + break + return current + + +# Transpose Detection + + +def _unwrap_transpose(node: Node) -> Optional[Node]: + """If node is transpose(1,2) or permute([0,2,1,3]), return its input. + + Also looks through contiguous()/clone() wrappers. + """ + if not isinstance(node, Node): + return None + + current = node + while isinstance(current, Node) and _is_op( + current, + torch.ops.aten.contiguous.default, + torch.ops.aten.clone.default, + "contiguous", + "clone", + ): + current = current.args[0] + + if not isinstance(current, Node): + return None + + # aten.transpose.int(tensor, 1, 2) + if _is_op(current, torch.ops.aten.transpose.int): + if len(current.args) >= 3: + dim0, dim1 = current.args[1], current.args[2] + if (dim0 == 1 and dim1 == 2) or (dim0 == 2 and dim1 == 1): + return current.args[0] + + # aten.permute.default(tensor, [0, 2, 1, 3]) + if _is_op(current, torch.ops.aten.permute.default): + if len(current.args) >= 2: + perm = current.args[1] + if list(perm) == [0, 2, 1, 3]: + return current.args[0] + + # call_method transpose(1, 2) + if _is_op(current, "transpose"): + if len(current.args) >= 3: + dim0, dim1 = current.args[1], current.args[2] + if (dim0 == 1 and dim1 == 2) or (dim0 == 2 and dim1 == 1): + return current.args[0] + + # call_method permute(0, 2, 1, 3) + if _is_op(current, "permute"): + if len(current.args) >= 5: + perm = list(current.args[1:5]) + if perm == [0, 2, 1, 3]: + return current.args[0] + elif len(current.args) >= 2 and isinstance(current.args[1], (list, tuple)): + perm = current.args[1] + if list(perm) == [0, 2, 1, 3]: + return current.args[0] + + return None + + +def _unwrap_repeat_kv(node: Node) -> Optional[Node]: + """Unwrap HuggingFace's repeat_kv (GQA head repetition) pattern. + + Pattern: reshape <- expand <- getitem(unsqueeze). + Returns the pre-repetition tensor, or None. + """ + if not isinstance(node, Node): + return None + + if not _is_op( + node, + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + "reshape", + "view", + ): + return None + + inner = node.args[0] if node.args else None + if not isinstance(inner, Node): + return None + + if not _is_op(inner, torch.ops.aten.expand.default, "expand"): + return None + + inner2 = inner.args[0] if inner.args else None + if not isinstance(inner2, Node): + return None + + if _is_op(inner2, torch.ops.aten.unsqueeze.default, "unsqueeze"): + return inner2.args[0] if inner2.args else None + + if inner2.op == "call_function" and inner2.target is operator.getitem: + if len(inner2.args) >= 2 and isinstance(inner2.args[1], tuple): + if any(i is None for i in inner2.args[1]): + return inner2.args[0] if inner2.args else None + + return None + + +# SDPA Detection and Parameter Extraction + + +def _is_sdpa_node(node: Node) -> bool: + return _is_op( + node, + torch.ops.aten.scaled_dot_product_attention.default, + torch._C._nn.scaled_dot_product_attention, + ) + + +def _is_lower_triangular_bool_mask(mask: torch.Tensor) -> bool: + """Check if a tensor is a bool, square lower-triangular (causal) mask.""" + if mask.dtype != torch.bool or mask.ndim < 2: + return False + q_len, kv_len = mask.shape[-2], mask.shape[-1] + if q_len != kv_len: + return False + ref = torch.tril(torch.ones(q_len, kv_len, dtype=torch.bool, device=mask.device)) + return torch.equal(mask.broadcast_to(mask.shape), ref.expand_as(mask)) + + +def detect_causal_mask( + model: nn.Module, + sample_input_ids=None, + flash_impl_name: str | None = None, +) -> bool: + """Run one forward pass to detect whether the model uses causal masks. + + Returns True when every SDPA call used causal attention (either via + a materialized lower-triangular bool mask, or via is_causal=True). + """ + from torch.nn.attention import ( + activate_flash_attention_impl, + restore_flash_attention_impl, + ) + + try: + device = next(model.parameters()).device + except StopIteration: + return False + + if sample_input_ids is None: + vocab_size = getattr(getattr(model, "config", None), "vocab_size", None) + if vocab_size is None: + return False + sample_input_ids = torch.randint(0, vocab_size, (1, 16), device=device) + + all_causal: list[bool] = [] + saw_any_sdpa = False + + original_sdpa = F.scaled_dot_product_attention + + def _hook(*args, **kwargs): + nonlocal saw_any_sdpa + saw_any_sdpa = True + attn_mask = args[3] if len(args) > 3 else kwargs.get("attn_mask", None) + is_causal = kwargs.get("is_causal", False) if len(args) <= 5 else args[5] + + if attn_mask is not None and not is_causal: + all_causal.append(_is_lower_triangular_bool_mask(attn_mask)) + elif attn_mask is None and is_causal: + all_causal.append(True) + + return original_sdpa(*args, **kwargs) + + F.scaled_dot_product_attention = _hook + if flash_impl_name is not None: + activate_flash_attention_impl(flash_impl_name) + try: + with torch.no_grad(): + model(sample_input_ids) + except Exception: + logger.debug("detect_causal_mask: forward pass failed", exc_info=True) + return False + finally: + F.scaled_dot_product_attention = original_sdpa + if flash_impl_name is not None: + restore_flash_attention_impl() + + if not saw_any_sdpa: + return False + + return all(all_causal) + + +def _sdpa_is_fusible(node: Node, strip_causal_mask: bool = False) -> Tuple[bool, bool]: + """Check if an SDPA node is compatible with our FP8 fused kernel. + + Returns (is_fusible, needs_mask_strip). + """ + args = node.args + kwargs = node.kwargs + + attn_mask = args[3] if len(args) > 3 else kwargs.get("attn_mask", None) + is_causal = args[5] if len(args) > 5 else kwargs.get("is_causal", False) + + needs_mask_strip = False + if attn_mask is not None: + if not is_causal and strip_causal_mask and isinstance(attn_mask, Node): + needs_mask_strip = True + else: + return False, False + + dropout_p = args[4] if len(args) > 4 else kwargs.get("dropout_p", 0.0) + if dropout_p != 0.0: + return False, False + + return True, needs_mask_strip + + +def _strip_causal_mask(node: Node) -> None: + """Strip a materialized causal mask from an SDPA node.""" + args = list(node.args) + kwargs = dict(node.kwargs) + + if len(args) > 3: + args[3] = None + elif "attn_mask" in kwargs: + kwargs["attn_mask"] = None + + if len(args) > 5: + args[5] = True + elif "is_causal" in kwargs: + kwargs["is_causal"] = True + else: + kwargs["is_causal"] = True + + node.args = tuple(args) + node.kwargs = kwargs + + logger.info("Stripped causal mask from SDPA node: %s", node.name) + + +def _get_sdpa_params(node: Node) -> Tuple[bool, float, bool]: + """Extract is_causal, scale, and enable_gqa from an SDPA node. + + Scale uses 0.0 as sentinel for "default" (1/sqrt(D)). + """ + args = node.args + kwargs = node.kwargs + + is_causal = args[5] if len(args) > 5 else kwargs.get("is_causal", False) + scale = args[6] if len(args) > 6 else kwargs.get("scale", None) + enable_gqa = args[7] if len(args) > 7 else kwargs.get("enable_gqa", False) + + if scale is None: + scale = 0.0 + + return is_causal, scale, enable_gqa + + +def _get_sdpa_qkv(node: Node) -> Optional[Tuple[Node, Node, Node]]: + """Extract Q, K, V input nodes from an SDPA node.""" + args = node.args + kwargs = node.kwargs + + q = args[0] if len(args) > 0 else kwargs.get("query", None) + k = args[1] if len(args) > 1 else kwargs.get("key", None) + v = args[2] if len(args) > 2 else kwargs.get("value", None) + + if not all(isinstance(n, Node) for n in (q, k, v)): + return None + + return q, k, v + + +# NeoX/LLaMA RoPE Pattern Detection +# +# NeoX/LLaMA RoPE: +# rotate_half(x) = cat(-x[..., D//2:], x[..., :D//2], dim=-1) +# apply_rope(x, cos, sin) = x * cos + rotate_half(x) * sin + + +def _detect_rotate_half(cat_node: Node) -> Optional[Node]: + """Detect rotate_half(x) = cat(-x[..., D//2:], x[..., :D//2], dim=-1). + + Returns the source tensor x, or None. + """ + if not _is_op(cat_node, torch.ops.aten.cat.default, torch.cat): + return None + + if len(cat_node.args) < 1: + return None + + tensors_list = cat_node.args[0] + + if len(cat_node.args) >= 2: + cat_dim = cat_node.args[1] + else: + cat_dim = cat_node.kwargs.get("dim", 0) + + if cat_dim not in (-1, 3): + return None + + if not isinstance(tensors_list, (list, tuple)) or len(tensors_list) != 2: + return None + + neg_part = tensors_list[0] + pos_part = tensors_list[1] + + if not isinstance(neg_part, Node) or not isinstance(pos_part, Node): + return None + + if not _is_op(neg_part, torch.ops.aten.neg.default, operator.neg, torch.neg): + return None + + neg_input = neg_part.args[0] + if not isinstance(neg_input, Node): + return None + + return _match_rotate_half_slices(neg_input, pos_part) + + +def _match_rotate_half_slices(neg_input: Node, pos_part: Node) -> Optional[Node]: + """Match the slice patterns in rotate_half. Returns the source tensor x, or None.""" + # ATen slice pattern + if _is_op(neg_input, torch.ops.aten.slice.Tensor) and _is_op( + pos_part, torch.ops.aten.slice.Tensor + ): + slice_neg_source = neg_input.args[0] + slice_pos_source = pos_part.args[0] + + if slice_neg_source is not slice_pos_source: + return None + + slice_neg_dim = neg_input.args[1] if len(neg_input.args) > 1 else 0 + slice_pos_dim = pos_part.args[1] if len(pos_part.args) > 1 else 0 + + if slice_neg_dim not in (-1, 3) or slice_pos_dim not in (-1, 3): + return None + + pos_start = pos_part.args[2] if len(pos_part.args) > 2 else None + pos_end = pos_part.args[3] if len(pos_part.args) > 3 else None + neg_start = neg_input.args[2] if len(neg_input.args) > 2 else None + + if pos_start != 0: + return None + if neg_start is None or pos_end is None: + return None + if neg_start != pos_end: + return None + + return slice_neg_source + + # Dynamo getitem pattern + if _is_op(neg_input, operator.getitem) and _is_op(pos_part, operator.getitem): + slice_neg_source = neg_input.args[0] + slice_pos_source = pos_part.args[0] + + if slice_neg_source is not slice_pos_source: + return None + + neg_idx = neg_input.args[1] + pos_idx = pos_part.args[1] + + neg_slice = _extract_last_dim_slice(neg_idx) + pos_slice = _extract_last_dim_slice(pos_idx) + + if neg_slice is None or pos_slice is None: + return None + + if pos_slice.start not in (0, None): + return None + if pos_slice.stop is None: + return None + if neg_slice.start is None: + return None + if neg_slice.start != pos_slice.stop: + return None + + return slice_neg_source + + return None + + +def _extract_last_dim_slice(idx) -> Optional[slice]: + """Extract the slice on the last dimension from a getitem index.""" + if isinstance(idx, tuple): + if len(idx) >= 2 and idx[0] is Ellipsis and isinstance(idx[1], slice): + return idx[1] + if len(idx) >= 1 and isinstance(idx[-1], slice): + for i in range(len(idx) - 1): + if idx[i] is not Ellipsis and idx[i] != slice(None): + return None + return idx[-1] + elif isinstance(idx, slice): + return idx + return None + + +def _detect_interleaved_rotation(node: Node) -> Optional[Node]: + """Detect the FLUX-style interleaved rotation pattern. + + Pattern: x.reshape(..., -1, 2).unbind(-1) -> stack([-x_imag, x_real], dim=-1).flatten(3) + Returns the source tensor x, or None. + """ + if not _is_op(node, torch.ops.aten.flatten.using_ints, "flatten"): + return None + if len(node.args) < 1: + return None + + stack_node = node.args[0] + if not isinstance(stack_node, Node): + return None + + if not _is_op(stack_node, torch.ops.aten.stack.default, torch.stack): + return None + if len(stack_node.args) < 1: + return None + + tensors_list = stack_node.args[0] + if len(stack_node.args) >= 2: + stack_dim = stack_node.args[1] + else: + stack_dim = stack_node.kwargs.get("dim", 0) + if stack_dim != -1: + return None + + if not isinstance(tensors_list, (list, tuple)) or len(tensors_list) != 2: + return None + + neg_part = tensors_list[0] + pos_part = tensors_list[1] + if not isinstance(neg_part, Node) or not isinstance(pos_part, Node): + return None + + if not _is_op(neg_part, torch.ops.aten.neg.default, operator.neg, torch.neg): + return None + + x_imag = neg_part.args[0] + if not isinstance(x_imag, Node): + return None + + if not _is_op(pos_part, operator.getitem) or not _is_op(x_imag, operator.getitem): + return None + + real_source = pos_part.args[0] + imag_source = x_imag.args[0] + if real_source is not imag_source: + return None + if pos_part.args[1] != 0 or x_imag.args[1] != 1: + return None + + unbind_node = real_source + if not isinstance(unbind_node, Node): + return None + if not _is_op(unbind_node, torch.ops.aten.unbind.int, "unbind"): + return None + + if len(unbind_node.args) >= 2: + unbind_dim = unbind_node.args[1] + else: + unbind_dim = unbind_node.kwargs.get("dim", 0) + if unbind_dim != -1: + return None + + reshape_node = unbind_node.args[0] + if not isinstance(reshape_node, Node): + return None + if not _is_op( + reshape_node, + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + "reshape", + "view", + ): + return None + + source_x = reshape_node.args[0] + return source_x if isinstance(source_x, Node) else None + + +def _detect_rotation(node: Node) -> Optional[Tuple[Node, bool]]: + """Detect any supported rotation pattern. + + Returns (source_tensor, is_interleaved) or None. + """ + result = _detect_rotate_half(node) + if result is not None: + return result, False + + result = _detect_interleaved_rotation(node) + if result is not None: + return result, True + + return None + + +def _detect_neox_rope(node: Node) -> Optional[RoPEMatch]: + """Detect the NeoX/LLaMA RoPE pattern: add(mul(x, cos), mul(rotate_half(x), sin)).""" + if not _is_op(node, torch.ops.aten.add.Tensor, operator.add): + return None + if len(node.args) < 2: + return None + + left = node.args[0] + right = node.args[1] + if not isinstance(left, Node) or not isinstance(right, Node): + return None + + if not _is_op(left, torch.ops.aten.mul.Tensor, operator.mul): + return None + if not _is_op(right, torch.ops.aten.mul.Tensor, operator.mul): + return None + + def _try_match(x_mul: Node, rot_mul: Node) -> Optional[RoPEMatch]: + rot_a, rot_b = rot_mul.args[0], rot_mul.args[1] + + for rot_candidate, sin_candidate in [(rot_a, rot_b), (rot_b, rot_a)]: + if not isinstance(rot_candidate, Node): + continue + + rot_unwrapped = _trace_through_views(rot_candidate) + rotation_result = _detect_rotation(rot_unwrapped) + if rotation_result is None: + continue + x_from_rotate, is_interleaved = rotation_result + + x_a, x_b = x_mul.args[0], x_mul.args[1] + for x_candidate, cos_candidate in [(x_a, x_b), (x_b, x_a)]: + if not isinstance(x_candidate, Node): + continue + + x_traced = _trace_through_views(x_candidate) + if x_traced is not x_from_rotate: + continue + + cos_source = _trace_through_views(cos_candidate) + sin_source = _trace_through_views(sin_candidate) + + return RoPEMatch( + pre_rope_input=x_from_rotate, + cos_node=cos_source, + sin_node=sin_source, + rope_interleaved=is_interleaved, + ) + + return None + + # Try both orderings of the add (commutative). + result = _try_match(left, right) + if result is not None: + return result + return _try_match(right, left) + + +def _detect_rope(node: Node) -> Optional[RoPEMatch]: + """Detect any supported RoPE variant at a given node.""" + return _detect_neox_rope(node) + + +# Graph Surgery + + +def _replace_with_fused_op( + graph: Graph, + sdpa_node: Node, + pre_rope_q: Node, + pre_rope_k: Node, + v_input: Node, + cos_node: Node, + sin_node: Node, + is_causal: bool, + scale: float, + enable_gqa: bool, + rope_interleaved: bool, + rope_sdpa_op, +) -> None: + """Replace an SDPA node with a fused RoPE+SDPA custom op.""" + with graph.inserting_before(sdpa_node): + fused_node = graph.call_function( + rope_sdpa_op, + args=(pre_rope_q, pre_rope_k, v_input, cos_node, sin_node), + kwargs={ + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + "rope_interleaved": rope_interleaved, + }, + ) + + fused_node.meta = sdpa_node.meta.copy() + sdpa_node.replace_all_uses_with(fused_node) + + logger.info( + "Fused RoPE + SDPA: replaced %s with %s", + sdpa_node.name, + fused_node.name, + ) + + +def _replace_sdpa_with_fp8( + graph: Graph, + sdpa_node: Node, + q_node: Node, + k_node: Node, + v_node: Node, + is_causal: bool, + scale: float, + enable_gqa: bool, + fp8_sdpa_op, +) -> None: + """Replace a plain SDPA node with an FP8 SDPA op (no RoPE fusion).""" + with graph.inserting_before(sdpa_node): + fp8_node = graph.call_function( + fp8_sdpa_op, + args=(q_node, k_node, v_node), + kwargs={ + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + }, + ) + + fp8_node.meta = sdpa_node.meta.copy() + sdpa_node.replace_all_uses_with(fp8_node) + + logger.info( + "Replaced SDPA with FP8: %s -> %s", + sdpa_node.name, + fp8_node.name, + ) + + +# Main Fusion Pass + + +def rope_sdpa_fusion_pass( + graph: Graph, + rope_sdpa_op, + fp8_sdpa_op, + max_head_dim: int = 256, + backend_name: str = "FP8", + fuse_rope: bool = True, + strip_causal_mask: bool = False, +) -> None: + """Detect and replace SDPA patterns in the FX graph. + + For each fusible SDPA node: + - Pattern A (RoPE -> transpose -> SDPA): fuse with rope_sdpa custom op + - Pattern B (transpose -> RoPE -> SDPA): fuse with rope_sdpa custom op + - No RoPE: replace with fp8_sdpa custom op + + Note: KV caching must be disabled before compilation. + DynamicCache.update() inserts torch.cat nodes that break pattern matching. + """ + sdpa_nodes = [n for n in graph.nodes if _is_sdpa_node(n)] + + if not sdpa_nodes: + logger.debug("RoPE + SDPA fusion: found 0 SDPA nodes in graph") + return + + fused_count = 0 + fp8_count = 0 + + for sdpa_node in sdpa_nodes: + is_fusible, needs_mask_strip = _sdpa_is_fusible( + sdpa_node, strip_causal_mask=strip_causal_mask + ) + if not is_fusible: + logger.debug("Skipping non-fusible SDPA: %s", sdpa_node.name) + continue + + if needs_mask_strip: + _strip_causal_mask(sdpa_node) + + is_causal, scale, enable_gqa = _get_sdpa_params(sdpa_node) + + qkv = _get_sdpa_qkv(sdpa_node) + if qkv is None: + continue + q_node, k_node, v_node = qkv + + # Try RoPE fusion + if fuse_rope: + v_pre_transpose = _unwrap_transpose(v_node) + + # Pattern A: RoPE -> transpose -> SDPA (FLUX-style) + q_pre_transpose = _unwrap_transpose(q_node) + k_pre_transpose = _unwrap_transpose(k_node) + + if q_pre_transpose is not None and k_pre_transpose is not None: + q_pre_cast = _trace_through_views(q_pre_transpose) + k_pre_cast = _trace_through_views(k_pre_transpose) + + q_rope = _detect_rope(q_pre_cast) + k_rope = _detect_rope(k_pre_cast) + + if q_rope is not None and k_rope is not None: + pre_rope_q = _trace_through_views(q_rope.pre_rope_input) + pre_rope_k = _trace_through_views(k_rope.pre_rope_input) + + if v_pre_transpose is None: + logger.debug( + "Pattern A: V has no transpose, skipping: %s", + sdpa_node.name, + ) + continue + + cos_sin = _reshape_cos_sin_to_2d( + graph, + q_rope.cos_node, + q_rope.sin_node, + sdpa_node, + ) + if cos_sin is None: + logger.debug( + "Pattern A: cos/sin shape incompatible, skipping: %s", + sdpa_node.name, + ) + continue + cos_2d, sin_2d = cos_sin + + _replace_with_fused_op( + graph=graph, + sdpa_node=sdpa_node, + pre_rope_q=pre_rope_q, + pre_rope_k=pre_rope_k, + v_input=v_pre_transpose, + cos_node=cos_2d, + sin_node=sin_2d, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + rope_interleaved=q_rope.rope_interleaved, + rope_sdpa_op=rope_sdpa_op, + ) + fused_count += 1 + continue + + # Pattern B: transpose -> RoPE -> SDPA (HuggingFace-style) + # For GQA, K may go through repeat_kv after RoPE. + q_rope = _detect_rope(_trace_through_views(q_node)) + + k_rope = _detect_rope(_trace_through_views(k_node)) + gqa_unwrapped = False + if k_rope is None: + k_pre_repeat = _unwrap_repeat_kv(k_node) + if k_pre_repeat is not None: + k_rope = _detect_rope(_trace_through_views(k_pre_repeat)) + if k_rope is not None: + gqa_unwrapped = True + + if q_rope is not None and k_rope is not None: + q_bshd = _unwrap_transpose(_trace_through_views(q_rope.pre_rope_input)) + k_bshd = _unwrap_transpose(_trace_through_views(k_rope.pre_rope_input)) + + if q_bshd is not None and k_bshd is not None: + v_for_fusion = v_node + if gqa_unwrapped: + v_pre_repeat = _unwrap_repeat_kv(v_node) + if v_pre_repeat is not None: + v_for_fusion = v_pre_repeat + + v_bshd = _unwrap_transpose(v_for_fusion) + if v_bshd is None: + logger.debug( + "Pattern B: V has no transpose, skipping: %s", + sdpa_node.name, + ) + continue + + cos_sin = _reshape_cos_sin_to_2d( + graph, + q_rope.cos_node, + q_rope.sin_node, + sdpa_node, + ) + if cos_sin is None: + logger.debug( + "Pattern B: cos/sin shape incompatible, skipping: %s", + sdpa_node.name, + ) + continue + cos_2d, sin_2d = cos_sin + + fused_enable_gqa = True if gqa_unwrapped else enable_gqa + + _replace_with_fused_op( + graph=graph, + sdpa_node=sdpa_node, + pre_rope_q=q_bshd, + pre_rope_k=k_bshd, + v_input=v_bshd, + cos_node=cos_2d, + sin_node=sin_2d, + is_causal=is_causal, + scale=scale, + enable_gqa=fused_enable_gqa, + rope_interleaved=q_rope.rope_interleaved, + rope_sdpa_op=rope_sdpa_op, + ) + fused_count += 1 + continue + + # No RoPE detected (or fuse_rope=False) — replace with non-rope FP8 SDPA + q_shape = _get_node_shape(q_node) + if q_shape is not None and q_shape[-1] > max_head_dim: + logger.debug( + "Skipping FP8 replacement: head_dim=%d > %d for %s", + q_shape[-1], + max_head_dim, + sdpa_node.name, + ) + continue + + _replace_sdpa_with_fp8( + graph=graph, + sdpa_node=sdpa_node, + q_node=q_node, + k_node=k_node, + v_node=v_node, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + fp8_sdpa_op=fp8_sdpa_op, + ) + fp8_count += 1 + + replaced_count = fused_count + fp8_count + logger.info( + "Found %d SDPA node(s): %d RoPE-fused, %d FP8-replaced (backend: %s)", + len(sdpa_nodes), + fused_count, + fp8_count, + backend_name, + ) + + if replaced_count > 0: + graph.eliminate_dead_code() + logger.info( + "Fusion pass complete: %d RoPE-fused, %d FP8-replaced", + fused_count, + fp8_count, + ) diff --git a/torchao/prototype/attention/shared_utils/setup.py b/torchao/prototype/attention/shared_utils/setup.py index 73e76798bc..9f58134377 100644 --- a/torchao/prototype/attention/shared_utils/setup.py +++ b/torchao/prototype/attention/shared_utils/setup.py @@ -4,10 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings + import torch.nn as nn +from torchao.prototype.attention.shared_utils.fusion_utils import detect_causal_mask from torchao.prototype.attention.shared_utils.wrapper import ( _FP8FlashAttentionMonkeyPatchWrapper, + _FP8FlashAttentionWrapper, _make_causal_aware_sdpa, ) @@ -17,20 +21,32 @@ def setup_fp8_backend( flash_impl_name: str, fuse_rope_using_torch_compile: bool, ) -> nn.Module: - if fuse_rope_using_torch_compile: - raise NotImplementedError( - "fuse_rope_using_torch_compile requires the RoPE fusion path, " - "which is not available in this version." - ) if flash_impl_name == "FA3": from torchao.prototype.attention.fp8_fa3.attention import ( fp8_fa3_sdpa as sdpa_fn, ) + from torchao.prototype.attention.fp8_fa3.fusion_pass import make_fp8_backend else: raise ValueError(f"Unknown flash_impl_name: {flash_impl_name}") + if fuse_rope_using_torch_compile: + wrapper = _FP8FlashAttentionWrapper(model, flash_impl_name=flash_impl_name) + wrapper.compile_backend = make_fp8_backend(model, fuse_rope_using_torch_compile) + warnings.warn( + "fuse_rope_using_torch_compile=True: you must call " + "torch.compile(model, backend=model.compile_backend) for the " + "RoPE + FP8 fusion to take effect. Without it the model runs " + "eagerly with no fusion. " + "Note: this path uses torch._inductor.config.pre_grad_custom_pass, " + "an unstable internal API that may change across PyTorch versions.", + UserWarning, + stacklevel=3, + ) + return wrapper + + strip_causal_mask = detect_causal_mask(model, flash_impl_name=flash_impl_name) return _FP8FlashAttentionMonkeyPatchWrapper( model, flash_impl_name=flash_impl_name, - sdpa_patch_fn=_make_causal_aware_sdpa(sdpa_fn, strip_causal_mask=False), + sdpa_patch_fn=_make_causal_aware_sdpa(sdpa_fn, strip_causal_mask), ) diff --git a/torchao/prototype/attention/shared_utils/wrapper.py b/torchao/prototype/attention/shared_utils/wrapper.py index 00b3e96c64..506d10e80c 100644 --- a/torchao/prototype/attention/shared_utils/wrapper.py +++ b/torchao/prototype/attention/shared_utils/wrapper.py @@ -28,6 +28,21 @@ def __getattr__(self, name: str): return getattr(self._orig_mod, name) +class _FP8FlashAttentionWrapper(_LowPrecisionAttentionWrapper): + """Compile path wrapper. Activates the flash impl around the module forward.""" + + def __init__(self, orig_mod: nn.Module, flash_impl_name: str): + super().__init__(orig_mod) + self._flash_impl_name = flash_impl_name + + def forward(self, *args, **kwargs): + activate_flash_attention_impl(self._flash_impl_name) + try: + return self._orig_mod(*args, **kwargs) + finally: + restore_flash_attention_impl() + + class _FP8FlashAttentionMonkeyPatchWrapper(_LowPrecisionAttentionWrapper): """Monkey-patch path wrapper. Replaces ``F.scaled_dot_product_attention`` with the FP8 backend for the duration of each forward call.