-
Notifications
You must be signed in to change notification settings - Fork 462
Added new API for low precision fp8 attention using FA3 #3857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 38 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
cf8280f
Update
howardzhang-cv 9acfc52
Update (base update)
howardzhang-cv 88dff89
Update
howardzhang-cv 11e7cad
Update
howardzhang-cv 95cccd5
Update (base update)
howardzhang-cv fdf88ac
Update
howardzhang-cv ad075ac
Update (base update)
howardzhang-cv 878b464
Update
howardzhang-cv 3be7bbb
Update
howardzhang-cv 3eea34a
Update (base update)
howardzhang-cv 333e08c
Update
howardzhang-cv 8e227d0
Update
howardzhang-cv d85dcc2
Update
howardzhang-cv 56ba611
Update (base update)
howardzhang-cv aac4e70
Update
howardzhang-cv 9756826
Update (base update)
howardzhang-cv 32858e9
Update
howardzhang-cv 548d7ef
Update
howardzhang-cv e3c6014
Update (base update)
howardzhang-cv 97eafd5
Update
howardzhang-cv 0a042ad
Update (base update)
howardzhang-cv b6e59d0
Update
howardzhang-cv 44a7429
Update (base update)
howardzhang-cv a64a978
Update
howardzhang-cv 411886b
Update (base update)
howardzhang-cv 264d2bd
Update
howardzhang-cv 74f3cfd
Update (base update)
howardzhang-cv 708547f
Update
howardzhang-cv d60829a
Update (base update)
howardzhang-cv 1d26fd8
Update
howardzhang-cv 68efede
Update (base update)
howardzhang-cv e5a8c5a
Update
howardzhang-cv fec81e6
Update (base update)
howardzhang-cv 669829e
Update
howardzhang-cv edb1f38
Update (base update)
howardzhang-cv 7db5ce9
Update
howardzhang-cv 58b0e6a
Update (base update)
howardzhang-cv d18f997
Update
howardzhang-cv 100382a
Update (base update)
howardzhang-cv 58c838f
Update
howardzhang-cv c348a9f
Update (base update)
howardzhang-cv a719b90
Update
howardzhang-cv f140854
Update (base update)
howardzhang-cv ed23fd0
Update
howardzhang-cv 94d9200
Merge branch 'main' into gh/howardzhang-cv/16/head
howardzhang-cv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,248 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """Tests for FP8 low-precision attention (FA3 backend on Hopper).""" | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.testing._internal import common_utils | ||
| from torch.testing._internal.common_utils import TestCase, run_tests | ||
|
|
||
| from torchao.prototype.attention import ( | ||
| AttentionBackend, | ||
| apply_low_precision_attention, | ||
| ) | ||
| from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper | ||
| from torchao.utils import torch_version_at_least | ||
|
|
||
| if torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(): | ||
| from torch.nn.attention import ( | ||
| activate_flash_attention_impl, | ||
| restore_flash_attention_impl, | ||
| ) | ||
|
|
||
| from torchao.prototype.attention.fp8_fa3.attention import ( | ||
| fp8_fa3_rope_sdpa, | ||
| fp8_fa3_sdpa, | ||
| ) | ||
| from torchao.quantization.utils import compute_error | ||
|
|
||
|
|
||
| def _rope_cos_sin(S, D, device): | ||
| freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D)) | ||
| angles = torch.outer(torch.arange(S, dtype=torch.float32), freqs) | ||
| cos_half = torch.cos(angles) | ||
| sin_half = torch.sin(angles) | ||
| cos = torch.cat([cos_half, cos_half], dim=-1).to(device) | ||
| sin = torch.cat([sin_half, sin_half], dim=-1).to(device) | ||
| return cos, sin | ||
|
|
||
|
|
||
| def _apply_rope(x, cos, sin): | ||
| """NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D].""" | ||
| D_HALF = x.shape[-1] // 2 | ||
| rotate = torch.cat([-x[..., D_HALF:], x[..., :D_HALF]], dim=-1) | ||
| return ( | ||
| x * cos.unsqueeze(0).unsqueeze(2) + rotate * sin.unsqueeze(0).unsqueeze(2) | ||
| ).to(x.dtype) | ||
|
|
||
|
|
||
| class SimpleAttentionModel(nn.Module): | ||
| def __init__(self, embed_dim, num_heads): | ||
| super().__init__() | ||
| self.num_heads = num_heads | ||
| self.head_dim = embed_dim // num_heads | ||
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
|
|
||
| def forward(self, x): | ||
| B, S, _ = x.shape | ||
| q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | ||
| k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | ||
| v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | ||
| attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) | ||
| return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) | ||
|
|
||
|
|
||
| class SimpleRoPEAttentionModel(nn.Module): | ||
| """Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA).""" | ||
|
|
||
| def __init__(self, embed_dim, num_heads): | ||
| super().__init__() | ||
| self.num_heads = num_heads | ||
| self.head_dim = embed_dim // num_heads | ||
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
|
|
||
| def forward(self, x, cos, sin): | ||
| B, S, _ = x.shape | ||
| q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim) | ||
| k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim) | ||
| v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim) | ||
| q = _apply_rope(q, cos, sin).transpose(1, 2) | ||
| k = _apply_rope(k, cos, sin).transpose(1, 2) | ||
| v = v.transpose(1, 2) | ||
| attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) | ||
| return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) | ||
|
|
||
|
|
||
| @common_utils.instantiate_parametrized_tests | ||
| class TestFP8FA3Attention(TestCase): | ||
| @unittest.skipUnless( | ||
| torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), | ||
| "Requires PyTorch >= 2.11, Hopper GPU, and FA3", | ||
| ) | ||
| @common_utils.parametrize("shape", [(2, 8, 1024, 64), (1, 16, 1024, 128)]) | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_sdpa_accuracy(self, shape, dtype): | ||
| B, H, S, D = shape | ||
| q = torch.randn(B, H, S, D, device="cuda", dtype=dtype) | ||
| k = torch.randn(B, H, S, D, device="cuda", dtype=dtype) | ||
| v = torch.randn(B, H, S, D, device="cuda", dtype=dtype) | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = F.scaled_dot_product_attention(q, k, v, is_causal=False) | ||
|
|
||
| activate_flash_attention_impl("FA3") | ||
| try: | ||
| with torch.no_grad(): | ||
| out_fp8 = fp8_fa3_sdpa(q, k, v, is_causal=False) | ||
| finally: | ||
| restore_flash_attention_impl() | ||
|
|
||
| sqnr = compute_error(out_ref, out_fp8) | ||
| self.assertGreater( | ||
| sqnr.item(), | ||
| 25.0, | ||
| f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}", | ||
| ) | ||
|
|
||
| @unittest.skipUnless( | ||
| torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), | ||
| "Requires PyTorch >= 2.11, Hopper GPU, and FA3", | ||
| ) | ||
| @common_utils.parametrize("shape", [(2, 1024, 8, 64), (1, 1024, 16, 128)]) | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_rope_sdpa_accuracy(self, shape, dtype): | ||
| B, S, H, D = shape | ||
| q = torch.randn(B, S, H, D, device="cuda", dtype=dtype) | ||
| k = torch.randn(B, S, H, D, device="cuda", dtype=dtype) | ||
| v = torch.randn(B, S, H, D, device="cuda", dtype=dtype) | ||
| cos, sin = _rope_cos_sin(S, D, "cuda") | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = F.scaled_dot_product_attention( | ||
| _apply_rope(q, cos, sin).transpose(1, 2), | ||
| _apply_rope(k, cos, sin).transpose(1, 2), | ||
| v.transpose(1, 2), | ||
| is_causal=False, | ||
| ) | ||
|
|
||
| activate_flash_attention_impl("FA3") | ||
| try: | ||
| with torch.no_grad(): | ||
| out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False) | ||
| finally: | ||
| restore_flash_attention_impl() | ||
|
|
||
| sqnr = compute_error(out_ref, out_fp8) | ||
| self.assertGreater( | ||
| sqnr.item(), | ||
| 25.0, | ||
| f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}", | ||
| ) | ||
|
|
||
| @unittest.skipUnless( | ||
| torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), | ||
| "Requires PyTorch >= 2.11, Hopper GPU, and FA3", | ||
| ) | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_monkey_patch_model(self, dtype): | ||
| embed_dim, num_heads = 512, 8 | ||
| model = ( | ||
| SimpleAttentionModel(embed_dim, num_heads) | ||
| .to(device="cuda", dtype=dtype) | ||
| .eval() | ||
| ) | ||
| x = torch.randn(2, 128, embed_dim, device="cuda", dtype=dtype) | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = model(x) | ||
|
|
||
| fp8_model = ( | ||
| SimpleAttentionModel(embed_dim, num_heads) | ||
| .to(device="cuda", dtype=dtype) | ||
| .eval() | ||
| ) | ||
| fp8_model.load_state_dict(model.state_dict()) | ||
| fp8_model = apply_low_precision_attention( | ||
| fp8_model, | ||
| backend=AttentionBackend.FP8_FA3, | ||
| fuse_rope_using_torch_compile=False, | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| out_fp8 = fp8_model(x) | ||
|
|
||
| sqnr = compute_error(out_ref, out_fp8) | ||
| self.assertGreater( | ||
| sqnr.item(), | ||
| 20.0, | ||
| f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}", | ||
| ) | ||
|
|
||
| @unittest.skipUnless( | ||
| torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), | ||
| "Requires PyTorch >= 2.11, Hopper GPU, and FA3", | ||
| ) | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_rope_fusion_model(self, dtype): | ||
| embed_dim, num_heads = 512, 8 | ||
| model = ( | ||
| SimpleRoPEAttentionModel(embed_dim, num_heads) | ||
| .to(device="cuda", dtype=dtype) | ||
| .eval() | ||
| ) | ||
| S = 128 | ||
| x = torch.randn(2, S, embed_dim, device="cuda", dtype=dtype) | ||
| cos, sin = _rope_cos_sin(S, embed_dim // num_heads, "cuda") | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = model(x, cos, sin) | ||
|
|
||
| fp8_model = ( | ||
| SimpleRoPEAttentionModel(embed_dim, num_heads) | ||
| .to(device="cuda", dtype=dtype) | ||
| .eval() | ||
| ) | ||
| fp8_model.load_state_dict(model.state_dict()) | ||
| fp8_model = apply_low_precision_attention( | ||
| fp8_model, | ||
| backend=AttentionBackend.FP8_FA3, | ||
| fuse_rope_using_torch_compile=True, | ||
| ) | ||
| fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend) | ||
|
|
||
| with torch.no_grad(): | ||
| out_fp8 = fp8_model(x, cos, sin) | ||
|
|
||
| sqnr = compute_error(out_ref, out_fp8) | ||
| self.assertGreater( | ||
| sqnr.item(), | ||
| 20.0, | ||
| f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Low-precision attention for inference. | ||
|
|
||
| Only supports forward pass — backward is not supported by the underlying backends. | ||
| """ | ||
|
|
||
| from torchao.prototype.attention.api import ( | ||
| AttentionBackend, | ||
| apply_low_precision_attention, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "AttentionBackend", | ||
| "apply_low_precision_attention", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """User-facing API for low-precision attention.""" | ||
|
|
||
| from enum import Enum | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch._dynamo | ||
| import torch.nn as nn | ||
|
|
||
| from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend | ||
| from torchao.prototype.attention.shared_utils.wrapper import ( | ||
| _LowPrecisionAttentionWrapper, | ||
| ) | ||
| from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper | ||
| from torchao.utils import torch_version_at_least | ||
|
|
||
|
|
||
| class AttentionBackend(str, Enum): | ||
| """Backend kernel for computing attention.""" | ||
|
|
||
| FP8_FA3 = "FP8_FA3" # Requires SM90+ (Hopper) | ||
|
|
||
|
|
||
| def _get_available_backend() -> AttentionBackend: | ||
| if not torch.cuda.is_available(): | ||
| raise RuntimeError("Low-precision attention requires CUDA.") | ||
| capability = torch.cuda.get_device_capability() | ||
| if _is_hopper() and _is_fa3_available(): | ||
| return AttentionBackend.FP8_FA3 | ||
| raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.") | ||
|
|
||
|
|
||
| def _check_backend_available(backend: AttentionBackend) -> None: | ||
| if not torch.cuda.is_available(): | ||
| raise RuntimeError(f"{backend} backend requires CUDA.") | ||
| capability = torch.cuda.get_device_capability() | ||
| if backend == AttentionBackend.FP8_FA3: | ||
| if not _is_hopper(): | ||
| raise RuntimeError( | ||
| f"FP8_FA3 requires Hopper (SM 9.x), got SM{capability[0]}{capability[1]}." | ||
| ) | ||
| if not _is_fa3_available(): | ||
| raise RuntimeError( | ||
| "FP8_FA3 requires the flash-attn package with FA3 support." | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown backend: {backend}") | ||
|
|
||
|
|
||
| def apply_low_precision_attention( | ||
| model: nn.Module, | ||
| backend: Optional[AttentionBackend] = None, | ||
| fuse_rope_using_torch_compile: bool = False, | ||
| ) -> nn.Module: | ||
| """Apply low-precision attention to a model. | ||
|
|
||
| Must be called before ``torch.compile``. KV caching should be | ||
| disabled before calling (e.g., ``config.use_cache = False`` for | ||
| HuggingFace models). | ||
|
|
||
| When ``fuse_rope_using_torch_compile=True``, the returned wrapper | ||
| exposes a ``compile_backend`` attribute. You must compile with it to get | ||
| the RoPE fusion:: | ||
|
|
||
| model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True) | ||
| model = torch.compile(model, backend=model.compile_backend) | ||
| """ | ||
| if not torch_version_at_least("2.11.0"): | ||
| raise RuntimeError("Low-precision attention requires PyTorch 2.11+.") | ||
| if isinstance(model, _LowPrecisionAttentionWrapper): | ||
| raise RuntimeError( | ||
| "apply_low_precision_attention has already been applied to this module." | ||
| ) | ||
| if isinstance(model, torch._dynamo.OptimizedModule): | ||
| raise RuntimeError( | ||
| "apply_low_precision_attention must be called before torch.compile." | ||
| ) | ||
|
|
||
| if backend is None: | ||
| backend = _get_available_backend() | ||
| else: | ||
| _check_backend_available(backend) | ||
|
|
||
| if backend == AttentionBackend.FP8_FA3: | ||
| return setup_fp8_backend(model, "FA3", fuse_rope_using_torch_compile) | ||
|
|
||
| raise ValueError(f"Unknown backend: {backend}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| FP8 attention using FA3 backend. | ||
| """ | ||
|
|
||
| from torchao.prototype.attention.fp8_fa3.attention import ( | ||
| fp8_fa3_rope_sdpa, | ||
| fp8_fa3_sdpa, | ||
| ) | ||
| from torchao.prototype.attention.quantization import _fp8_sdpa_quantize | ||
|
|
||
| __all__ = [ | ||
| "fp8_fa3_sdpa", | ||
| "fp8_fa3_rope_sdpa", | ||
| "_fp8_sdpa_quantize", | ||
| ] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be explicit that parts of torch.compile is used to do the logic swap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, additionally added a warning as well to be even more explicit