Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d26cabd
Update (base update)
howardzhang-cv Feb 12, 2026
eb083c0
Update
howardzhang-cv Feb 12, 2026
08eaf40
Update (base update)
howardzhang-cv Feb 12, 2026
2380610
Update
howardzhang-cv Feb 12, 2026
d372af3
Update (base update)
howardzhang-cv Feb 13, 2026
994a587
Update
howardzhang-cv Feb 13, 2026
8de0af2
Update (base update)
howardzhang-cv Feb 13, 2026
b17b5ec
Update
howardzhang-cv Feb 13, 2026
ad96369
Update (base update)
howardzhang-cv Feb 13, 2026
9589d68
Update
howardzhang-cv Feb 13, 2026
33493e6
Update (base update)
howardzhang-cv Feb 13, 2026
fe36423
Update
howardzhang-cv Feb 13, 2026
3287be2
Update (base update)
howardzhang-cv Feb 21, 2026
d65530f
Update
howardzhang-cv Feb 21, 2026
b3e130c
Update (base update)
howardzhang-cv Feb 24, 2026
365224a
Update
howardzhang-cv Feb 24, 2026
f9a62a3
Update (base update)
howardzhang-cv Feb 25, 2026
624643e
Update
howardzhang-cv Feb 25, 2026
2c93577
Update (base update)
howardzhang-cv Feb 25, 2026
1757ff9
Update
howardzhang-cv Feb 25, 2026
0986316
Update (base update)
howardzhang-cv Feb 27, 2026
a696f1b
Update
howardzhang-cv Feb 27, 2026
48f8c5f
Update (base update)
howardzhang-cv Feb 28, 2026
4c1253d
Update
howardzhang-cv Feb 28, 2026
3049c3e
Update (base update)
howardzhang-cv Feb 28, 2026
982c617
Update
howardzhang-cv Feb 28, 2026
df8919a
Update (base update)
howardzhang-cv Feb 28, 2026
3f22aad
Update
howardzhang-cv Feb 28, 2026
c26c9ec
Update (base update)
howardzhang-cv Mar 2, 2026
2a0aaed
Update
howardzhang-cv Mar 2, 2026
044cc6e
Update (base update)
howardzhang-cv Mar 3, 2026
b696b53
Update
howardzhang-cv Mar 3, 2026
665e1fc
Update (base update)
howardzhang-cv Mar 3, 2026
1341000
Update
howardzhang-cv Mar 3, 2026
57ca528
Update (base update)
howardzhang-cv Mar 5, 2026
e1fd08c
Update
howardzhang-cv Mar 5, 2026
f7599cb
Update (base update)
howardzhang-cv Mar 6, 2026
ed18070
Update
howardzhang-cv Mar 6, 2026
2866109
Update (base update)
howardzhang-cv Mar 6, 2026
f72cce0
Update
howardzhang-cv Mar 6, 2026
ac50eaf
Update (base update)
howardzhang-cv Mar 6, 2026
bb511ef
Update
howardzhang-cv Mar 6, 2026
a6fb839
Update (base update)
howardzhang-cv Mar 7, 2026
4fc8c25
Update
howardzhang-cv Mar 7, 2026
d740be0
Update (base update)
howardzhang-cv Mar 7, 2026
6065db8
Update
howardzhang-cv Mar 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
460 changes: 460 additions & 0 deletions benchmarks/prototype/attention/eval_flux_model.py

Large diffs are not rendered by default.

