diff --git a/test/prototype/attention/__init__.py b/test/prototype/attention/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/prototype/attention/test_fp8_attention.py b/test/prototype/attention/test_fp8_attention.py new file mode 100644 index 0000000000..f494de07fa --- /dev/null +++ b/test/prototype/attention/test_fp8_attention.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for FP8 low-precision attention (FA3 backend on Hopper).""" + +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import TestCase, run_tests + +from torchao.quantization.utils import compute_error +from torchao.utils import torch_version_at_least + +if torch_version_at_least("2.11.0"): + from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper + + if _is_hopper() and _is_fa3_available(): + from torch.nn.attention import ( + activate_flash_attention_impl, + restore_flash_attention_impl, + ) + + from torchao.prototype.attention import ( + AttentionBackend, + apply_low_precision_attention, + ) + from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa + + +class SimpleAttentionModel(nn.Module): + def __init__(self, embed_dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + def forward(self, x): + B, S, _ = x.shape + q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) + + +@common_utils.instantiate_parametrized_tests +class TestFP8FA3Attention(TestCase): + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("shape", [(2, 8, 1024, 64), (1, 16, 1024, 128)]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_sdpa_accuracy(self, shape, dtype): + B, H, S, D = shape + q = torch.randn(B, H, S, D, device="cuda", dtype=dtype) + k = torch.randn(B, H, S, D, device="cuda", dtype=dtype) + v = torch.randn(B, H, S, D, device="cuda", dtype=dtype) + + with torch.no_grad(): + out_ref = F.scaled_dot_product_attention(q, k, v, is_causal=False) + + activate_flash_attention_impl("FA3") + try: + with torch.no_grad(): + out_fp8 = fp8_fa3_sdpa(q, k, v, is_causal=False) + finally: + restore_flash_attention_impl() + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 25.0, + f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}", + ) + + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_monkey_patch_model(self, dtype): + embed_dim, num_heads = 512, 8 + model = ( + SimpleAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + x = torch.randn(2, 128, embed_dim, device="cuda", dtype=dtype) + + with torch.no_grad(): + out_ref = model(x) + + fp8_model = ( + SimpleAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + fp8_model.load_state_dict(model.state_dict()) + fp8_model = apply_low_precision_attention( + fp8_model, + backend=AttentionBackend.FP8_FA3, + fuse_rope_using_torch_compile=False, + ) + + with torch.no_grad(): + out_fp8 = fp8_model(x) + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 20.0, + f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/prototype/attention/__init__.py b/torchao/prototype/attention/__init__.py new file mode 100644 index 0000000000..60531e9cf0 --- /dev/null +++ b/torchao/prototype/attention/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Low-precision attention for inference. + +Only supports forward pass — backward is not supported by the underlying backends. +""" + +from torchao.prototype.attention.api import ( + AttentionBackend, + apply_low_precision_attention, +) + +__all__ = [ + "AttentionBackend", + "apply_low_precision_attention", +] diff --git a/torchao/prototype/attention/api.py b/torchao/prototype/attention/api.py new file mode 100644 index 0000000000..8bbe8c9a34 --- /dev/null +++ b/torchao/prototype/attention/api.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""User-facing API for low-precision attention.""" + +from enum import Enum +from typing import Optional + +import torch +import torch._dynamo +import torch.nn as nn + +from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper +from torchao.utils import torch_version_at_least + +if torch_version_at_least("2.11.0"): + from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend + from torchao.prototype.attention.shared_utils.wrapper import ( + _LowPrecisionAttentionWrapper, + ) +else: + raise ImportError("Low-precision attention requires PyTorch 2.11+.") + + +class AttentionBackend(str, Enum): + """Backend kernel for computing attention.""" + + FP8_FA3 = "FP8_FA3" # Requires SM90+ (Hopper) + + +def _get_available_backend() -> AttentionBackend: + if not torch.cuda.is_available(): + raise RuntimeError("Low-precision attention requires CUDA.") + capability = torch.cuda.get_device_capability() + if _is_hopper() and _is_fa3_available(): + return AttentionBackend.FP8_FA3 + raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.") + + +def _check_backend_available(backend: AttentionBackend) -> None: + if not torch.cuda.is_available(): + raise RuntimeError(f"{backend} backend requires CUDA.") + capability = torch.cuda.get_device_capability() + if backend == AttentionBackend.FP8_FA3: + if not _is_hopper(): + raise RuntimeError( + f"FP8_FA3 requires Hopper (SM 9.x), got SM{capability[0]}{capability[1]}." + ) + if not _is_fa3_available(): + raise RuntimeError( + "FP8_FA3 requires the flash-attn package with FA3 support." + ) + else: + raise ValueError(f"Unknown backend: {backend}") + + +def apply_low_precision_attention( + model: nn.Module, + backend: Optional[AttentionBackend] = None, + fuse_rope_using_torch_compile: bool = False, +) -> nn.Module: + """Apply low-precision attention to a model. + + Must be called before ``torch.compile``. KV caching should be + disabled before calling (e.g., ``config.use_cache = False`` for + HuggingFace models). + """ + if isinstance(model, _LowPrecisionAttentionWrapper): + raise RuntimeError( + "apply_low_precision_attention has already been applied to this module." + ) + if isinstance(model, torch._dynamo.OptimizedModule): + raise RuntimeError( + "apply_low_precision_attention must be called before torch.compile." + ) + + if backend is None: + backend = _get_available_backend() + else: + _check_backend_available(backend) + + if backend == AttentionBackend.FP8_FA3: + return setup_fp8_backend(model, "FA3", fuse_rope_using_torch_compile) + + raise ValueError(f"Unknown backend: {backend}") diff --git a/torchao/prototype/attention/fp8_fa3/__init__.py b/torchao/prototype/attention/fp8_fa3/__init__.py new file mode 100644 index 0000000000..c0e3007c4e --- /dev/null +++ b/torchao/prototype/attention/fp8_fa3/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""FP8 attention using FA3 backend.""" + +from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa +from torchao.prototype.attention.quantization import _fp8_sdpa_quantize + +__all__ = [ + "fp8_fa3_sdpa", + "_fp8_sdpa_quantize", +] diff --git a/torchao/prototype/attention/fp8_fa3/attention.py b/torchao/prototype/attention/fp8_fa3/attention.py new file mode 100644 index 0000000000..eee06ed77f --- /dev/null +++ b/torchao/prototype/attention/fp8_fa3/attention.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""FP8 SDPA using FA3 backend. + +Thin wrapper around ``shared_utils/attention.py``. When using directly, +activate the FA3 flash attention implementation before calling. +""" + +from functools import partial + +from torchao.prototype.attention.shared_utils.attention import ( + _fp8_sdpa, +) + +fp8_fa3_sdpa = partial(_fp8_sdpa, backend_name="FA3") +fp8_fa3_sdpa.__doc__ = _fp8_sdpa.__doc__ +fp8_fa3_sdpa.__name__ = "fp8_fa3_sdpa" +fp8_fa3_sdpa.__qualname__ = "fp8_fa3_sdpa" diff --git a/torchao/prototype/attention/fp8_fa3/setup.py b/torchao/prototype/attention/fp8_fa3/setup.py new file mode 100644 index 0000000000..d9849c7b1e --- /dev/null +++ b/torchao/prototype/attention/fp8_fa3/setup.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""FP8 FA3 backend setup.""" + +import torch.nn as nn + +from torchao.prototype.attention.config import LowPrecisionAttentionConfig +from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend + + +def setup_fp8_fa3( + model: nn.Module, + config: LowPrecisionAttentionConfig, +) -> nn.Module: + """Set up FP8 FA3 attention on *model* and wrap it.""" + from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa + + return setup_fp8_backend( + model, + config, + flash_impl_name="FA3", + sdpa_fn=fp8_fa3_sdpa, + ) diff --git a/torchao/prototype/attention/quantization/__init__.py b/torchao/prototype/attention/quantization/__init__.py new file mode 100644 index 0000000000..05bb8bcb74 --- /dev/null +++ b/torchao/prototype/attention/quantization/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared FP8 quantization kernels for low-precision attention.""" + +from torchao.prototype.attention.quantization.quantization import ( + _fp8_sdpa_quantize, +) + +__all__ = [ + "_fp8_sdpa_quantize", +] diff --git a/torchao/prototype/attention/quantization/quantization.py b/torchao/prototype/attention/quantization/quantization.py new file mode 100644 index 0000000000..84a255ead7 --- /dev/null +++ b/torchao/prototype/attention/quantization/quantization.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +FP8 quantization for attention inputs. +""" + +from typing import Tuple + +import torch + + +def _fp8_sdpa_quantize( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Quantize Q, K, V to FP8 with per-head scaling.""" + if q.dim() != 4: + raise ValueError(f"Expected 4D tensor for q, got {q.dim()}D") + if k.dim() != 4: + raise ValueError(f"Expected 4D tensor for k, got {k.dim()}D") + if v.dim() != 4: + raise ValueError(f"Expected 4D tensor for v, got {v.dim()}D") + if k.shape != v.shape: + raise ValueError(f"K and V shape mismatch: {k.shape} vs {v.shape}") + if q.shape[0] != k.shape[0]: + raise ValueError(f"Batch size mismatch: {q.shape[0]} vs {k.shape[0]}") + if q.shape[1] % k.shape[1] != 0: + raise ValueError( + f"Q head count ({q.shape[1]}) must be a multiple of K head count ({k.shape[1]})" + ) + if q.shape[3] != k.shape[3]: + raise ValueError(f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}") + + from torchao.prototype.attention.quantization.triton_qkv_quantization import ( + triton_fp8_sdpa_quantize, + ) + + return triton_fp8_sdpa_quantize(q, k, v) diff --git a/torchao/prototype/attention/quantization/triton_qkv_quantization.py b/torchao/prototype/attention/quantization/triton_qkv_quantization.py new file mode 100644 index 0000000000..c0ea2f4fd6 --- /dev/null +++ b/torchao/prototype/attention/quantization/triton_qkv_quantization.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +FP8 quantization kernels for Q, K, V. + +Input/output format: [B, H, S, D]. +Supports GQA (different head counts for Q vs K/V). +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +def _compute_num_chunks(tensor: torch.Tensor, S: int) -> int: + """Compute optimal number of chunks based on GPU properties.""" + props = torch.cuda.get_device_properties(tensor.device) + num_sms = props.multi_processor_count + B, H = tensor.shape[:2] # [B, H, S, D] + base_parallelism = B * H + # Target 2-4x SMs for good occupancy/latency hiding + target_blocks = num_sms * 4 + num_chunks = max(1, target_blocks // base_parallelism) + # Ensure each chunk has at least 32 S positions for efficiency + num_chunks = min(num_chunks, S // 32) if S >= 32 else 1 + # Cap at reasonable maximum + num_chunks = min(num_chunks, 64) + # Adjust if S is small + num_chunks = min(num_chunks, S) + return num_chunks + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def single_phase1_kernel( + # Input tensor [B, H, S, D] + x_ptr, + # Output: partial max values [B * H * num_chunks] + partial_max_ptr, + # Input strides (for [B, H, S, D] layout) + stride_b, + stride_h, + stride_s, + stride_d, + # Dimensions + S, + D, + H, + chunk_size, + num_chunks, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 1 for a single tensor: Compute partial absmax. + + Grid: (B, H, num_chunks) + + Uses linearized iteration over chunk_size * D elements. + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointer for input [B, H, S, D] + base_offset = pid_b * stride_b + pid_h * stride_h + + # Initialize max accumulator + x_max = 0.0 + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + # Input offset [B, H, S, D] + ptr_offset = s_idx * stride_s + d_idx * stride_d + + x_val = tl.load(x_ptr + base_offset + ptr_offset, mask=mask, other=0.0).to( + tl.float32 + ) + x_max = tl.maximum(x_max, tl.max(tl.abs(x_val))) + + # Store partial max + chunk_idx = pid_b * (H * num_chunks) + pid_h * num_chunks + pid_chunk + tl.store(partial_max_ptr + chunk_idx, x_max) + + +@triton.jit +def single_reduce_kernel( + partial_max_ptr, # [B * H * num_chunks] + scale_ptr, + descale_ptr, + H, + num_chunks, +): + """ + Reduce partial maxes and compute scale/descale for a single tensor. + + Grid: (B, H) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + # Reduce across chunks for this (batch, head) + x_max = 0.0 + + base_idx = (pid_b * H + pid_h) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + # Compute scale and descale + # FP8 E4M3 max value is 448.0 + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H + pid_h + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.jit +def group_reduce_kernel( + partial_max_ptr, # [B * H_q * num_chunks] + scale_ptr, # [B, H_kv] + descale_ptr, # [B, H_kv] + H_q, + H_kv, + groups, # H_q // H_kv + num_chunks, +): + """ + Reduce partial maxes across head groups for GQA Q tensor. + + For each KV group, reduces the max across all Q heads in that group + and all chunks, producing one scale per (batch, kv_head). + + Grid: (B, H_kv) + """ + pid_b = tl.program_id(axis=0) + pid_hkv = tl.program_id(axis=1) + + x_max = 0.0 + + for g in range(groups): + h_q = pid_hkv * groups + g + base_idx = (pid_b * H_q + h_q) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H_kv + pid_hkv + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def single_phase2_kernel( + # Input tensor [B, H, S, D] + x_ptr, + # Output tensor [B, H, S, D] - FP8 quantized + x_out_ptr, + # Precomputed scale [B, H_scale] + scale_ptr, + # Strides (for [B, H, S, D] layout) + stride_b, + stride_h, + stride_s, + stride_d, + # Dimensions + S, + D, + H, + chunk_size, + # Scale indexing for GQA: scale has H_scale entries per batch, + # and each group of `groups` heads shares one scale. + # For non-GQA: H_scale = H, groups = 1. + H_scale, + groups, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 2 for a single tensor: Quantize to FP8 using precomputed scale. + + Grid: (B, H, num_chunks) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Load scale for this head (or head group for GQA) + scale = tl.load(scale_ptr + pid_b * H_scale + pid_h // groups) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointer + base_offset = pid_b * stride_b + pid_h * stride_h + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + ptr_offset = base_offset + s_idx * stride_s + d_idx * stride_d + + # Load input value + x_val = tl.load(x_ptr + ptr_offset, mask=mask, other=0.0).to(tl.float32) + + # Quantize to FP8 + x_fp8 = (x_val * scale).to(tl.float8e4nv) + + # Store to output + tl.store(x_out_ptr + ptr_offset, x_fp8, mask=mask) + + +def triton_fp8_sdpa_quantize( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + num_chunks: Optional[int] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """ + Separated FP8 quantization for Q, K, V tensors. + + Quantizes all tensors to FP8 with per-head scaling. + Each of Q, K, V is processed with independent kernel launches, + supporting GQA where Q has more heads than K/V (H_q = groups * H_kv). + + For GQA, Q is quantized with per-KV-group scaling so that q_descale + has shape [B, H_kv] as required by FA3. + + Args: + q: Query tensor of shape [B, H_q, S, D] in bf16/fp16 + k: Key tensor of shape [B, H_kv, S, D] in bf16/fp16 + v: Value tensor of shape [B, H_kv, S, D] in bf16/fp16 + num_chunks: Number of chunks to split S dimension into. + If None, automatically selects based on GPU SM count. + + Returns: + q_fp8: Quantized query, shape [B, H_q, S, D] in fp8 + k_fp8: Quantized key, shape [B, H_kv, S, D] in fp8 + v_fp8: Quantized value, shape [B, H_kv, S, D] in fp8 + q_descale: Query descale factors, shape [B, H_kv] in fp32 + k_descale: Key descale factors, shape [B, H_kv] in fp32 + v_descale: Value descale factors, shape [B, H_kv] in fp32 + """ + assert q.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {q.dim()}D" + assert k.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {k.dim()}D" + assert v.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {v.dim()}D" + assert k.shape == v.shape, ( + f"K and V must have the same shape, got {k.shape} vs {v.shape}" + ) + assert q.shape[0] == k.shape[0], ( + f"Batch size mismatch: {q.shape[0]} vs {k.shape[0]}" + ) + assert q.shape[2] == k.shape[2], ( + f"Sequence length mismatch: {q.shape[2]} vs {k.shape[2]}" + ) + assert q.shape[3] == k.shape[3], f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}" + assert q.shape[1] % k.shape[1] == 0, ( + f"Q heads ({q.shape[1]}) must be a multiple of K heads ({k.shape[1]})" + ) + + B, H_q, S, D = q.shape + H_kv = k.shape[1] + groups = H_q // H_kv + + # Make tensors contiguous if needed + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + # Compute number of chunks + if num_chunks is None: + num_chunks = _compute_num_chunks(q, S) + chunk_size = (S + num_chunks - 1) // num_chunks + + # Allocate output tensors + q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fn) + k_fp8 = torch.empty_like(k, dtype=torch.float8_e4m3fn) + v_fp8 = torch.empty_like(v, dtype=torch.float8_e4m3fn) + + # Allocate partial max buffers (one per tensor) + q_partial_max = torch.empty( + B * H_q * num_chunks, dtype=torch.float32, device=q.device + ) + k_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + v_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + + # Allocate scale/descale tensors. + # For GQA, Q scale/descale are [B, H_kv] (per KV group). + # K and V are always [B, H_kv] (per head). + q_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + q_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + + q_grid_chunked = (B, H_q, num_chunks) + kv_grid_chunked = (B, H_kv, num_chunks) + + # ---- Phase 1: Max for Q ---- + single_phase1_kernel[q_grid_chunked]( + q, + q_partial_max, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + S, + D, + H_q, + chunk_size, + num_chunks, + ) + + # ---- Phase 1: Max for K ---- + single_phase1_kernel[kv_grid_chunked]( + k, + k_partial_max, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + S, + D, + H_kv, + chunk_size, + num_chunks, + ) + + # ---- Phase 1: Max for V ---- + single_phase1_kernel[kv_grid_chunked]( + v, + v_partial_max, + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + S, + D, + H_kv, + chunk_size, + num_chunks, + ) + + # ---- Reduce ---- + # Q: group reduce across `groups` Q heads per KV head + group_reduce_kernel[(B, H_kv)]( + q_partial_max, q_scale, q_descale, H_q, H_kv, groups, num_chunks + ) + # K, V: per-head reduce + single_reduce_kernel[(B, H_kv)](k_partial_max, k_scale, k_descale, H_kv, num_chunks) + single_reduce_kernel[(B, H_kv)](v_partial_max, v_scale, v_descale, H_kv, num_chunks) + + # ---- Phase 2: Quantize Q ---- + # Q scale is [B, H_kv]; each group of `groups` Q heads shares one scale. + single_phase2_kernel[q_grid_chunked]( + q, + q_fp8, + q_scale, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + S, + D, + H_q, + chunk_size, + H_kv, + groups, + ) + + # ---- Phase 2: Quantize K ---- + # K scale is [B, H_kv]; groups=1 (per-head). + single_phase2_kernel[kv_grid_chunked]( + k, + k_fp8, + k_scale, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + S, + D, + H_kv, + chunk_size, + H_kv, + 1, + ) + + # ---- Phase 2: Quantize V ---- + # V scale is [B, H_kv]; groups=1 (per-head). + single_phase2_kernel[kv_grid_chunked]( + v, + v_fp8, + v_scale, + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + S, + D, + H_kv, + chunk_size, + H_kv, + 1, + ) + + return q_fp8, k_fp8, v_fp8, q_descale, k_descale, v_descale diff --git a/torchao/prototype/attention/shared_utils/__init__.py b/torchao/prototype/attention/shared_utils/__init__.py new file mode 100644 index 0000000000..ecddc88b5f --- /dev/null +++ b/torchao/prototype/attention/shared_utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchao/prototype/attention/shared_utils/attention.py b/torchao/prototype/attention/shared_utils/attention.py new file mode 100644 index 0000000000..15361b8e37 --- /dev/null +++ b/torchao/prototype/attention/shared_utils/attention.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared FP8 scaled dot-product attention implementation.""" + +from typing import Optional + +import torch + +from torchao.utils import torch_version_at_least + +_TORCH_VERSION_AT_LEAST_2_11 = torch_version_at_least("2.11.0") + +if _TORCH_VERSION_AT_LEAST_2_11: + from torch.nn.attention import SDPBackend, sdpa_kernel + from torch.nn.attention.experimental._scaled_dot_product_attention_quantized import ( + _scaled_dot_product_attention_quantized, + ) + +from torchao.prototype.attention.quantization import ( + _fp8_sdpa_quantize, +) + + +def _fp8_sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + *, + backend_name: str = "FP8", +) -> torch.Tensor: + """FP8 SDPA implementation shared by all backends. + + Quantizes Q, K, V to FP8 with per-head scaling, then calls + ``_scaled_dot_product_attention_quantized`` under ``SDPBackend.FLASH_ATTENTION``. + """ + if not _TORCH_VERSION_AT_LEAST_2_11: + raise RuntimeError("Low-precision attention requires PyTorch 2.11+.") + if attn_mask is not None: + raise ValueError(f"attn_mask not supported for FP8 {backend_name}") + if dropout_p != 0.0: + raise ValueError( + f"dropout_p must be 0.0 for FP8 {backend_name}, got {dropout_p}" + ) + + input_dtype = query.dtype + + # Ensure Triton kernels launch on the correct device. + # In the monkey-patch path (no torch.compile), accelerate's hooks move + # tensors to the correct device but don't call torch.cuda.set_device(). + # Triton dispatches based on current_device(), not tensor device, so + # without this guard the kernel launches on the wrong GPU's stream. + _prev_device = torch.cuda.current_device() + if query.device.index is not None and query.device.index != _prev_device: + torch.cuda.set_device(query.device) + + q_fp8, k_fp8, v_fp8, descale_q, descale_k, descale_v = _fp8_sdpa_quantize( + query, key, value + ) + + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + out = _scaled_dot_product_attention_quantized( + q_fp8, + k_fp8, + v_fp8, + is_causal=is_causal, + scale=scale, + q_descale=descale_q, + k_descale=descale_k, + v_descale=descale_v, + ) + + # Restore previous device to avoid side effects on the caller. + if query.device.index is not None and query.device.index != _prev_device: + torch.cuda.set_device(_prev_device) + + return out.to(input_dtype) diff --git a/torchao/prototype/attention/shared_utils/setup.py b/torchao/prototype/attention/shared_utils/setup.py new file mode 100644 index 0000000000..73e76798bc --- /dev/null +++ b/torchao/prototype/attention/shared_utils/setup.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from torchao.prototype.attention.shared_utils.wrapper import ( + _FP8FlashAttentionMonkeyPatchWrapper, + _make_causal_aware_sdpa, +) + + +def setup_fp8_backend( + model: nn.Module, + flash_impl_name: str, + fuse_rope_using_torch_compile: bool, +) -> nn.Module: + if fuse_rope_using_torch_compile: + raise NotImplementedError( + "fuse_rope_using_torch_compile requires the RoPE fusion path, " + "which is not available in this version." + ) + if flash_impl_name == "FA3": + from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_sdpa as sdpa_fn, + ) + else: + raise ValueError(f"Unknown flash_impl_name: {flash_impl_name}") + + return _FP8FlashAttentionMonkeyPatchWrapper( + model, + flash_impl_name=flash_impl_name, + sdpa_patch_fn=_make_causal_aware_sdpa(sdpa_fn, strip_causal_mask=False), + ) diff --git a/torchao/prototype/attention/shared_utils/wrapper.py b/torchao/prototype/attention/shared_utils/wrapper.py new file mode 100644 index 0000000000..00b3e96c64 --- /dev/null +++ b/torchao/prototype/attention/shared_utils/wrapper.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import ( + activate_flash_attention_impl, + restore_flash_attention_impl, +) + + +class _LowPrecisionAttentionWrapper(nn.Module): + """Base wrapper. Proxies attribute access to the original module.""" + + def __init__(self, orig_mod: nn.Module): + super().__init__() + self._orig_mod = orig_mod + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._orig_mod, name) + + +class _FP8FlashAttentionMonkeyPatchWrapper(_LowPrecisionAttentionWrapper): + """Monkey-patch path wrapper. Replaces ``F.scaled_dot_product_attention`` + with the FP8 backend for the duration of each forward call. + """ + + def __init__( + self, orig_mod: nn.Module, flash_impl_name: str, sdpa_patch_fn: Callable + ): + super().__init__(orig_mod) + self._flash_impl_name = flash_impl_name + self._sdpa_patch_fn = sdpa_patch_fn + + def forward(self, *args, **kwargs): + activate_flash_attention_impl(self._flash_impl_name) + try: + original_sdpa = F.scaled_dot_product_attention + F.scaled_dot_product_attention = self._sdpa_patch_fn + try: + return self._orig_mod(*args, **kwargs) + finally: + F.scaled_dot_product_attention = original_sdpa + finally: + restore_flash_attention_impl() + + +def _make_causal_aware_sdpa(fp8_sdpa_fn: Callable, strip_causal_mask: bool) -> Callable: + """Wrap an FP8 SDPA function to strip materialized causal masks.""" + if strip_causal_mask: + + def _patched( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + ): + if attn_mask is not None: + attn_mask = None + is_causal = True + return fp8_sdpa_fn( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + return _patched + return fp8_sdpa_fn diff --git a/torchao/prototype/attention/utils.py b/torchao/prototype/attention/utils.py new file mode 100644 index 0000000000..721a11f8ad --- /dev/null +++ b/torchao/prototype/attention/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import importlib + +import torch + + +def _is_hopper() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major == 9 + + +def _is_blackwell() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major == 10 + + +def _is_fa3_available() -> bool: + try: + importlib.import_module("flash_attn_interface") + return True + except ModuleNotFoundError: + return False + + +def _is_fa4_available() -> bool: + try: + importlib.import_module("flash_attn.cute.interface") + return True + except ModuleNotFoundError: + return False