Skip to content

Commit 66bc20d

Browse files
Add FA4 monkey-patch path for low-precision attention
ghstack-source-id: 6a451ce Pull-Request: #3960 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent d3111be commit 66bc20d

File tree

4 files changed

+91
-1
lines changed

4 files changed

+91
-1
lines changed

torchao/prototype/attention/api.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
import torch._dynamo
1414
import torch.nn as nn
1515

16-
from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper
16+
from torchao.prototype.attention.utils import (
17+
_is_blackwell,
18+
_is_fa3_available,
19+
_is_fa4_available,
20+
_is_hopper,
21+
)
1722
from torchao.utils import torch_version_at_least
1823

1924
_TORCH_VERSION_AT_LEAST_2_11 = torch_version_at_least("2.11.0")
@@ -29,14 +34,19 @@ class AttentionBackend(str, Enum):
2934
"""Backend kernel for computing attention."""
3035

3136
FP8_FA3 = "FP8_FA3" # Requires SM90+ (Hopper)
37+
FP8_FA4 = "FP8_FA4" # Requires SM90+ (Hopper) or SM100+ (Blackwell)
3238

3339

3440
def _get_available_backend() -> AttentionBackend:
3541
if not torch.cuda.is_available():
3642
raise RuntimeError("Low-precision attention requires CUDA.")
3743
capability = torch.cuda.get_device_capability()
44+
if _is_blackwell() and _is_fa4_available():
45+
return AttentionBackend.FP8_FA4
3846
if _is_hopper() and _is_fa3_available():
3947
return AttentionBackend.FP8_FA3
48+
if _is_hopper() and _is_fa4_available():
49+
return AttentionBackend.FP8_FA4
4050
raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.")
4151

4252

@@ -53,6 +63,16 @@ def _check_backend_available(backend: AttentionBackend) -> None:
5363
raise RuntimeError(
5464
"FP8_FA3 requires the flash-attn package with FA3 support."
5565
)
66+
elif backend == AttentionBackend.FP8_FA4:
67+
if not (_is_hopper() or _is_blackwell()):
68+
raise RuntimeError(
69+
f"FP8_FA4 requires Hopper or Blackwell, got SM{capability[0]}{capability[1]}."
70+
)
71+
if not _is_fa4_available():
72+
raise RuntimeError(
73+
"FP8_FA4 requires the flash-attn package with FA4 support "
74+
"(flash_attn.cute.interface)."
75+
)
5676
else:
5777
raise ValueError(f"Unknown backend: {backend}")
5878

@@ -95,4 +115,7 @@ def apply_low_precision_attention(
95115
if backend == AttentionBackend.FP8_FA3:
96116
return setup_fp8_backend(model, "FA3")
97117

118+
if backend == AttentionBackend.FP8_FA4:
119+
return setup_fp8_backend(model, "FA4")
120+
98121
raise ValueError(f"Unknown backend: {backend}")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
FP8 attention using FA4 backend.
9+
"""
10+
11+
from torchao.prototype.attention.fp8_fa4.attention import (
12+
fp8_fa4_rope_sdpa,
13+
fp8_fa4_sdpa,
14+
)
15+
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize
16+
17+
__all__ = [
18+
"fp8_fa4_sdpa",
19+
"fp8_fa4_rope_sdpa",
20+
"_fp8_sdpa_quantize",
21+
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
FP8 SDPA using FA4 backend.
9+
10+
When using these functions directly (not through apply_low_precision_attention),
11+
you must activate FA4 yourself::
12+
13+
activate_flash_attention_impl("FA4")
14+
try:
15+
out = fp8_fa4_sdpa(q, k, v, is_causal=True)
16+
finally:
17+
restore_flash_attention_impl()
18+
"""
19+
20+
from functools import partial
21+
22+
from torchao.prototype.attention.shared_utils.attention import (
23+
_fp8_rope_sdpa,
24+
_fp8_sdpa,
25+
)
26+
from torchao.prototype.attention.shared_utils.custom_ops import (
27+
register_fp8_attention_ops,
28+
)
29+
30+
fp8_fa4_sdpa = partial(_fp8_sdpa, backend_name="FA4")
31+
fp8_fa4_sdpa.__doc__ = _fp8_sdpa.__doc__
32+
fp8_fa4_sdpa.__name__ = "fp8_fa4_sdpa"
33+
fp8_fa4_sdpa.__qualname__ = "fp8_fa4_sdpa"
34+
35+
fp8_fa4_rope_sdpa = partial(_fp8_rope_sdpa, backend_name="FA4")
36+
fp8_fa4_rope_sdpa.__doc__ = _fp8_rope_sdpa.__doc__
37+
fp8_fa4_rope_sdpa.__name__ = "fp8_fa4_rope_sdpa"
38+
fp8_fa4_rope_sdpa.__qualname__ = "fp8_fa4_rope_sdpa"
39+
40+
_ops = register_fp8_attention_ops(
41+
backend_name="fa4",
42+
rope_sdpa_fn=fp8_fa4_rope_sdpa,
43+
sdpa_fn=fp8_fa4_sdpa,
44+
)

torchao/prototype/attention/shared_utils/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def setup_fp8_backend(
2525
) -> nn.Module:
2626
if flash_impl_name == "FA3":
2727
from torchao.prototype.attention.fp8_fa3.attention import _ops
28+
elif flash_impl_name == "FA4":
29+
from torchao.prototype.attention.fp8_fa4.attention import _ops
2830
else:
2931
raise ValueError(f"Unknown flash_impl_name: {flash_impl_name}")
3032

0 commit comments

Comments
 (0)