Empty file.
250 changes: 250 additions & 0 deletions test/prototype/attention/test_fp8_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for FP8 low-precision attention (FA3 backend on Hopper)."""

import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase, run_tests

from torchao.quantization.utils import compute_error
from torchao.utils import torch_version_at_least

if torch_version_at_least("2.11.0"):
from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper

if _is_hopper() and _is_fa3_available():
from torch.nn.attention import (
activate_flash_attention_impl,
restore_flash_attention_impl,
)

from torchao.prototype.attention import (
AttentionBackend,
apply_low_precision_attention,
)
from torchao.prototype.attention.fp8_fa3.attention import (
fp8_fa3_rope_sdpa,
fp8_fa3_sdpa,
)


def _rope_cos_sin(S, D, device):
freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
angles = torch.outer(torch.arange(S, dtype=torch.float32), freqs)
cos_half = torch.cos(angles)
sin_half = torch.sin(angles)
cos = torch.cat([cos_half, cos_half], dim=-1).to(device)
sin = torch.cat([sin_half, sin_half], dim=-1).to(device)
return cos, sin


def _apply_rope(x, cos, sin):
"""NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D]."""
D_HALF = x.shape[-1] // 2
rotate = torch.cat([-x[..., D_HALF:], x[..., :D_HALF]], dim=-1)
return (
x * cos.unsqueeze(0).unsqueeze(2) + rotate * sin.unsqueeze(0).unsqueeze(2)
).to(x.dtype)


class SimpleAttentionModel(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

def forward(self, x):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))


class SimpleRoPEAttentionModel(nn.Module):
"""Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA)."""

def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

def forward(self, x, cos, sin):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim)
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim)
q = _apply_rope(q, cos, sin).transpose(1, 2)
k = _apply_rope(k, cos, sin).transpose(1, 2)
v = v.transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))


@common_utils.instantiate_parametrized_tests
class TestFP8FA3Attention(TestCase):
@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("shape", [(2, 8, 1024, 64), (1, 16, 1024, 128)])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_sdpa_accuracy(self, shape, dtype):
B, H, S, D = shape
q = torch.randn(B, H, S, D, device="cuda", dtype=dtype)
k = torch.randn(B, H, S, D, device="cuda", dtype=dtype)
v = torch.randn(B, H, S, D, device="cuda", dtype=dtype)

with torch.no_grad():
out_ref = F.scaled_dot_product_attention(q, k, v, is_causal=False)

activate_flash_attention_impl("FA3")
try:
with torch.no_grad():
out_fp8 = fp8_fa3_sdpa(q, k, v, is_causal=False)
finally:
restore_flash_attention_impl()

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
25.0,
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("shape", [(2, 1024, 8, 64), (1, 1024, 16, 128)])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_rope_sdpa_accuracy(self, shape, dtype):
B, S, H, D = shape
q = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
k = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
v = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
cos, sin = _rope_cos_sin(S, D, "cuda")

with torch.no_grad():
out_ref = F.scaled_dot_product_attention(
_apply_rope(q, cos, sin).transpose(1, 2),
_apply_rope(k, cos, sin).transpose(1, 2),
v.transpose(1, 2),
is_causal=False,
)

activate_flash_attention_impl("FA3")
try:
with torch.no_grad():
out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False)
finally:
restore_flash_attention_impl()

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
25.0,
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_monkey_patch_model(self, dtype):
embed_dim, num_heads = 512, 8
model = (
SimpleAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
x = torch.randn(2, 128, embed_dim, device="cuda", dtype=dtype)

with torch.no_grad():
out_ref = model(x)

fp8_model = (
SimpleAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
fp8_model.load_state_dict(model.state_dict())
fp8_model = apply_low_precision_attention(
fp8_model,
backend=AttentionBackend.FP8_FA3,
fuse_rope_using_torch_compile=False,
)

with torch.no_grad():
out_fp8 = fp8_model(x)

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
20.0,
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
)

@unittest.skipUnless(
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_rope_fusion_model(self, dtype):
embed_dim, num_heads = 512, 8
model = (
SimpleRoPEAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
S = 128
x = torch.randn(2, S, embed_dim, device="cuda", dtype=dtype)
cos, sin = _rope_cos_sin(S, embed_dim // num_heads, "cuda")

with torch.no_grad():
out_ref = model(x, cos, sin)

fp8_model = (
SimpleRoPEAttentionModel(embed_dim, num_heads)
.to(device="cuda", dtype=dtype)
.eval()
)
fp8_model.load_state_dict(model.state_dict())
fp8_model = apply_low_precision_attention(
fp8_model,
backend=AttentionBackend.FP8_FA3,
fuse_rope_using_torch_compile=True,
)
fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend)

with torch.no_grad():
out_fp8 = fp8_model(x, cos, sin)

sqnr = compute_error(out_ref, out_fp8)
self.assertGreater(
sqnr.item(),
20.0,
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
)


if __name__ == "__main__":
run_tests()
21 changes: 21 additions & 0 deletions torchao/prototype/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Low-precision attention for inference.

Only supports forward pass — backward is not supported by the underlying backends.
"""

