Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
65eec66
Update (base update)
howardzhang-cv Feb 21, 2026
5f12437
Update
howardzhang-cv Feb 21, 2026
9955568
Update (base update)
howardzhang-cv Feb 24, 2026
b6ba8c4
Update
howardzhang-cv Feb 24, 2026
e471bd4
Update (base update)
howardzhang-cv Feb 25, 2026
1fb4411
Update
howardzhang-cv Feb 25, 2026
589d82c
Update (base update)
howardzhang-cv Feb 25, 2026
8c8ca6a
Update
howardzhang-cv Feb 25, 2026
7522715
Update (base update)
howardzhang-cv Feb 27, 2026
b121c8b
Update
howardzhang-cv Feb 27, 2026
a4d1542
Update (base update)
howardzhang-cv Feb 28, 2026
f0f44ac
Update
howardzhang-cv Feb 28, 2026
9f62ad0
Update (base update)
howardzhang-cv Feb 28, 2026
a472bb1
Update
howardzhang-cv Feb 28, 2026
88603c1
Update (base update)
howardzhang-cv Feb 28, 2026
799da1a
Update
howardzhang-cv Feb 28, 2026
f1e4899
Update (base update)
howardzhang-cv Mar 2, 2026
52038bd
Update
howardzhang-cv Mar 2, 2026
b3bff30
Update (base update)
howardzhang-cv Mar 3, 2026
c43f801
Update
howardzhang-cv Mar 3, 2026
cdc3088
Update (base update)
howardzhang-cv Mar 3, 2026
b953741
Update
howardzhang-cv Mar 3, 2026
0d398c9
Update (base update)
howardzhang-cv Mar 5, 2026
37c88fa
Update
howardzhang-cv Mar 5, 2026
d58e023
Update (base update)
howardzhang-cv Mar 6, 2026
4f6b977
Update
howardzhang-cv Mar 6, 2026
2298633
Update (base update)
howardzhang-cv Mar 6, 2026
2e6fb2f
Update
howardzhang-cv Mar 6, 2026
9507b0f
Update (base update)
howardzhang-cv Mar 6, 2026
a5e4031
Update
howardzhang-cv Mar 6, 2026
7e4f0f8
Update (base update)
howardzhang-cv Mar 7, 2026
4d03f07
Update
howardzhang-cv Mar 7, 2026
a82bc42
Update (base update)
howardzhang-cv Mar 7, 2026
b6f072f
Update
howardzhang-cv Mar 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions benchmarks/prototype/attention/benchmark_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Benchmark two attention backends against each other for a single layer,
sweeping sequence lengths and measuring runtime and SQNR.

