Skip to content
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
cf8280f
Update
howardzhang-cv Feb 11, 2026
9acfc52
Update (base update)
howardzhang-cv Feb 11, 2026
88dff89
Update
howardzhang-cv Feb 11, 2026
11e7cad
Update
howardzhang-cv Feb 12, 2026
95cccd5
Update (base update)
howardzhang-cv Feb 12, 2026
fdf88ac
Update
howardzhang-cv Feb 12, 2026
ad075ac
Update (base update)
howardzhang-cv Feb 12, 2026
878b464
Update
howardzhang-cv Feb 13, 2026
3be7bbb
Update
howardzhang-cv Feb 13, 2026
3eea34a
Update (base update)
howardzhang-cv Feb 13, 2026
333e08c
Update
howardzhang-cv Feb 13, 2026
8e227d0
Update
howardzhang-cv Feb 13, 2026
d85dcc2
Update
howardzhang-cv Feb 21, 2026
56ba611
Update (base update)
howardzhang-cv Feb 21, 2026
aac4e70
Update
howardzhang-cv Feb 24, 2026
9756826
Update (base update)
howardzhang-cv Feb 24, 2026
32858e9
Update
howardzhang-cv Feb 25, 2026
548d7ef
Update
howardzhang-cv Feb 25, 2026
e3c6014
Update (base update)
howardzhang-cv Feb 25, 2026
97eafd5
Update
howardzhang-cv Feb 27, 2026
0a042ad
Update (base update)
howardzhang-cv Feb 27, 2026
b6e59d0
Update
howardzhang-cv Feb 28, 2026
44a7429
Update (base update)
howardzhang-cv Feb 28, 2026
a64a978
Update
howardzhang-cv Feb 28, 2026
411886b
Update (base update)
howardzhang-cv Feb 28, 2026
264d2bd
Update
howardzhang-cv Feb 28, 2026
74f3cfd
Update (base update)
howardzhang-cv Feb 28, 2026
708547f
Update
howardzhang-cv Mar 2, 2026
d60829a
Update (base update)
howardzhang-cv Mar 2, 2026
1d26fd8
Update
howardzhang-cv Mar 3, 2026
68efede
Update (base update)
howardzhang-cv Mar 3, 2026
e5a8c5a
Update
howardzhang-cv Mar 3, 2026
fec81e6
Update (base update)
howardzhang-cv Mar 5, 2026
669829e
Update
howardzhang-cv Mar 5, 2026
edb1f38
Update (base update)
howardzhang-cv Mar 6, 2026
7db5ce9
Update
howardzhang-cv Mar 6, 2026
58b0e6a
Update (base update)
howardzhang-cv Mar 6, 2026
d18f997
Update
howardzhang-cv Mar 6, 2026
100382a
Update (base update)
howardzhang-cv Mar 6, 2026
58c838f
Update
howardzhang-cv Mar 6, 2026
c348a9f
Update (base update)
howardzhang-cv Mar 7, 2026
a719b90
Update
howardzhang-cv Mar 7, 2026
f140854
Update (base update)
howardzhang-cv Mar 7, 2026
ed23fd0
Update
howardzhang-cv Mar 7, 2026
94d9200
Merge branch 'main' into gh/howardzhang-cv/16/head
howardzhang-cv Mar 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
248 changes: 248 additions & 0 deletions test/prototype/attention/test_fp8_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for FP8 low-precision attention (FA3 backend on Hopper)."""

import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase, run_tests

from torchao.prototype.attention import (
AttentionBackend,
apply_low_precision_attention,
)
from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper
from torchao.utils import torch_version_at_least

if torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available():
from torch.nn.attention import (
activate_flash_attention_impl,
restore_flash_attention_impl,
)

from torchao.prototype.attention.fp8_fa3.attention import (
fp8_fa3_rope_sdpa,
fp8_fa3_sdpa,
)
from torchao.quantization.utils import compute_error


def _rope_cos_sin(S, D, device):
freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
angles = torch.outer(torch.arange(S, dtype=torch.float32), freqs)
cos_half = torch.cos(angles)
sin_half = torch.sin(angles)
cos = torch.cat([cos_half, cos_half], dim=-1).to(device)
sin = torch.cat([sin_half, sin_half], dim=-1).to(device)
return cos, sin


def _apply_rope(x, cos, sin):
"""NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D]."""
D_HALF = x.shape[-1] // 2
rotate = torch.cat([-x[..., D_HALF:], x[..., :D_HALF]], dim=-1)
return (
x * cos.unsqueeze(0).unsqueeze(2) + rotate * sin.unsqueeze(0).unsqueeze(2)
).to(x.dtype)


class SimpleAttentionModel(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

def forward(self, x):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))


class SimpleRoPEAttentionModel(nn.Module):
"""Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA)."""

def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

def forward(self, x, cos, sin):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim)
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim)
q = _apply_rope(q, cos, sin).transpose(1, 2)
k = _apply_rope(k, cos, sin).transpose(1, 2)
v = v.transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))


@common_utils.instantiate_parametrized_tests
class TestFP8FA3Attention(TestCase):
@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("shape", [(2, 8, 1024, 64), (1, 16, 1024, 128)])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_sdpa_accuracy(self, shape, dtype):
B, H, S, D = shape
q = torch.randn(B, H, S, D, device="cuda", dtype=dtype)
k = torch.randn(B, H, S, D, device="cuda", dtype=dtype)
v = torch.randn(B, H, S, D, device="cuda", dtype=dtype)

with torch.no_grad():
out_ref = F.scaled_dot_product_attention(q, k, v, is_causal=False)

activate_flash_attention_impl("FA3")
try:
with torch.no_grad():
out_fp8 = fp8_fa3_sdpa(q, k, v, is_causal=False)
finally:
restore_flash_attention_impl()

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
25.0,
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("shape", [(2, 1024, 8, 64), (1, 1024, 16, 128)])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_rope_sdpa_accuracy(self, shape, dtype):
B, S, H, D = shape
q = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
k = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
v = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
cos, sin = _rope_cos_sin(S, D, "cuda")

with torch.no_grad():
out_ref = F.scaled_dot_product_attention(
_apply_rope(q, cos, sin).transpose(1, 2),
_apply_rope(k, cos, sin).transpose(1, 2),
v.transpose(1, 2),
is_causal=False,
)

activate_flash_attention_impl("FA3")
try:
with torch.no_grad():
out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False)
finally:
restore_flash_attention_impl()

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
25.0,
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_monkey_patch_model(self, dtype):
embed_dim, num_heads = 512, 8
model = (
SimpleAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
x = torch.randn(2, 128, embed_dim, device="cuda", dtype=dtype)

with torch.no_grad():
out_ref = model(x)

fp8_model = (
SimpleAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
fp8_model.load_state_dict(model.state_dict())
fp8_model = apply_low_precision_attention(
fp8_model,
backend=AttentionBackend.FP8_FA3,
fuse_rope_using_torch_compile=False,
)

with torch.no_grad():
out_fp8 = fp8_model(x)

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
20.0,
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_rope_fusion_model(self, dtype):
embed_dim, num_heads = 512, 8
model = (
SimpleRoPEAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
S = 128
x = torch.randn(2, S, embed_dim, device="cuda", dtype=dtype)
cos, sin = _rope_cos_sin(S, embed_dim // num_heads, "cuda")

with torch.no_grad():
out_ref = model(x, cos, sin)

fp8_model = (
SimpleRoPEAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
fp8_model.load_state_dict(model.state_dict())
fp8_model = apply_low_precision_attention(
fp8_model,
backend=AttentionBackend.FP8_FA3,
fuse_rope_using_torch_compile=True,
)
fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend)

with torch.no_grad():
out_fp8 = fp8_model(x, cos, sin)

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
20.0,
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
)


if __name__ == "__main__":
run_tests()
21 changes: 21 additions & 0 deletions torchao/prototype/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Low-precision attention for inference.

Only supports forward pass — backward is not supported by the underlying backends.
"""