from torchao.prototype.attention.api import (
AttentionBackend,
apply_low_precision_attention,
)

__all__ = [
"AttentionBackend",
"apply_low_precision_attention",
]
95 changes: 95 additions & 0 deletions torchao/prototype/attention/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""User-facing API for low-precision attention."""

from enum import Enum
from typing import Optional

import torch
import torch._dynamo
import torch.nn as nn

from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper
from torchao.utils import torch_version_at_least

if torch_version_at_least("2.11.0"):
from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend
from torchao.prototype.attention.shared_utils.wrapper import (
_LowPrecisionAttentionWrapper,
)
else:
raise ImportError("Low-precision attention requires PyTorch 2.11+.")


class AttentionBackend(str, Enum):
"""Backend kernel for computing attention."""

FP8_FA3 = "FP8_FA3" # Requires SM90+ (Hopper)


def _get_available_backend() -> AttentionBackend:
if not torch.cuda.is_available():
raise RuntimeError("Low-precision attention requires CUDA.")
capability = torch.cuda.get_device_capability()
if _is_hopper() and _is_fa3_available():
return AttentionBackend.FP8_FA3
raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.")


def _check_backend_available(backend: AttentionBackend) -> None:
if not torch.cuda.is_available():
raise RuntimeError(f"{backend} backend requires CUDA.")
capability = torch.cuda.get_device_capability()
if backend == AttentionBackend.FP8_FA3:
if not _is_hopper():
raise RuntimeError(
f"FP8_FA3 requires Hopper (SM 9.x), got SM{capability[0]}{capability[1]}."
)
if not _is_fa3_available():
raise RuntimeError(
"FP8_FA3 requires the flash-attn package with FA3 support."
)
else:
raise ValueError(f"Unknown backend: {backend}")


def apply_low_precision_attention(
model: nn.Module,
backend: Optional[AttentionBackend] = None,
fuse_rope_using_torch_compile: bool = False,
) -> nn.Module:
"""Apply low-precision attention to a model.

Must be called before ``torch.compile``. KV caching should be
disabled before calling (e.g., ``config.use_cache = False`` for
HuggingFace models).

When ``fuse_rope_using_torch_compile=True``, the returned wrapper
exposes a ``compile_backend`` attribute. You must compile with it to get
the RoPE fusion::

model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True)
model = torch.compile(model, backend=model.compile_backend)
"""
if isinstance(model, _LowPrecisionAttentionWrapper):
raise RuntimeError(
"apply_low_precision_attention has already been applied to this module."
)
if isinstance(model, torch._dynamo.OptimizedModule):
raise RuntimeError(
"apply_low_precision_attention must be called before torch.compile."
)

if backend is None:
backend = _get_available_backend()
else:
_check_backend_available(backend)

if backend == AttentionBackend.FP8_FA3:
return setup_fp8_backend(model, "FA3", fuse_rope_using_torch_compile)

raise ValueError(f"Unknown backend: {backend}")
21 changes: 21 additions & 0 deletions torchao/prototype/attention/fp8_fa3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
FP8 attention using FA3 backend.
"""

from torchao.prototype.attention.fp8_fa3.attention import (
fp8_fa3_rope_sdpa,
fp8_fa3_sdpa,
)
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize

__all__ = [
"fp8_fa3_sdpa",
"fp8_fa3_rope_sdpa",
"_fp8_sdpa_quantize",
]
Loading
Loading