Usage: python benchmarks/prototype/attention/benchmark_sdpa.py --baseline fa2 --test fa3_fp8
"""

import argparse
from contextlib import contextmanager
from functools import partial

import torch
import torch.nn.functional as F
from torch.nn.attention import (
SDPBackend,
activate_flash_attention_impl,
restore_flash_attention_impl,
sdpa_kernel,
)

from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
from torchao.quantization.utils import compute_error as compute_sqnr

BACKENDS = ["fa2", "fa3", "fa3_fp8"]

BACKEND_LABELS = {
"fa2": "FA2 BF16",
"fa3": "FA3 BF16",
"fa3_fp8": "FA3 FP8",
}


@contextmanager
def _activate_backend(backend: str):
"""Context manager that activates the appropriate flash attention impl."""
if backend in ("fa3", "fa3_fp8"):
activate_flash_attention_impl("FA3")
else:
# fa2 is the default, no activation needed
pass
try:
yield
finally:
if backend in ("fa3", "fa3_fp8"):
restore_flash_attention_impl()


def _run_attention(backend: str, q, k, v, is_causal: bool):
"""Run a single attention call for the given backend."""
if backend == "fa3_fp8":
return fp8_fa3_sdpa(q, k, v, is_causal=is_causal)
else:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)


def benchmark_fn(fn, num_warmup, num_iters):
"""Benchmark a function, returning median runtime in ms."""
for _ in range(num_warmup):
fn()
torch.cuda.synchronize()

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]

for i in range(num_iters):
start_events[i].record()
fn()
end_events[i].record()
torch.cuda.synchronize()

times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
times.sort()
return times[num_iters // 2] # median


@torch.inference_mode()
def run_benchmark(
baseline: str = "fa2",
test: str = "fa3_fp8",
is_causal: bool = False,
num_warmup: int = 5,
num_iters: int = 20,
):
B = 1
H = 32
D = 128
SEQ_LENGTHS = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]

device = "cuda"
dtype = torch.bfloat16

baseline_label = BACKEND_LABELS[baseline]
test_label = BACKEND_LABELS[test]

print("=" * 90)
print(f"Benchmark: {baseline_label} vs {test_label} — Single Attention Layer")
print(f" Shape: (B={B}, H={H}, S=variable, D={D})")
print(f" Causal: {is_causal}")
print(f" Warmup: {num_warmup}, Iters: {num_iters}")
print(f" Device: {torch.cuda.get_device_name()}")
print("=" * 90)

col_baseline = f"{baseline_label} (ms)"
col_test = f"{test_label} (ms)"
col_w = max(len(col_baseline), len(col_test), 12)

header = (
f"{'SeqLen':>8} | "
f"{col_baseline:>{col_w}} | "
f"{col_test:>{col_w}} | "
f"{'Speedup':>8} | "
f"{'SQNR (dB)':>10}"
)
print(header)
print("-" * len(header))

results = []

for S in SEQ_LENGTHS:
q = torch.randn(B, H, S, D, device=device, dtype=dtype)
k = torch.randn(B, H, S, D, device=device, dtype=dtype)
v = torch.randn(B, H, S, D, device=device, dtype=dtype)

# --- Baseline ---
with _activate_backend(baseline):
baseline_fn = partial(_run_attention, baseline, q, k, v, is_causal)
baseline_time = benchmark_fn(baseline_fn, num_warmup, num_iters)
ref_out = _run_attention(baseline, q, k, v, is_causal)

# --- Test ---
with _activate_backend(test):
test_fn = partial(_run_attention, test, q, k, v, is_causal)
test_time = benchmark_fn(test_fn, num_warmup, num_iters)
test_out = _run_attention(test, q, k, v, is_causal)

sqnr = compute_sqnr(ref_out, test_out)
speedup = baseline_time / test_time

print(
f"{S:>8} | "
f"{baseline_time:>{col_w}.3f} | "
f"{test_time:>{col_w}.3f} | "
f"{speedup:>7.2f}x | "
f"{sqnr:>10.2f}"
)

results.append(
{
"seq_len": S,
"baseline_ms": baseline_time,
"test_ms": test_time,
"speedup": speedup,
"sqnr_db": sqnr,
}
)

del q, k, v, ref_out, test_out
torch.cuda.empty_cache()

print("-" * len(header))
print()

return results


def main():
parser = argparse.ArgumentParser(
description="Benchmark any two attention backends for a single layer"
)
parser.add_argument(
"--baseline",
type=str,
default="fa2",
choices=BACKENDS,
help="Baseline attention backend (default: fa2)",
)
parser.add_argument(
"--test",
type=str,
default="fa3_fp8",
choices=BACKENDS,
help="Test attention backend to compare against baseline (default: fa3_fp8)",
)
parser.add_argument(
"--causal",
action="store_true",
help="Use causal attention masking",
)
parser.add_argument(
"--num_warmup",
type=int,
default=5,
help="Number of warmup iterations",
)
parser.add_argument(
"--num_iters",
type=int,
default=20,
help="Number of timed iterations",
)
args = parser.parse_args()

run_benchmark(
baseline=args.baseline,
test=args.test,
is_causal=args.causal,
num_warmup=args.num_warmup,
num_iters=args.num_iters,
)


if __name__ == "__main__":
main()
Loading
Loading