from torchao.prototype.attention.api import (
AttentionBackend,
apply_low_precision_attention,
)

__all__ = [
"AttentionBackend",
"apply_low_precision_attention",
]
93 changes: 93 additions & 0 deletions torchao/prototype/attention/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""User-facing API for low-precision attention."""

from enum import Enum
from typing import Optional

import torch
import torch._dynamo
import torch.nn as nn

from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend
from torchao.prototype.attention.shared_utils.wrapper import (
_LowPrecisionAttentionWrapper,
)
from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper
from torchao.utils import torch_version_at_least


class AttentionBackend(str, Enum):
"""Backend kernel for computing attention."""

FP8_FA3 = "FP8_FA3" # Requires SM90+ (Hopper)


def _get_available_backend() -> AttentionBackend:
if not torch.cuda.is_available():
raise RuntimeError("Low-precision attention requires CUDA.")
capability = torch.cuda.get_device_capability()
if _is_hopper() and _is_fa3_available():
return AttentionBackend.FP8_FA3
raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.")


def _check_backend_available(backend: AttentionBackend) -> None:
if not torch.cuda.is_available():
raise RuntimeError(f"{backend} backend requires CUDA.")
capability = torch.cuda.get_device_capability()
if backend == AttentionBackend.FP8_FA3:
if not _is_hopper():
raise RuntimeError(
f"FP8_FA3 requires Hopper (SM 9.x), got SM{capability[0]}{capability[1]}."
)
if not _is_fa3_available():
raise RuntimeError(
"FP8_FA3 requires the flash-attn package with FA3 support."
)
else:
raise ValueError(f"Unknown backend: {backend}")


def apply_low_precision_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be explicit that parts of torch.compile is used to do the logic swap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, additionally added a warning as well to be even more explicit

model: nn.Module,
backend: Optional[AttentionBackend] = None,
fuse_rope_using_torch_compile: bool = False,
) -> nn.Module:
"""Apply low-precision attention to a model.

Must be called before ``torch.compile``. KV caching should be
disabled before calling (e.g., ``config.use_cache = False`` for
HuggingFace models).

When ``fuse_rope_using_torch_compile=True``, the returned wrapper
exposes a ``compile_backend`` attribute. You must compile with it to get
the RoPE fusion::

model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True)
model = torch.compile(model, backend=model.compile_backend)
"""
if not torch_version_at_least("2.11.0"):
raise RuntimeError("Low-precision attention requires PyTorch 2.11+.")
if isinstance(model, _LowPrecisionAttentionWrapper):
raise RuntimeError(
"apply_low_precision_attention has already been applied to this module."
)
if isinstance(model, torch._dynamo.OptimizedModule):
raise RuntimeError(
"apply_low_precision_attention must be called before torch.compile."
)

if backend is None:
backend = _get_available_backend()
else:
_check_backend_available(backend)

if backend == AttentionBackend.FP8_FA3:
return setup_fp8_backend(model, "FA3", fuse_rope_using_torch_compile)

raise ValueError(f"Unknown backend: {backend}")
21 changes: 21 additions & 0 deletions torchao/prototype/attention/fp8_fa3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
FP8 attention using FA3 backend.
"""

from torchao.prototype.attention.fp8_fa3.attention import (
fp8_fa3_rope_sdpa,
fp8_fa3_sdpa,
)
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize

__all__ = [
"fp8_fa3_sdpa",
"fp8_fa3_rope_sdpa",
"_fp8_sdpa_quantize",
]
Loading
Loading