-
Notifications
You must be signed in to change notification settings - Fork 463
Add FP8 FA3 low-precision attention with monkey-patch SDPA path #3959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ae6ccf2
Update
howardzhang-cv 6c26be4
Update (base update)
howardzhang-cv 42277d8
Update
howardzhang-cv 51017d2
Update (base update)
howardzhang-cv d79b6c5
Update
howardzhang-cv 980d112
Update
howardzhang-cv f8f2f53
Update (base update)
howardzhang-cv 8750752
Update
howardzhang-cv 989cf48
Update
howardzhang-cv 1e0f2c3
Update
howardzhang-cv 8ed4023
Update (base update)
howardzhang-cv bdde034
Update
howardzhang-cv e714719
Update (base update)
howardzhang-cv 0f86fed
Update
howardzhang-cv 354cd87
Update
howardzhang-cv 4f463a8
Update
howardzhang-cv 699c526
Update
howardzhang-cv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,262 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Tests for FP8 low-precision attention (FA3 backend). | ||
|
|
||
| Tests are gated on Hopper (SM 9.x) with flash-attn installed. | ||
| When the backend is not available on the current hardware, tests are | ||
| automatically skipped. | ||
| """ | ||
|
|
||
| import unittest | ||
| from dataclasses import dataclass | ||
| from typing import Callable, List | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from torchao.utils import torch_version_at_least | ||
|
|
||
| _TORCH_VERSION_AT_LEAST_2_11 = torch_version_at_least("2.11.0") | ||
|
|
||
| if _TORCH_VERSION_AT_LEAST_2_11: | ||
| from torch.nn.attention import ( | ||
| activate_flash_attention_impl, | ||
| restore_flash_attention_impl, | ||
| ) | ||
|
|
||
| from torch.testing._internal import common_utils | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao.prototype.attention import ( | ||
| AttentionBackend, | ||
| LowPrecisionAttentionConfig, | ||
| apply_low_precision_attention, | ||
| ) | ||
| from torchao.prototype.attention.utils import ( | ||
| _is_fa3_available, | ||
| _is_hopper, | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Backend configuration | ||
| # --------------------------------------------------------------------------- | ||
| @dataclass | ||
| class BackendConfig: | ||
| """Configuration for a single backend under test.""" | ||
|
|
||
| name: str | ||
| flash_impl: str # "FA3" | ||
| attention_backend: AttentionBackend | ||
| sdpa_fn: Callable # fp8_fa3_sdpa | ||
| available_eager: bool # Can run direct sdpa calls | ||
| available_compiled: bool # Can run via apply_low_precision_attention | ||
| skip_msg: str | ||
|
|
||
|
|
||
| def _probe_eager_quantized_sdpa(sdpa_fn, flash_impl: str) -> bool: | ||
| """Try a tiny quantized SDPA call to verify the backend works in eager mode. | ||
|
|
||
| FA3 uses _scaled_dot_product_attention_quantized internally, | ||
| which requires FA3 activation. This probe catches mismatches. | ||
| """ | ||
| try: | ||
| activate_flash_attention_impl(flash_impl) | ||
| try: | ||
| q = torch.randn(1, 1, 4, 64, device="cuda", dtype=torch.bfloat16) | ||
| with torch.no_grad(): | ||
| sdpa_fn(q, q, q, is_causal=False) | ||
|
||
| return True | ||
| except RuntimeError: | ||
| return False | ||
| finally: | ||
| restore_flash_attention_impl() | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
||
| def _build_backend_configs() -> List[BackendConfig]: | ||
| """Build backend configs, lazily importing functions only when available.""" | ||
| configs = [] | ||
|
|
||
| # FA3: Hopper only | ||
| fa3_available = ( | ||
| _TORCH_VERSION_AT_LEAST_2_11 and _is_hopper() and _is_fa3_available() | ||
| ) | ||
| if fa3_available: | ||
| from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa | ||
|
|
||
| sdpa_fn = fp8_fa3_sdpa | ||
| eager_ok = _probe_eager_quantized_sdpa(sdpa_fn, "FA3") | ||
| else: | ||
| sdpa_fn = None | ||
| eager_ok = False | ||
|
|
||
| configs.append( | ||
| BackendConfig( | ||
| name="FA3", | ||
| flash_impl="FA3", | ||
| attention_backend=AttentionBackend.FP8_FA3, | ||
| sdpa_fn=sdpa_fn, | ||
| available_eager=eager_ok, | ||
| available_compiled=eager_ok, | ||
| skip_msg=( | ||
| "FP8 FA3 requires Hopper (SM 9.x), flash-attn installed, " | ||
| "and PyTorch with FA3 activation APIs" | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| return configs | ||
|
|
||
|
|
||
| _BACKEND_CONFIGS = _build_backend_configs() | ||
| _EAGER_BACKENDS = [c for c in _BACKEND_CONFIGS if c.available_eager] | ||
| _COMPILED_BACKENDS = [c for c in _BACKEND_CONFIGS if c.available_compiled] | ||
| _ANY_EAGER_AVAILABLE = len(_EAGER_BACKENDS) > 0 | ||
| _ANY_COMPILED_AVAILABLE = len(_COMPILED_BACKENDS) > 0 | ||
| _NO_EAGER_SKIP_MSG = "No FP8 attention backend available for eager mode" | ||
| _NO_COMPILED_SKIP_MSG = "No FP8 attention backend available for compiled mode" | ||
|
|
||
| if _ANY_EAGER_AVAILABLE or _ANY_COMPILED_AVAILABLE: | ||
| from torchao.quantization.utils import compute_error | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Simple model for API-level tests | ||
| # --------------------------------------------------------------------------- | ||
| class SimpleAttentionModel(nn.Module): | ||
| """A minimal model that calls F.scaled_dot_product_attention.""" | ||
|
|
||
| def __init__(self, embed_dim, num_heads): | ||
| super().__init__() | ||
| self.num_heads = num_heads | ||
| self.head_dim = embed_dim // num_heads | ||
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
|
|
||
| def forward(self, x): | ||
| B, S, _ = x.shape | ||
| q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | ||
| k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | ||
| v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | ||
| attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) | ||
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1) | ||
| return self.out_proj(attn_out) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Numerical accuracy tests | ||
| # --------------------------------------------------------------------------- | ||
| @common_utils.instantiate_parametrized_tests | ||
| class TestFP8SDPANumericalAccuracy(TestCase): | ||
| """SQNR-based numerical accuracy tests for FP8 SDPA.""" | ||
|
|
||
| def setUp(self): | ||
| self._active_backend = None | ||
|
|
||
| def tearDown(self): | ||
| if self._active_backend is not None: | ||
| restore_flash_attention_impl() | ||
|
|
||
| def _activate(self, backend: BackendConfig): | ||
| activate_flash_attention_impl(backend.flash_impl) | ||
| self._active_backend = backend | ||
|
|
||
| @unittest.skipIf(not _ANY_EAGER_AVAILABLE, _NO_EAGER_SKIP_MSG) | ||
| @common_utils.parametrize( | ||
| "shape", | ||
| [ | ||
| (2, 8, 1024, 64), | ||
| (1, 16, 1024, 128), | ||
| ], | ||
| ) | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_sdpa_accuracy(self, shape, dtype): | ||
| """FP8 SDPA output matches regular SDPA within acceptable SQNR.""" | ||
| B, H, S, D = shape | ||
| q = torch.randn(B, H, S, D, device="cuda", dtype=dtype) | ||
| k = torch.randn(B, H, S, D, device="cuda", dtype=dtype) | ||
| v = torch.randn(B, H, S, D, device="cuda", dtype=dtype) | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = F.scaled_dot_product_attention(q, k, v, is_causal=False) | ||
|
|
||
| for backend in _EAGER_BACKENDS: | ||
| self._activate(backend) | ||
| with torch.no_grad(): | ||
| out_fp8 = backend.sdpa_fn(q, k, v, is_causal=False) | ||
| restore_flash_attention_impl() | ||
| self._active_backend = None | ||
|
|
||
| sqnr = compute_error(out_ref, out_fp8) | ||
| self.assertGreater( | ||
| sqnr.item(), | ||
| 25.0, | ||
| f"[{backend.name}] SQNR {sqnr.item():.2f} dB below threshold " | ||
| f"of 25 dB for shape={shape}, dtype={dtype}", | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # API-level model tests | ||
| # --------------------------------------------------------------------------- | ||
| @common_utils.instantiate_parametrized_tests | ||
| class TestFP8ModelAPI(TestCase): | ||
| """API-level tests using apply_low_precision_attention on a model.""" | ||
|
|
||
| @unittest.skipIf(not _ANY_COMPILED_AVAILABLE, _NO_COMPILED_SKIP_MSG) | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_apply_to_model_accuracy(self, dtype): | ||
| """apply_low_precision_attention produces output close to original model.""" | ||
| embed_dim, num_heads = 256, 8 | ||
| model = SimpleAttentionModel(embed_dim, num_heads).to( | ||
| device="cuda", dtype=dtype | ||
| ) | ||
| model.eval() | ||
|
|
||
| x = torch.randn(2, 128, embed_dim, device="cuda", dtype=dtype) | ||
|
|
||
| with torch.no_grad(): | ||
| out_ref = model(x) | ||
|
|
||
| for backend in _COMPILED_BACKENDS: | ||
| # Need a fresh model for each backend since | ||
| # apply_low_precision_attention modifies the model. | ||
| test_model = SimpleAttentionModel(embed_dim, num_heads).to( | ||
| device="cuda", dtype=dtype | ||
| ) | ||
| test_model.load_state_dict(model.state_dict()) | ||
| test_model.eval() | ||
|
|
||
| config = LowPrecisionAttentionConfig( | ||
| backend=backend.attention_backend, | ||
| ) | ||
| test_model = apply_low_precision_attention(test_model, config) | ||
|
|
||
| with torch.no_grad(): | ||
| out_fp8 = test_model(x) | ||
|
|
||
| sqnr = compute_error(out_ref, out_fp8) | ||
| self.assertGreater( | ||
| sqnr.item(), | ||
| 20.0, | ||
| f"[{backend.name}] SQNR " | ||
| f"{sqnr.item():.2f} dB below threshold " | ||
| f"for model-level test, dtype={dtype}", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Low-precision attention for inference. | ||
|
|
||
| This module provides APIs for running attention with reduced precision | ||
| (e.g., FP8) for faster inference. It can be extended to support different | ||
| quantization strategies and different PyTorch core attention backends. | ||
|
|
||
| Note: Low-precision attention only supports inference (forward pass). | ||
| Backward pass is not supported by the underlying backends. | ||
|
|
||
| Note: apply_low_precision_attention replaces all F.scaled_dot_product_attention | ||
| calls within the model with the configured low-precision backend. | ||
|
|
||
| Example:: | ||
|
|
||
| from torchao.prototype.attention import apply_low_precision_attention | ||
|
|
||
| model = MyTransformer() | ||
| model = apply_low_precision_attention(model) | ||
| output = model(inputs) | ||
| """ | ||
|
|
||
| from torchao.prototype.attention.api import apply_low_precision_attention | ||
| from torchao.prototype.attention.config import ( | ||
| AttentionBackend, | ||
| LowPrecisionAttentionConfig, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| # Config | ||
| "LowPrecisionAttentionConfig", | ||
| "AttentionBackend", | ||
| # API | ||
| "apply_low_precision_attention", | ||
| ] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you run a deslop pass on some of this, I kind of doubt you need code like this