diff --git a/benchmarks/prototype/attention/eval_flux_model.py b/benchmarks/prototype/attention/eval_flux_model.py new file mode 100644 index 0000000000..fee11edd37 --- /dev/null +++ b/benchmarks/prototype/attention/eval_flux_model.py @@ -0,0 +1,460 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Benchmark attention backends on FLUX.1-schnell. + +Compares backends using LPIPS (perceptual similarity) on DrawBench prompts. + +Usage: + python eval_flux_model.py --baseline fa3 --test fa3_fp8 --debug_prompt "A red car" +""" + +import argparse +import gc +import random +from typing import Optional + +import lpips +import numpy as np +import torch +import torch._dynamo +from datasets import load_dataset +from diffusers import FluxPipeline +from PIL import Image +from torch.nn.attention import ( + activate_flash_attention_impl, + restore_flash_attention_impl, +) + +from torchao.prototype.attention import ( + AttentionBackend, + apply_low_precision_attention, +) + +BACKENDS = { + "fa2": {"flash_impl": None, "fp8": False}, + "fa3": {"flash_impl": "FA3", "fp8": False}, + "fa3_fp8": { + "flash_impl": "FA3", + "fp8": True, + "fp8_backend": AttentionBackend.FP8_FA3, + }, + "fa4": {"flash_impl": "FA4", "fp8": False}, + "fa4_fp8": { + "flash_impl": "FA4", + "fp8": True, + "fp8_backend": AttentionBackend.FP8_FA4, + }, +} + +IMAGE_SIZE = (512, 512) # (width, height) - resize for consistent LPIPS +RANDOM_SEED = 42 +MODEL_ID = "black-forest-labs/FLUX.1-schnell" + + +def cleanup_gpu(): + """Free GPU memory between benchmark phases.""" + gc.collect() + torch.cuda.empty_cache() + torch._dynamo.reset() + + +def setup_backend( + pipe, + backend_name, + compile_flag, + orig_transformer, + fuse_rope_using_torch_compile=False, +): + """Set up a backend and return the flash_impl name.""" + cfg = BACKENDS[backend_name] + pipe.transformer = orig_transformer + + if cfg["fp8"]: + print(f"Applying low-precision FP8 attention ({backend_name})...") + pipe.transformer = apply_low_precision_attention( + pipe.transformer, + backend=cfg["fp8_backend"], + fuse_rope_using_torch_compile=fuse_rope_using_torch_compile, + ) + if fuse_rope_using_torch_compile: + print( + f"Compiling transformer with torch.compile ({backend_name}, FP8 backend)..." + ) + pipe.transformer = torch.compile( + pipe.transformer, backend=pipe.transformer.compile_backend + ) + elif compile_flag: + print(f"Compiling transformer with torch.compile ({backend_name})...") + pipe.transformer = torch.compile(pipe.transformer) + return cfg["flash_impl"] + else: + if compile_flag: + print(f"Compiling transformer with torch.compile ({backend_name})...") + pipe.transformer = torch.compile(pipe.transformer) + return cfg["flash_impl"] + + +def pil_to_lpips_tensor(img: Image.Image, device: str) -> torch.Tensor: + """Convert a PIL Image to a tensor suitable for LPIPS computation.""" + t = ( + torch.from_numpy( + ( + torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) + .view(img.size[1], img.size[0], 3) + .numpy() + ) + ).float() + / 255.0 + ) + t = t.permute(2, 0, 1).unsqueeze(0) + t = t * 2.0 - 1.0 + return t.to(device) + + +def generate_image( + pipe, + prompt: str, + seed: int, + device: str, + num_inference_steps: int, + height: int = 2048, + width: int = 2048, + flash_impl: Optional[str] = None, +) -> Image.Image: + """Generate an image from a prompt with deterministic seed.""" + generator = torch.Generator(device=device).manual_seed(seed) + + if flash_impl: + activate_flash_attention_impl(flash_impl) + try: + image = pipe( + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance_scale=3.5, + height=height, + width=width, + generator=generator, + ).images[0] + finally: + if flash_impl: + restore_flash_attention_impl() + + if IMAGE_SIZE is not None: + image = image.resize(IMAGE_SIZE, Image.BICUBIC) + + return image + + +@torch.inference_mode() +def run_benchmark( + baseline_backend: str = "fa3", + test_backend: str = "fa3_fp8", + num_prompts: int = 50, + num_inference_steps: int = 20, + height: int = 2048, + width: int = 2048, + debug_prompt: Optional[str] = None, + warmup_iters: int = 2, + compile: bool = False, + fuse_rope_using_torch_compile: bool = False, +): + """Run the attention backend benchmark on FLUX.1-schnell.""" + compile_str = " + torch.compile" if compile else "" + print("=" * 80) + print("Attention Backend Benchmark for FLUX.1-schnell") + print(f"Baseline: {baseline_backend} | Test: {test_backend}{compile_str}") + print("=" * 80) + + torch.manual_seed(RANDOM_SEED) + torch.cuda.manual_seed_all(RANDOM_SEED) + random.seed(RANDOM_SEED) + np.random.seed(RANDOM_SEED) + + device = "cuda" + + # ----- Load prompts ----- + if debug_prompt is not None: + prompts = [debug_prompt] + print(f"Using debug prompt: {debug_prompt}") + else: + print("Loading DrawBench dataset...") + dataset = load_dataset("sayakpaul/drawbench", split="train") + all_prompts = [item["Prompts"] for item in dataset] + prompts = all_prompts[:num_prompts] + print( + f"Using {len(prompts)} prompts from DrawBench " + f"(total available: {len(all_prompts)})" + ) + + # ----- Load model and LPIPS ----- + print(f"\nLoading FLUX.1-schnell from {MODEL_ID}...") + pipe = FluxPipeline.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=True) + + print("Loading LPIPS model (VGG)...") + loss_fn = lpips.LPIPS(net="vgg").to(device) + + orig_transformer = pipe.transformer + + if compile: + pipe.vae.decode = torch.compile(pipe.vae.decode) + + # ----- Phase 1: Baseline backend ----- + print("\n" + "-" * 80) + print(f"Phase 1: Generating images ({baseline_backend})") + print("-" * 80) + + baseline_flash_impl = setup_backend( + pipe, + baseline_backend, + compile, + orig_transformer, + fuse_rope_using_torch_compile=fuse_rope_using_torch_compile, + ) + + print(f"Warming up {baseline_backend} with {warmup_iters} iterations...") + warmup_prompt = prompts[0] + for i in range(warmup_iters): + _ = generate_image( + pipe, + warmup_prompt, + RANDOM_SEED, + device, + num_inference_steps, + height=height, + width=width, + flash_impl=baseline_flash_impl, + ) + print(f" Warmup {i + 1}/{warmup_iters} complete") + + baseline_data = [] + baseline_times_ms = [] + + for idx, prompt in enumerate(prompts): + print(f"[{idx + 1}/{len(prompts)}] {baseline_backend}: {prompt[:50]}...") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + baseline_img = generate_image( + pipe, + prompt, + RANDOM_SEED, + device, + num_inference_steps, + height=height, + width=width, + flash_impl=baseline_flash_impl, + ) + end_event.record() + torch.cuda.synchronize() + elapsed_ms = start_event.elapsed_time(end_event) + + baseline_tensor = pil_to_lpips_tensor(baseline_img, device) + # Store tensors on CPU to free GPU memory for the test phase. + baseline_data.append((prompt, baseline_tensor.cpu())) + baseline_times_ms.append(elapsed_ms) + + avg_baseline_ms = sum(baseline_times_ms) / len(baseline_times_ms) + print( + f"\n{baseline_backend} complete. Avg time per image: {avg_baseline_ms:.1f} ms" + ) + + # ----- Cleanup before test phase ----- + cleanup_gpu() + + # ----- Phase 2: Test backend ----- + print("\n" + "-" * 80) + print(f"Phase 2: Generating images ({test_backend})") + print("-" * 80) + + test_flash_impl = setup_backend( + pipe, + test_backend, + compile, + orig_transformer, + fuse_rope_using_torch_compile=fuse_rope_using_torch_compile, + ) + + print(f"Warming up {test_backend} with {warmup_iters} iterations...") + for i in range(warmup_iters): + _ = generate_image( + pipe, + warmup_prompt, + RANDOM_SEED, + device, + num_inference_steps, + height=height, + width=width, + flash_impl=test_flash_impl, + ) + print(f" Warmup {i + 1}/{warmup_iters} complete") + + lpips_values = [] + test_times_ms = [] + + for idx, (prompt, baseline_tensor_cpu) in enumerate(baseline_data): + print(f"[{idx + 1}/{len(prompts)}] {test_backend}: {prompt[:50]}...") + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + test_img = generate_image( + pipe, + prompt, + RANDOM_SEED, + device, + num_inference_steps, + height=height, + width=width, + flash_impl=test_flash_impl, + ) + end_event.record() + torch.cuda.synchronize() + elapsed_ms = start_event.elapsed_time(end_event) + test_times_ms.append(elapsed_ms) + + test_tensor = pil_to_lpips_tensor(test_img, device) + lpips_value = loss_fn(baseline_tensor_cpu.to(device), test_tensor).item() + lpips_values.append(lpips_value) + + print(f" LPIPS: {lpips_value:.4f}, Time: {elapsed_ms:.1f} ms") + + avg_test_ms = sum(test_times_ms) / len(test_times_ms) + + # ----- Results ----- + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + + avg_lpips = sum(lpips_values) / len(lpips_values) + max_lpips = max(lpips_values) + min_lpips = min(lpips_values) + std_lpips = np.std(lpips_values) + + print("\nLPIPS Statistics (lower is better, 0 = identical):") + print(f" Average LPIPS: {avg_lpips:.4f}") + print(f" Std Dev: {std_lpips:.4f}") + print(f" Min LPIPS: {min_lpips:.4f}") + print(f" Max LPIPS: {max_lpips:.4f}") + + print("\nTiming Statistics:") + print(f" Avg {baseline_backend} time: {avg_baseline_ms:.1f} ms per image") + print(f" Avg {test_backend} time: {avg_test_ms:.1f} ms per image") + print(f" Speedup: {avg_baseline_ms / avg_test_ms:.2f}x") + + print("\nBenchmark Configuration:") + print(f" Baseline backend: {baseline_backend}") + print(f" Test backend: {test_backend}") + print(f" torch.compile: {compile}") + print(f" Model: {MODEL_ID}") + print(f" Prompts tested: {len(prompts)}") + print(f" Inference steps: {num_inference_steps}") + print(f" Generation size: {width}x{height}") + print(f" LPIPS resize: {IMAGE_SIZE[0]}x{IMAGE_SIZE[1]}") + print(f" Random seed: {RANDOM_SEED}") + print("=" * 80) + + return { + "avg_lpips": avg_lpips, + "std_lpips": std_lpips, + "min_lpips": min_lpips, + "max_lpips": max_lpips, + "speedup": avg_baseline_ms / avg_test_ms, + "lpips_values": lpips_values, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark attention backends on FLUX.1-schnell" + ) + parser.add_argument( + "--baseline", + type=str, + default="fa3", + choices=list(BACKENDS.keys()), + help="Baseline attention backend", + ) + parser.add_argument( + "--test", + type=str, + default="fa3_fp8", + choices=list(BACKENDS.keys()), + help="Test attention backend", + ) + parser.add_argument( + "--num_prompts", + type=int, + default=200, + help="Number of prompts to use (50 for quick, 200 for full benchmark)", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=4, + help="Number of diffusion inference steps", + ) + parser.add_argument( + "--debug_prompt", + type=str, + default=None, + help="Use a single debug prompt instead of DrawBench", + ) + parser.add_argument( + "--warmup_iters", + type=int, + default=2, + help="Number of warmup iterations", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Wrap the model with torch.compile for both backends", + ) + parser.add_argument( + "--fuse_rope_using_torch_compile", + action="store_true", + help="Fuse RoPE into the FP8 kernel (compile path, off by default)", + ) + parser.add_argument( + "--height", + type=int, + default=2048, + help="Generated image height in pixels (default: 2048)", + ) + parser.add_argument( + "--width", + type=int, + default=2048, + help="Generated image width in pixels (default: 2048)", + ) + + args = parser.parse_args() + + run_benchmark( + baseline_backend=args.baseline, + test_backend=args.test, + num_prompts=args.num_prompts, + num_inference_steps=args.num_inference_steps, + height=args.height, + width=args.width, + debug_prompt=args.debug_prompt, + warmup_iters=args.warmup_iters, + compile=args.compile, + fuse_rope_using_torch_compile=args.fuse_rope_using_torch_compile, + ) + + +if __name__ == "__main__": + main() diff --git a/test/prototype/attention/__init__.py b/test/prototype/attention/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/prototype/attention/test_fp8_attention.py b/test/prototype/attention/test_fp8_attention.py new file mode 100644 index 0000000000..8d8062022d --- /dev/null +++ b/test/prototype/attention/test_fp8_attention.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for FP8 low-precision attention (FA3 backend on Hopper).""" + +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import TestCase, run_tests + +from torchao.quantization.utils import compute_error +from torchao.utils import torch_version_at_least + +if torch_version_at_least("2.11.0"): + from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper + + if _is_hopper() and _is_fa3_available(): + from torch.nn.attention import ( + activate_flash_attention_impl, + restore_flash_attention_impl, + ) + + from torchao.prototype.attention import ( + AttentionBackend, + apply_low_precision_attention, + ) + from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_rope_sdpa, + fp8_fa3_sdpa, + ) + + +def _rope_cos_sin(S, D, device): + freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D)) + angles = torch.outer(torch.arange(S, dtype=torch.float32), freqs) + cos_half = torch.cos(angles) + sin_half = torch.sin(angles) + cos = torch.cat([cos_half, cos_half], dim=-1).to(device) + sin = torch.cat([sin_half, sin_half], dim=-1).to(device) + return cos, sin + + +def _apply_rope(x, cos, sin): + """NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D].""" + D_HALF = x.shape[-1] // 2 + rotate = torch.cat([-x[..., D_HALF:], x[..., :D_HALF]], dim=-1) + return ( + x * cos.unsqueeze(0).unsqueeze(2) + rotate * sin.unsqueeze(0).unsqueeze(2) + ).to(x.dtype) + + +class SimpleAttentionModel(nn.Module): + def __init__(self, embed_dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + def forward(self, x): + B, S, _ = x.shape + q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) + + +class SimpleRoPEAttentionModel(nn.Module): + """Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA).""" + + def __init__(self, embed_dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + def forward(self, x, cos, sin): + B, S, _ = x.shape + q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim) + k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim) + v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim) + q = _apply_rope(q, cos, sin).transpose(1, 2) + k = _apply_rope(k, cos, sin).transpose(1, 2) + v = v.transpose(1, 2) + attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) + + +@common_utils.instantiate_parametrized_tests +class TestFP8FA3Attention(TestCase): + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("shape", [(2, 8, 1024, 64), (1, 16, 1024, 128)]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_sdpa_accuracy(self, shape, dtype): + B, H, S, D = shape + q = torch.randn(B, H, S, D, device="cuda", dtype=dtype) + k = torch.randn(B, H, S, D, device="cuda", dtype=dtype) + v = torch.randn(B, H, S, D, device="cuda", dtype=dtype) + + with torch.no_grad(): + out_ref = F.scaled_dot_product_attention(q, k, v, is_causal=False) + + activate_flash_attention_impl("FA3") + try: + with torch.no_grad(): + out_fp8 = fp8_fa3_sdpa(q, k, v, is_causal=False) + finally: + restore_flash_attention_impl() + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 25.0, + f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}", + ) + + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("shape", [(2, 1024, 8, 64), (1, 1024, 16, 128)]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_rope_sdpa_accuracy(self, shape, dtype): + B, S, H, D = shape + q = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + k = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + v = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + cos, sin = _rope_cos_sin(S, D, "cuda") + + with torch.no_grad(): + out_ref = F.scaled_dot_product_attention( + _apply_rope(q, cos, sin).transpose(1, 2), + _apply_rope(k, cos, sin).transpose(1, 2), + v.transpose(1, 2), + is_causal=False, + ) + + activate_flash_attention_impl("FA3") + try: + with torch.no_grad(): + out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False) + finally: + restore_flash_attention_impl() + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 25.0, + f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}", + ) + + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_monkey_patch_model(self, dtype): + embed_dim, num_heads = 512, 8 + model = ( + SimpleAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + x = torch.randn(2, 128, embed_dim, device="cuda", dtype=dtype) + + with torch.no_grad(): + out_ref = model(x) + + fp8_model = ( + SimpleAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + fp8_model.load_state_dict(model.state_dict()) + fp8_model = apply_low_precision_attention( + fp8_model, + backend=AttentionBackend.FP8_FA3, + fuse_rope_using_torch_compile=False, + ) + + with torch.no_grad(): + out_fp8 = fp8_model(x) + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 20.0, + f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}", + ) + + @unittest.skipUnless( + torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(), + "Requires PyTorch >= 2.11, Hopper GPU, and FA3", + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_rope_fusion_model(self, dtype): + embed_dim, num_heads = 512, 8 + model = ( + SimpleRoPEAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + S = 128 + x = torch.randn(2, S, embed_dim, device="cuda", dtype=dtype) + cos, sin = _rope_cos_sin(S, embed_dim // num_heads, "cuda") + + with torch.no_grad(): + out_ref = model(x, cos, sin) + + fp8_model = ( + SimpleRoPEAttentionModel(embed_dim, num_heads) + .to(device="cuda", dtype=dtype) + .eval() + ) + fp8_model.load_state_dict(model.state_dict()) + fp8_model = apply_low_precision_attention( + fp8_model, + backend=AttentionBackend.FP8_FA3, + fuse_rope_using_torch_compile=True, + ) + fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend) + + with torch.no_grad(): + out_fp8 = fp8_model(x, cos, sin) + + sqnr = compute_error(out_ref, out_fp8) + self.assertGreater( + sqnr.item(), + 20.0, + f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/prototype/attention/__init__.py b/torchao/prototype/attention/__init__.py new file mode 100644 index 0000000000..60531e9cf0 --- /dev/null +++ b/torchao/prototype/attention/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Low-precision attention for inference. + +Only supports forward pass — backward is not supported by the underlying backends. +""" + +from torchao.prototype.attention.api import ( + AttentionBackend, + apply_low_precision_attention, +) + +__all__ = [ + "AttentionBackend", + "apply_low_precision_attention", +] diff --git a/torchao/prototype/attention/api.py b/torchao/prototype/attention/api.py new file mode 100644 index 0000000000..3879029b04 --- /dev/null +++ b/torchao/prototype/attention/api.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""User-facing API for low-precision attention.""" + +from enum import Enum +from typing import Optional + +import torch +import torch._dynamo +import torch.nn as nn + +from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper +from torchao.utils import torch_version_at_least + +if torch_version_at_least("2.11.0"): + from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend + from torchao.prototype.attention.shared_utils.wrapper import ( + _LowPrecisionAttentionWrapper, + ) +else: + raise ImportError("Low-precision attention requires PyTorch 2.11+.") + + +class AttentionBackend(str, Enum): + """Backend kernel for computing attention.""" + + FP8_FA3 = "FP8_FA3" # Requires SM90+ (Hopper) + + +def _get_available_backend() -> AttentionBackend: + if not torch.cuda.is_available(): + raise RuntimeError("Low-precision attention requires CUDA.") + capability = torch.cuda.get_device_capability() + if _is_hopper() and _is_fa3_available(): + return AttentionBackend.FP8_FA3 + raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.") + + +def _check_backend_available(backend: AttentionBackend) -> None: + if not torch.cuda.is_available(): + raise RuntimeError(f"{backend} backend requires CUDA.") + capability = torch.cuda.get_device_capability() + if backend == AttentionBackend.FP8_FA3: + if not _is_hopper(): + raise RuntimeError( + f"FP8_FA3 requires Hopper (SM 9.x), got SM{capability[0]}{capability[1]}." + ) + if not _is_fa3_available(): + raise RuntimeError( + "FP8_FA3 requires the flash-attn package with FA3 support." + ) + else: + raise ValueError(f"Unknown backend: {backend}") + + +def apply_low_precision_attention( + model: nn.Module, + backend: Optional[AttentionBackend] = None, + fuse_rope_using_torch_compile: bool = False, +) -> nn.Module: + """Apply low-precision attention to a model. + + Must be called before ``torch.compile``. KV caching should be + disabled before calling (e.g., ``config.use_cache = False`` for + HuggingFace models). + + When ``fuse_rope_using_torch_compile=True``, the returned wrapper + exposes a ``compile_backend`` attribute. You must compile with it to get + the RoPE fusion:: + + model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True) + model = torch.compile(model, backend=model.compile_backend) + """ + if isinstance(model, _LowPrecisionAttentionWrapper): + raise RuntimeError( + "apply_low_precision_attention has already been applied to this module." + ) + if isinstance(model, torch._dynamo.OptimizedModule): + raise RuntimeError( + "apply_low_precision_attention must be called before torch.compile." + ) + + if backend is None: + backend = _get_available_backend() + else: + _check_backend_available(backend) + + if backend == AttentionBackend.FP8_FA3: + return setup_fp8_backend(model, "FA3", fuse_rope_using_torch_compile) + + raise ValueError(f"Unknown backend: {backend}") diff --git a/torchao/prototype/attention/fp8_fa3/__init__.py b/torchao/prototype/attention/fp8_fa3/__init__.py new file mode 100644 index 0000000000..4a4487f043 --- /dev/null +++ b/torchao/prototype/attention/fp8_fa3/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +FP8 attention using FA3 backend. +""" + +from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_rope_sdpa, + fp8_fa3_sdpa, +) +from torchao.prototype.attention.quantization import _fp8_sdpa_quantize + +__all__ = [ + "fp8_fa3_sdpa", + "fp8_fa3_rope_sdpa", + "_fp8_sdpa_quantize", +] diff --git a/torchao/prototype/attention/fp8_fa3/attention.py b/torchao/prototype/attention/fp8_fa3/attention.py new file mode 100644 index 0000000000..bcec1e985b --- /dev/null +++ b/torchao/prototype/attention/fp8_fa3/attention.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +FP8 SDPA using FA3 backend. + +When using these functions directly (not through apply_low_precision_attention), +you must activate FA3 yourself:: + + activate_flash_attention_impl("FA3") + try: + out = fp8_fa3_sdpa(q, k, v, is_causal=True) + finally: + restore_flash_attention_impl() +""" + +from functools import partial + +from torchao.prototype.attention.shared_utils.attention import ( + _fp8_rope_sdpa, + _fp8_sdpa, +) + +fp8_fa3_sdpa = partial(_fp8_sdpa, backend_name="FA3") +fp8_fa3_sdpa.__doc__ = _fp8_sdpa.__doc__ +fp8_fa3_sdpa.__name__ = "fp8_fa3_sdpa" +fp8_fa3_sdpa.__qualname__ = "fp8_fa3_sdpa" + +fp8_fa3_rope_sdpa = partial(_fp8_rope_sdpa, backend_name="FA3") +fp8_fa3_rope_sdpa.__doc__ = _fp8_rope_sdpa.__doc__ +fp8_fa3_rope_sdpa.__name__ = "fp8_fa3_rope_sdpa" +fp8_fa3_rope_sdpa.__qualname__ = "fp8_fa3_rope_sdpa" diff --git a/torchao/prototype/attention/fp8_fa3/fusion_pass.py b/torchao/prototype/attention/fp8_fa3/fusion_pass.py new file mode 100644 index 0000000000..37bb6b07df --- /dev/null +++ b/torchao/prototype/attention/fp8_fa3/fusion_pass.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_rope_sdpa, + fp8_fa3_sdpa, +) +from torchao.prototype.attention.shared_utils.custom_ops import ( + make_backend_fn, + register_fp8_attention_ops, +) + +_ops = register_fp8_attention_ops( + backend_name="fa3", + rope_sdpa_fn=fp8_fa3_rope_sdpa, + sdpa_fn=fp8_fa3_sdpa, +) + +make_fp8_backend = make_backend_fn(_ops, backend_name="FA3", flash_impl_name="FA3") diff --git a/torchao/prototype/attention/quantization/__init__.py b/torchao/prototype/attention/quantization/__init__.py new file mode 100644 index 0000000000..945a52bfc1 --- /dev/null +++ b/torchao/prototype/attention/quantization/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from torchao.prototype.attention.quantization.triton_qkv_quantization import ( + triton_fp8_sdpa_quantize as _fp8_sdpa_quantize, +) +from torchao.prototype.attention.quantization.triton_rope_qkv_quantization import ( + triton_fp8_rope_sdpa_quantize as _fp8_rope_sdpa_quantize, +) + +__all__ = [ + "_fp8_sdpa_quantize", + "_fp8_rope_sdpa_quantize", +] diff --git a/torchao/prototype/attention/quantization/triton_qkv_quantization.py b/torchao/prototype/attention/quantization/triton_qkv_quantization.py new file mode 100644 index 0000000000..c0ea2f4fd6 --- /dev/null +++ b/torchao/prototype/attention/quantization/triton_qkv_quantization.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +FP8 quantization kernels for Q, K, V. + +Input/output format: [B, H, S, D]. +Supports GQA (different head counts for Q vs K/V). +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +def _compute_num_chunks(tensor: torch.Tensor, S: int) -> int: + """Compute optimal number of chunks based on GPU properties.""" + props = torch.cuda.get_device_properties(tensor.device) + num_sms = props.multi_processor_count + B, H = tensor.shape[:2] # [B, H, S, D] + base_parallelism = B * H + # Target 2-4x SMs for good occupancy/latency hiding + target_blocks = num_sms * 4 + num_chunks = max(1, target_blocks // base_parallelism) + # Ensure each chunk has at least 32 S positions for efficiency + num_chunks = min(num_chunks, S // 32) if S >= 32 else 1 + # Cap at reasonable maximum + num_chunks = min(num_chunks, 64) + # Adjust if S is small + num_chunks = min(num_chunks, S) + return num_chunks + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def single_phase1_kernel( + # Input tensor [B, H, S, D] + x_ptr, + # Output: partial max values [B * H * num_chunks] + partial_max_ptr, + # Input strides (for [B, H, S, D] layout) + stride_b, + stride_h, + stride_s, + stride_d, + # Dimensions + S, + D, + H, + chunk_size, + num_chunks, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 1 for a single tensor: Compute partial absmax. + + Grid: (B, H, num_chunks) + + Uses linearized iteration over chunk_size * D elements. + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointer for input [B, H, S, D] + base_offset = pid_b * stride_b + pid_h * stride_h + + # Initialize max accumulator + x_max = 0.0 + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + # Input offset [B, H, S, D] + ptr_offset = s_idx * stride_s + d_idx * stride_d + + x_val = tl.load(x_ptr + base_offset + ptr_offset, mask=mask, other=0.0).to( + tl.float32 + ) + x_max = tl.maximum(x_max, tl.max(tl.abs(x_val))) + + # Store partial max + chunk_idx = pid_b * (H * num_chunks) + pid_h * num_chunks + pid_chunk + tl.store(partial_max_ptr + chunk_idx, x_max) + + +@triton.jit +def single_reduce_kernel( + partial_max_ptr, # [B * H * num_chunks] + scale_ptr, + descale_ptr, + H, + num_chunks, +): + """ + Reduce partial maxes and compute scale/descale for a single tensor. + + Grid: (B, H) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + # Reduce across chunks for this (batch, head) + x_max = 0.0 + + base_idx = (pid_b * H + pid_h) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + # Compute scale and descale + # FP8 E4M3 max value is 448.0 + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H + pid_h + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.jit +def group_reduce_kernel( + partial_max_ptr, # [B * H_q * num_chunks] + scale_ptr, # [B, H_kv] + descale_ptr, # [B, H_kv] + H_q, + H_kv, + groups, # H_q // H_kv + num_chunks, +): + """ + Reduce partial maxes across head groups for GQA Q tensor. + + For each KV group, reduces the max across all Q heads in that group + and all chunks, producing one scale per (batch, kv_head). + + Grid: (B, H_kv) + """ + pid_b = tl.program_id(axis=0) + pid_hkv = tl.program_id(axis=1) + + x_max = 0.0 + + for g in range(groups): + h_q = pid_hkv * groups + g + base_idx = (pid_b * H_q + h_q) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H_kv + pid_hkv + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def single_phase2_kernel( + # Input tensor [B, H, S, D] + x_ptr, + # Output tensor [B, H, S, D] - FP8 quantized + x_out_ptr, + # Precomputed scale [B, H_scale] + scale_ptr, + # Strides (for [B, H, S, D] layout) + stride_b, + stride_h, + stride_s, + stride_d, + # Dimensions + S, + D, + H, + chunk_size, + # Scale indexing for GQA: scale has H_scale entries per batch, + # and each group of `groups` heads shares one scale. + # For non-GQA: H_scale = H, groups = 1. + H_scale, + groups, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 2 for a single tensor: Quantize to FP8 using precomputed scale. + + Grid: (B, H, num_chunks) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Load scale for this head (or head group for GQA) + scale = tl.load(scale_ptr + pid_b * H_scale + pid_h // groups) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointer + base_offset = pid_b * stride_b + pid_h * stride_h + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + ptr_offset = base_offset + s_idx * stride_s + d_idx * stride_d + + # Load input value + x_val = tl.load(x_ptr + ptr_offset, mask=mask, other=0.0).to(tl.float32) + + # Quantize to FP8 + x_fp8 = (x_val * scale).to(tl.float8e4nv) + + # Store to output + tl.store(x_out_ptr + ptr_offset, x_fp8, mask=mask) + + +def triton_fp8_sdpa_quantize( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + num_chunks: Optional[int] = None, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """ + Separated FP8 quantization for Q, K, V tensors. + + Quantizes all tensors to FP8 with per-head scaling. + Each of Q, K, V is processed with independent kernel launches, + supporting GQA where Q has more heads than K/V (H_q = groups * H_kv). + + For GQA, Q is quantized with per-KV-group scaling so that q_descale + has shape [B, H_kv] as required by FA3. + + Args: + q: Query tensor of shape [B, H_q, S, D] in bf16/fp16 + k: Key tensor of shape [B, H_kv, S, D] in bf16/fp16 + v: Value tensor of shape [B, H_kv, S, D] in bf16/fp16 + num_chunks: Number of chunks to split S dimension into. + If None, automatically selects based on GPU SM count. + + Returns: + q_fp8: Quantized query, shape [B, H_q, S, D] in fp8 + k_fp8: Quantized key, shape [B, H_kv, S, D] in fp8 + v_fp8: Quantized value, shape [B, H_kv, S, D] in fp8 + q_descale: Query descale factors, shape [B, H_kv] in fp32 + k_descale: Key descale factors, shape [B, H_kv] in fp32 + v_descale: Value descale factors, shape [B, H_kv] in fp32 + """ + assert q.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {q.dim()}D" + assert k.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {k.dim()}D" + assert v.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {v.dim()}D" + assert k.shape == v.shape, ( + f"K and V must have the same shape, got {k.shape} vs {v.shape}" + ) + assert q.shape[0] == k.shape[0], ( + f"Batch size mismatch: {q.shape[0]} vs {k.shape[0]}" + ) + assert q.shape[2] == k.shape[2], ( + f"Sequence length mismatch: {q.shape[2]} vs {k.shape[2]}" + ) + assert q.shape[3] == k.shape[3], f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}" + assert q.shape[1] % k.shape[1] == 0, ( + f"Q heads ({q.shape[1]}) must be a multiple of K heads ({k.shape[1]})" + ) + + B, H_q, S, D = q.shape + H_kv = k.shape[1] + groups = H_q // H_kv + + # Make tensors contiguous if needed + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + # Compute number of chunks + if num_chunks is None: + num_chunks = _compute_num_chunks(q, S) + chunk_size = (S + num_chunks - 1) // num_chunks + + # Allocate output tensors + q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fn) + k_fp8 = torch.empty_like(k, dtype=torch.float8_e4m3fn) + v_fp8 = torch.empty_like(v, dtype=torch.float8_e4m3fn) + + # Allocate partial max buffers (one per tensor) + q_partial_max = torch.empty( + B * H_q * num_chunks, dtype=torch.float32, device=q.device + ) + k_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + v_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + + # Allocate scale/descale tensors. + # For GQA, Q scale/descale are [B, H_kv] (per KV group). + # K and V are always [B, H_kv] (per head). + q_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + q_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + + q_grid_chunked = (B, H_q, num_chunks) + kv_grid_chunked = (B, H_kv, num_chunks) + + # ---- Phase 1: Max for Q ---- + single_phase1_kernel[q_grid_chunked]( + q, + q_partial_max, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + S, + D, + H_q, + chunk_size, + num_chunks, + ) + + # ---- Phase 1: Max for K ---- + single_phase1_kernel[kv_grid_chunked]( + k, + k_partial_max, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + S, + D, + H_kv, + chunk_size, + num_chunks, + ) + + # ---- Phase 1: Max for V ---- + single_phase1_kernel[kv_grid_chunked]( + v, + v_partial_max, + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + S, + D, + H_kv, + chunk_size, + num_chunks, + ) + + # ---- Reduce ---- + # Q: group reduce across `groups` Q heads per KV head + group_reduce_kernel[(B, H_kv)]( + q_partial_max, q_scale, q_descale, H_q, H_kv, groups, num_chunks + ) + # K, V: per-head reduce + single_reduce_kernel[(B, H_kv)](k_partial_max, k_scale, k_descale, H_kv, num_chunks) + single_reduce_kernel[(B, H_kv)](v_partial_max, v_scale, v_descale, H_kv, num_chunks) + + # ---- Phase 2: Quantize Q ---- + # Q scale is [B, H_kv]; each group of `groups` Q heads shares one scale. + single_phase2_kernel[q_grid_chunked]( + q, + q_fp8, + q_scale, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + S, + D, + H_q, + chunk_size, + H_kv, + groups, + ) + + # ---- Phase 2: Quantize K ---- + # K scale is [B, H_kv]; groups=1 (per-head). + single_phase2_kernel[kv_grid_chunked]( + k, + k_fp8, + k_scale, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + S, + D, + H_kv, + chunk_size, + H_kv, + 1, + ) + + # ---- Phase 2: Quantize V ---- + # V scale is [B, H_kv]; groups=1 (per-head). + single_phase2_kernel[kv_grid_chunked]( + v, + v_fp8, + v_scale, + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + S, + D, + H_kv, + chunk_size, + H_kv, + 1, + ) + + return q_fp8, k_fp8, v_fp8, q_descale, k_descale, v_descale diff --git a/torchao/prototype/attention/quantization/triton_rope_qkv_quantization.py b/torchao/prototype/attention/quantization/triton_rope_qkv_quantization.py new file mode 100644 index 0000000000..6ec95dd5af --- /dev/null +++ b/torchao/prototype/attention/quantization/triton_rope_qkv_quantization.py @@ -0,0 +1,745 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Fused RoPE + FP8 quantization kernels for Q, K, V. + +Input: [B, S, H, D], output: [B, H, S, D]. +Supports GQA (different head counts for Q vs K/V). +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +def _compute_num_chunks(tensor: torch.Tensor, S: int) -> int: + """Compute optimal number of chunks based on GPU properties.""" + props = torch.cuda.get_device_properties(tensor.device) + num_sms = props.multi_processor_count + B, _, H, _ = tensor.shape # [B, S, H, D] + base_parallelism = B * H + # Target 2-4x SMs for good occupancy/latency hiding + target_blocks = num_sms * 4 + num_chunks = max(1, target_blocks // base_parallelism) + # Ensure each chunk has at least 32 S positions for efficiency + num_chunks = min(num_chunks, S // 32) if S >= 32 else 1 + # Cap at reasonable maximum + num_chunks = min(num_chunks, 64) + # Adjust if S is small + num_chunks = min(num_chunks, S) + return num_chunks + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 256}, num_warps=4), + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + ], + key=["chunk_size", "D_HALF"], +) +@triton.jit +def rope_single_phase1_kernel( + # Input tensor [B, S, H, D] + x_ptr, + # RoPE frequency tensors [S, D] + cos_ptr, + sin_ptr, + # Intermediate output tensor [B, H, S, D] - stores RoPE'd result + x_rope_ptr, + # Output: partial max values [B * H * num_chunks] + partial_max_ptr, + # Input strides (for [B, S, H, D] layout) + stride_in_b, + stride_in_s, + stride_in_h, + stride_in_d, + # Output strides (for [B, H, S, D] layout) + stride_out_b, + stride_out_h, + stride_out_s, + stride_out_d, + # Dimensions + S, + D, + D_HALF, + H, + chunk_size, + num_chunks, + # Block size + BLOCK_SIZE: tl.constexpr, + ROPE_INTERLEAVED: tl.constexpr, +): + """ + Phase 1 for a single tensor (Q or K): Apply RoPE, store to intermediate, + compute partial max. + + Grid: (B, H, num_chunks) + + Supports two RoPE pairing variants (selected by ROPE_INTERLEAVED constexpr): + - NeoX half-split (ROPE_INTERLEAVED=False): pairs (j, j+D/2) for j in [0, D/2) + - Interleaved (ROPE_INTERLEAVED=True): pairs (2i, 2i+1) for i in [0, D/2) + + Each pair shares the same rotation angle and the 2D rotation formula is: + out[first] = x[first]*cos - x[second]*sin + out[second] = x[second]*cos + x[first]*sin + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + actual_chunk_size = s_end - s_start + + # Number of pairs to process in this chunk + chunk_pairs = actual_chunk_size * D_HALF + + # Base pointers for input [B, S, H, D] + in_base_b = pid_b * stride_in_b + in_base_h = pid_h * stride_in_h + + # Base pointers for output [B, H, S, D] + out_base_b = pid_b * stride_out_b + out_base_h = pid_h * stride_out_h + + # Initialize max accumulator + x_max = 0.0 + + # Linearized iteration over chunk_size * D_HALF pairs + for block_start in range(0, chunk_pairs, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_pairs + + # Convert linear offset to (s, pair_idx) coordinates + local_s = offs // D_HALF + pair_idx = offs % D_HALF + s_idx = s_start + local_s + + # Compute element indices for this pair based on RoPE variant + if ROPE_INTERLEAVED: + # FLUX/GPT-J interleaved: pair (2i, 2i+1) + d_first = pair_idx * 2 + d_second = pair_idx * 2 + 1 + else: + # NeoX/LLaMA half-split: pair (j, j+D/2) + d_first = pair_idx + d_second = pair_idx + D_HALF + + # Input offsets [B, S, H, D] + in_offset_first = ( + in_base_b + s_idx * stride_in_s + in_base_h + d_first * stride_in_d + ) + in_offset_second = ( + in_base_b + s_idx * stride_in_s + in_base_h + d_second * stride_in_d + ) + + # Output offsets [B, H, S, D] + out_offset_first = ( + out_base_b + out_base_h + s_idx * stride_out_s + d_first * stride_out_d + ) + out_offset_second = ( + out_base_b + out_base_h + s_idx * stride_out_s + d_second * stride_out_d + ) + + # Load input pairs (each element loaded exactly once) + x_first = tl.load(x_ptr + in_offset_first, mask=mask, other=0.0).to(tl.float32) + x_second = tl.load(x_ptr + in_offset_second, mask=mask, other=0.0).to( + tl.float32 + ) + + # Load cos/sin — both elements in a NeoX pair share the same + # rotation angle. cos[j] == cos[j+D/2] in LLaMA's frequency + # layout, so we only need one load per pair. + cos_offset = s_idx * D + d_first + cos_val = tl.load(cos_ptr + cos_offset, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + cos_offset, mask=mask, other=0.0).to(tl.float32) + + # Apply NeoX RoPE rotation: + # out[j] = x[j] * cos - x[j+D/2] * sin + # out[j+D/2] = x[j+D/2] * cos + x[j] * sin + x_rope_first = tl.math.fma(x_first, cos_val, -(x_second * sin_val)) + x_rope_second = tl.math.fma(x_second, cos_val, x_first * sin_val) + + # Store RoPE'd result to intermediate buffer [B, H, S, D] + tl.store( + x_rope_ptr + out_offset_first, x_rope_first.to(x_first.dtype), mask=mask + ) + tl.store( + x_rope_ptr + out_offset_second, x_rope_second.to(x_first.dtype), mask=mask + ) + + # Update max values (from RoPE'd values) + x_max = tl.maximum(x_max, tl.max(tl.abs(x_rope_first))) + x_max = tl.maximum(x_max, tl.max(tl.abs(x_rope_second))) + + # Store partial max + chunk_idx = pid_b * (H * num_chunks) + pid_h * num_chunks + pid_chunk + tl.store(partial_max_ptr + chunk_idx, x_max) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def v_phase1_kernel( + # Input tensor [B, S, H, D] + v_ptr, + # Output: partial max values [B * H * num_chunks] + partial_max_ptr, + # Input strides (for [B, S, H, D] layout) + stride_in_b, + stride_in_s, + stride_in_h, + stride_in_d, + # Dimensions + S, + D, + H, + chunk_size, + num_chunks, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 1 for V: Compute partial absmax (no RoPE applied). + + Grid: (B, H, num_chunks) + + Uses linearized iteration over chunk_size * D elements. + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointers for input [B, S, H, D] + in_base_b = pid_b * stride_in_b + in_base_h = pid_h * stride_in_h + + # Initialize max accumulator + v_max = 0.0 + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + # Input offset [B, S, H, D] + in_offset = in_base_b + s_idx * stride_in_s + in_base_h + d_idx * stride_in_d + + v_val = tl.load(v_ptr + in_offset, mask=mask, other=0.0).to(tl.float32) + v_max = tl.maximum(v_max, tl.max(tl.abs(v_val))) + + # Store partial max + chunk_idx = pid_b * (H * num_chunks) + pid_h * num_chunks + pid_chunk + tl.store(partial_max_ptr + chunk_idx, v_max) + + +@triton.jit +def single_reduce_kernel( + partial_max_ptr, # [B * H * num_chunks] + scale_ptr, + descale_ptr, + H, + num_chunks, +): + """ + Reduce partial maxes and compute scale/descale for a single tensor. + + Grid: (B, H) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + # Reduce across chunks for this (batch, head) + x_max = 0.0 + + base_idx = (pid_b * H + pid_h) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + # Compute scale and descale + # FP8 E4M3 max value is 448.0 + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H + pid_h + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.jit +def group_reduce_kernel( + partial_max_ptr, # [B * H_q * num_chunks] + scale_ptr, # [B, H_kv] + descale_ptr, # [B, H_kv] + H_q, + H_kv, + groups, # H_q // H_kv + num_chunks, +): + """ + Reduce partial maxes across head groups for GQA Q tensor. + + For each KV group, reduces the max across all Q heads in that group + and all chunks, producing one scale per (batch, kv_head). + + Grid: (B, H_kv) + """ + pid_b = tl.program_id(axis=0) + pid_hkv = tl.program_id(axis=1) + + x_max = 0.0 + + for g in range(groups): + h_q = pid_hkv * groups + g + base_idx = (pid_b * H_q + h_q) * num_chunks + for c in range(num_chunks): + x_max = tl.maximum(x_max, tl.load(partial_max_ptr + base_idx + c)) + + FP8_MAX = 448.0 + eps = 1e-12 + scale_idx = pid_b * H_kv + pid_hkv + + tl.store(scale_ptr + scale_idx, tl.where(x_max > eps, FP8_MAX / x_max, 1.0)) + tl.store(descale_ptr + scale_idx, tl.where(x_max > eps, x_max / FP8_MAX, 1.0)) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def rope_single_phase2_kernel( + # Intermediate tensor [B, H, S, D] - already RoPE'd + x_rope_ptr, + # Output tensor [B, H, S, D] - FP8 quantized + x_out_ptr, + # Precomputed scale [B, H_scale] + scale_ptr, + # Strides (for [B, H, S, D] layout) - same for intermediate and output + stride_b, + stride_h, + stride_s, + stride_d, + # Dimensions + S, + D, + H, + chunk_size, + # Scale indexing for GQA: scale has H_scale entries per batch, + # and each group of `groups` heads shares one scale. + # For non-GQA: H_scale = H, groups = 1. + H_scale, + groups, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 2 for a single tensor (Q or K): Quantize pre-computed RoPE'd values to FP8. + + Grid: (B, H, num_chunks) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Load scale for this head (or head group for GQA) + scale = tl.load(scale_ptr + pid_b * H_scale + pid_h // groups) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointer + base_offset = pid_b * stride_b + pid_h * stride_h + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + ptr_offset = base_offset + s_idx * stride_s + d_idx * stride_d + + # Load pre-computed RoPE'd value from intermediate buffer + x_val = tl.load(x_rope_ptr + ptr_offset, mask=mask, other=0.0).to(tl.float32) + + # Quantize to FP8 + x_fp8 = (x_val * scale).to(tl.float8e4nv) + + # Store to output + tl.store(x_out_ptr + ptr_offset, x_fp8, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 512}, num_warps=4), + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + ], + key=["chunk_size", "D"], +) +@triton.jit +def v_phase2_kernel( + # Original V tensor [B, S, H, D] - needs transpose only + v_ptr, + # Output tensor [B, H, S, D] - FP8 quantized + v_out_ptr, + # Precomputed scale [B, H] + scale_ptr, + # V input strides (for [B, S, H, D] layout) + stride_v_in_b, + stride_v_in_s, + stride_v_in_h, + stride_v_in_d, + # Output strides (for [B, H, S, D] layout) + stride_out_b, + stride_out_h, + stride_out_s, + stride_out_d, + # Dimensions + S, + D, + H, + chunk_size, + # Block size + BLOCK_SIZE: tl.constexpr, +): + """ + Phase 2 for V: Transpose from [B,S,H,D] to [B,H,S,D] and quantize to FP8. + + Grid: (B, H, num_chunks) + """ + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + pid_chunk = tl.program_id(axis=2) + + # Load scale for this head + scale = tl.load(scale_ptr + pid_b * H + pid_h) + + # Compute the S range for this chunk + s_start = pid_chunk * chunk_size + s_end = tl.minimum(s_start + chunk_size, S) + chunk_elements = (s_end - s_start) * D + + # Base pointers + v_in_base = pid_b * stride_v_in_b + pid_h * stride_v_in_h + out_base = pid_b * stride_out_b + pid_h * stride_out_h + + # Linearized iteration over chunk_size * D elements + for block_start in range(0, chunk_elements, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < chunk_elements + + # Convert linear offset to (s, d) coordinates + local_s = offs // D + d_idx = offs % D + s_idx = s_start + local_s + + # V input offset [B, S, H, D] + v_in_offset = v_in_base + s_idx * stride_v_in_s + d_idx * stride_v_in_d + + # Output offset [B, H, S, D] + out_offset = out_base + s_idx * stride_out_s + d_idx * stride_out_d + + # Load V from original input (no RoPE, just transpose) + v_val = tl.load(v_ptr + v_in_offset, mask=mask, other=0.0).to(tl.float32) + + # Quantize to FP8 + v_fp8 = (v_val * scale).to(tl.float8e4nv) + + # Store to output + tl.store(v_out_ptr + out_offset, v_fp8, mask=mask) + + +def triton_fp8_rope_sdpa_quantize( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + num_chunks: Optional[int] = None, + rope_interleaved: bool = False, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """ + Separated RoPE + FP8 quantization for Q, K, V tensors. + + Applies RoPE to Q and K, then quantizes all tensors to FP8 with per-head scaling. + Also performs layout transformation from [B, S, H, D] to [B, H, S, D]. + Each of Q, K, V is processed with independent kernel launches. + + Supports GQA where Q has more heads than K/V (H_q = groups * H_kv). + For GQA, Q is quantized with per-KV-group scaling so that q_descale + has shape [B, H_kv] as required by FA3. + + Args: + q: Query tensor of shape [B, S, H_q, D] in bf16/fp16 + k: Key tensor of shape [B, S, H_kv, D] in bf16/fp16 + v: Value tensor of shape [B, S, H_kv, D] in bf16/fp16 + cos: Cosine frequencies for RoPE, shape [S, D] + sin: Sine frequencies for RoPE, shape [S, D] + num_chunks: Number of chunks to split S dimension into. + If None, automatically selects based on GPU SM count. + + Returns: + q_fp8: Quantized query with RoPE, shape [B, H_q, S, D] in fp8 + k_fp8: Quantized key with RoPE, shape [B, H_kv, S, D] in fp8 + v_fp8: Quantized value, shape [B, H_kv, S, D] in fp8 + q_descale: Query descale factors, shape [B, H_kv] in fp32 + k_descale: Key descale factors, shape [B, H_kv] in fp32 + v_descale: Value descale factors, shape [B, H_kv] in fp32 + """ + assert q.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {q.dim()}D" + assert k.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {k.dim()}D" + assert v.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {v.dim()}D" + assert k.shape == v.shape, ( + f"K and V must have the same shape, got {k.shape} vs {v.shape}" + ) + assert q.shape[0] == k.shape[0], ( + f"Batch size mismatch: {q.shape[0]} vs {k.shape[0]}" + ) + assert q.shape[1] == k.shape[1], ( + f"Sequence length mismatch: {q.shape[1]} vs {k.shape[1]}" + ) + assert q.shape[3] == k.shape[3], f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}" + assert q.shape[2] % k.shape[2] == 0, ( + f"Q heads ({q.shape[2]}) must be a multiple of K heads ({k.shape[2]})" + ) + assert cos.dim() == 2, f"Expected 2D cos tensor [S, D], got {cos.dim()}D" + assert sin.dim() == 2, f"Expected 2D sin tensor [S, D], got {sin.dim()}D" + + B, S, H_q, D = q.shape + H_kv = k.shape[2] + groups = H_q // H_kv + + assert D % 2 == 0, f"Head dimension D must be even for RoPE, got D={D}" + assert cos.shape == (S, D), f"Expected cos shape [{S}, {D}], got {cos.shape}" + assert sin.shape == (S, D), f"Expected sin shape [{S}, {D}], got {sin.shape}" + + D_HALF = D // 2 + + # Make tensors contiguous if needed + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + # Compute number of chunks + if num_chunks is None: + num_chunks = _compute_num_chunks(q, S) + chunk_size = (S + num_chunks - 1) // num_chunks + + # Allocate output tensors in [B, H, S, D] layout for SDPA + q_fp8 = torch.empty(B, H_q, S, D, dtype=torch.float8_e4m3fn, device=q.device) + k_fp8 = torch.empty(B, H_kv, S, D, dtype=torch.float8_e4m3fn, device=q.device) + v_fp8 = torch.empty(B, H_kv, S, D, dtype=torch.float8_e4m3fn, device=q.device) + + # Allocate intermediate buffers for RoPE'd Q, K in [B, H, S, D] layout + q_rope_intermediate = torch.empty(B, H_q, S, D, dtype=q.dtype, device=q.device) + k_rope_intermediate = torch.empty(B, H_kv, S, D, dtype=k.dtype, device=q.device) + + # Allocate partial max buffers (one per tensor) + q_partial_max = torch.empty( + B * H_q * num_chunks, dtype=torch.float32, device=q.device + ) + k_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + v_partial_max = torch.empty( + B * H_kv * num_chunks, dtype=torch.float32, device=q.device + ) + + # Allocate scale/descale tensors. + # For GQA, Q scale/descale are [B, H_kv] (per KV group). + # K and V are always [B, H_kv] (per head). + q_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_scale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + q_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + k_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + v_descale = torch.empty(B, H_kv, dtype=torch.float32, device=q.device) + + q_grid_chunked = (B, H_q, num_chunks) + kv_grid_chunked = (B, H_kv, num_chunks) + + # ---- Phase 1: RoPE + max for Q ---- + rope_single_phase1_kernel[q_grid_chunked]( + q, + cos, + sin, + q_rope_intermediate, + q_partial_max, + # Input strides [B, S, H_q, D] + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + # Output strides [B, H_q, S, D] + q_rope_intermediate.stride(0), + q_rope_intermediate.stride(1), + q_rope_intermediate.stride(2), + q_rope_intermediate.stride(3), + S, + D, + D_HALF, + H_q, + chunk_size, + num_chunks, + ROPE_INTERLEAVED=rope_interleaved, + ) + + # ---- Phase 1: RoPE + max for K ---- + rope_single_phase1_kernel[kv_grid_chunked]( + k, + cos, + sin, + k_rope_intermediate, + k_partial_max, + # Input strides [B, S, H_kv, D] + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + # Output strides [B, H_kv, S, D] + k_rope_intermediate.stride(0), + k_rope_intermediate.stride(1), + k_rope_intermediate.stride(2), + k_rope_intermediate.stride(3), + S, + D, + D_HALF, + H_kv, + chunk_size, + num_chunks, + ROPE_INTERLEAVED=rope_interleaved, + ) + + # ---- Phase 1: Max for V (no RoPE) ---- + v_phase1_kernel[kv_grid_chunked]( + v, + v_partial_max, + # Input strides [B, S, H_kv, D] + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + S, + D, + H_kv, + chunk_size, + num_chunks, + ) + + # ---- Reduce ---- + # Q: group reduce across `groups` Q heads per KV head + group_reduce_kernel[(B, H_kv)]( + q_partial_max, q_scale, q_descale, H_q, H_kv, groups, num_chunks + ) + # K, V: per-head reduce + single_reduce_kernel[(B, H_kv)](k_partial_max, k_scale, k_descale, H_kv, num_chunks) + single_reduce_kernel[(B, H_kv)](v_partial_max, v_scale, v_descale, H_kv, num_chunks) + + # ---- Phase 2: Quantize Q from intermediate ---- + # Q scale is [B, H_kv]; each group of `groups` Q heads shares one scale. + rope_single_phase2_kernel[q_grid_chunked]( + q_rope_intermediate, + q_fp8, + q_scale, + # Strides [B, H_q, S, D] + q_fp8.stride(0), + q_fp8.stride(1), + q_fp8.stride(2), + q_fp8.stride(3), + S, + D, + H_q, + chunk_size, + H_kv, + groups, + ) + + # ---- Phase 2: Quantize K from intermediate ---- + # K scale is [B, H_kv]; groups=1 (per-head). + rope_single_phase2_kernel[kv_grid_chunked]( + k_rope_intermediate, + k_fp8, + k_scale, + # Strides [B, H_kv, S, D] + k_fp8.stride(0), + k_fp8.stride(1), + k_fp8.stride(2), + k_fp8.stride(3), + S, + D, + H_kv, + chunk_size, + H_kv, + 1, + ) + + # ---- Phase 2: Transpose + quantize V ---- + v_phase2_kernel[kv_grid_chunked]( + v, + v_fp8, + v_scale, + # V input strides [B, S, H_kv, D] + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + # Output strides [B, H_kv, S, D] + v_fp8.stride(0), + v_fp8.stride(1), + v_fp8.stride(2), + v_fp8.stride(3), + S, + D, + H_kv, + chunk_size, + ) + + return q_fp8, k_fp8, v_fp8, q_descale, k_descale, v_descale diff --git a/torchao/prototype/attention/shared_utils/__init__.py b/torchao/prototype/attention/shared_utils/__init__.py new file mode 100644 index 0000000000..ecddc88b5f --- /dev/null +++ b/torchao/prototype/attention/shared_utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchao/prototype/attention/shared_utils/attention.py b/torchao/prototype/attention/shared_utils/attention.py new file mode 100644 index 0000000000..248668639a --- /dev/null +++ b/torchao/prototype/attention/shared_utils/attention.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared FP8 scaled dot-product attention implementation. + +Backend-specific modules (``fp8_fa3/attention.py``, etc.) provide thin +named wrappers around these functions via ``functools.partial``. +""" + +from typing import Optional + +import torch + +from torchao.utils import torch_version_at_least + +_TORCH_VERSION_AT_LEAST_2_11 = torch_version_at_least("2.11.0") + +if _TORCH_VERSION_AT_LEAST_2_11: + from torch.nn.attention import SDPBackend, sdpa_kernel + from torch.nn.attention.experimental._scaled_dot_product_attention_quantized import ( + _scaled_dot_product_attention_quantized, + ) + +from torchao.prototype.attention.quantization import ( + _fp8_rope_sdpa_quantize, + _fp8_sdpa_quantize, +) + + +def _fp8_sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + *, + backend_name: str = "FP8", +) -> torch.Tensor: + """FP8 SDPA shared by all backends. + + The correct flash attention implementation (e.g. FA3) must be + activated before calling this function. The high-level + ``apply_low_precision_attention`` API handles this automatically. + + Input/output layout: [B, H, S, D]. + """ + if not _TORCH_VERSION_AT_LEAST_2_11: + raise RuntimeError("Low-precision attention requires PyTorch 2.11+.") + if attn_mask is not None: + raise ValueError(f"attn_mask not supported for FP8 {backend_name}") + if dropout_p != 0.0: + raise ValueError( + f"dropout_p must be 0.0 for FP8 {backend_name}, got {dropout_p}" + ) + + input_dtype = query.dtype + + q_fp8, k_fp8, v_fp8, descale_q, descale_k, descale_v = _fp8_sdpa_quantize( + query, key, value + ) + + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + out = _scaled_dot_product_attention_quantized( + q_fp8, + k_fp8, + v_fp8, + is_causal=is_causal, + scale=scale, + q_descale=descale_q, + k_descale=descale_k, + v_descale=descale_v, + ) + + return out.to(input_dtype) + + +def _fp8_rope_sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + rope_interleaved: bool = False, + *, + backend_name: str = "FP8", +) -> torch.Tensor: + """Fused RoPE + FP8 SDPA shared by all backends. + + Input layout: [B, S, H, D] (pre-transpose). The fused quantization + kernel handles the transpose to [B, H, S, D] internally. + Output layout: [B, H, S, D]. + """ + if not _TORCH_VERSION_AT_LEAST_2_11: + raise RuntimeError("Low-precision attention requires PyTorch 2.11+.") + if attn_mask is not None: + raise ValueError(f"attn_mask not supported for FP8 {backend_name}") + if dropout_p != 0.0: + raise ValueError( + f"dropout_p must be 0.0 for FP8 {backend_name}, got {dropout_p}" + ) + + input_dtype = query.dtype + + cos = cos.to(query.device) + sin = sin.to(query.device) + + q_fp8, k_fp8, v_fp8, descale_q, descale_k, descale_v = _fp8_rope_sdpa_quantize( + query, key, value, cos, sin, rope_interleaved=rope_interleaved + ) + + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + out = _scaled_dot_product_attention_quantized( + q_fp8, + k_fp8, + v_fp8, + is_causal=is_causal, + scale=scale, + q_descale=descale_q, + k_descale=descale_k, + v_descale=descale_v, + ) + + return out.to(input_dtype) diff --git a/torchao/prototype/attention/shared_utils/custom_ops.py b/torchao/prototype/attention/shared_utils/custom_ops.py new file mode 100644 index 0000000000..c9bea64b1c --- /dev/null +++ b/torchao/prototype/attention/shared_utils/custom_ops.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared custom op registration and compile helpers for FP8 attention backends. + +Custom ops are needed because our FP8 SDPA functions call Triton kernels +that are not traceable by torch.compile. Registering them as custom_ops +tells the compiler to treat them as opaque nodes with known shapes/dtypes. +""" + +from functools import partial +from typing import Callable, NamedTuple + +import torch +import torch._inductor.config as inductor_config +import torch.nn as nn + +from torchao.prototype.attention.shared_utils.fusion_utils import ( + detect_causal_mask, +) +from torchao.prototype.attention.shared_utils.fusion_utils import ( + rope_sdpa_fusion_pass as _shared_fusion_pass, +) + + +class RegisteredOps(NamedTuple): + rope_sdpa_op: object + fp8_sdpa_op: object + + +def register_fp8_attention_ops( + backend_name: str, + rope_sdpa_fn: Callable, + sdpa_fn: Callable, +) -> RegisteredOps: + """Register the RoPE+SDPA and plain SDPA custom ops for a backend.""" + backend = backend_name.lower() + + rope_op_name = f"torchao::fp8_{backend}_rope_sdpa" + + @torch.library.custom_op(rope_op_name, mutates_args=()) + def _rope_sdpa_custom_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + rope_interleaved: bool = False, + ) -> torch.Tensor: + actual_scale = scale if scale != 0.0 else None + return rope_sdpa_fn( + q, + k, + v, + cos, + sin, + is_causal=is_causal, + scale=actual_scale, + enable_gqa=enable_gqa, + rope_interleaved=rope_interleaved, + ) + + @_rope_sdpa_custom_op.register_fake + def _rope_sdpa_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + rope_interleaved: bool = False, + ) -> torch.Tensor: + B, S, H, D = q.shape + return torch.empty(B, H, S, D, dtype=q.dtype, device=q.device) + + sdpa_op_name = f"torchao::fp8_{backend}_sdpa" + + @torch.library.custom_op(sdpa_op_name, mutates_args=()) + def _sdpa_custom_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + ) -> torch.Tensor: + actual_scale = scale if scale != 0.0 else None + return sdpa_fn( + q, + k, + v, + is_causal=is_causal, + scale=actual_scale, + enable_gqa=enable_gqa, + ) + + @_sdpa_custom_op.register_fake + def _sdpa_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, + ) -> torch.Tensor: + return torch.empty(q.shape, dtype=q.dtype, device=q.device) + + rope_sdpa_op = getattr( + getattr(torch.ops, "torchao"), f"fp8_{backend}_rope_sdpa" + ).default + fp8_sdpa_op = getattr(getattr(torch.ops, "torchao"), f"fp8_{backend}_sdpa").default + + return RegisteredOps(rope_sdpa_op=rope_sdpa_op, fp8_sdpa_op=fp8_sdpa_op) + + +def make_backend_fn( + ops: RegisteredOps, + backend_name: str, + flash_impl_name: str, + max_head_dim: int = 256, +) -> Callable: + """Return a ``make_fp8_backend(model, fuse_rope_using_torch_compile)`` function for a backend.""" + + def make_fp8_backend( + model: nn.Module, + fuse_rope_using_torch_compile: bool, + ) -> Callable: + from torch._inductor.compile_fx import compile_fx + + strip_causal_mask = detect_causal_mask(model, flash_impl_name=flash_impl_name) + + pass_fn = partial( + _shared_fusion_pass, + rope_sdpa_op=ops.rope_sdpa_op, + fp8_sdpa_op=ops.fp8_sdpa_op, + max_head_dim=max_head_dim, + backend_name=backend_name, + fuse_rope=fuse_rope_using_torch_compile, + strip_causal_mask=strip_causal_mask, + ) + + def fp8_attention_backend(gm, example_inputs): + old_pass = inductor_config.pre_grad_custom_pass + inductor_config.pre_grad_custom_pass = pass_fn + try: + return compile_fx(gm, example_inputs) + finally: + inductor_config.pre_grad_custom_pass = old_pass + + return fp8_attention_backend + + return make_fp8_backend diff --git a/torchao/prototype/attention/shared_utils/fusion_utils.py b/torchao/prototype/attention/shared_utils/fusion_utils.py new file mode 100644 index 0000000000..cf1e85f2a1 --- /dev/null +++ b/torchao/prototype/attention/shared_utils/fusion_utils.py @@ -0,0 +1,977 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared FX graph pattern detection and fusion utilities for low-precision attention. +""" + +import logging +import operator +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.fx import Graph, Node + +logger = logging.getLogger(__name__) + + +@dataclass +class RoPEMatch: + """Result of detecting a RoPE pattern on a tensor.""" + + pre_rope_input: Node # tensor before RoPE: "x" in "x * cos + rotate_half(x) * sin" + cos_node: Node # cosine frequencies, traced back to [S, D] source + sin_node: Node # sine frequencies, traced back to [S, D] source + rope_interleaved: bool # True = FLUX interleaved, False = NeoX half-split + + +# FX Node Utilities + + +def _is_op(node: Node, *targets) -> bool: + """Check if an FX node matches one of the given targets.""" + if node.op in ("call_function", "call_method"): + return node.target in targets + return False + + +def _get_fake_tensor(node: Node) -> Optional[torch.Tensor]: + """Get the FakeTensor metadata from a node (pre-grad or post-grad).""" + for key in ("val", "example_value"): + if key in node.meta: + val = node.meta[key] + if isinstance(val, torch.Tensor): + return val + return None + + +def _get_node_shape(node: Node) -> Optional[Tuple[int, ...]]: + fake = _get_fake_tensor(node) + if fake is not None: + return tuple(fake.shape) + return None + + +def _reshape_cos_sin_to_2d( + graph: Graph, + cos_node: Node, + sin_node: Node, + insert_before: Node, +) -> Optional[Tuple[Node, Node]]: + """Reshape cos/sin nodes to 2D [S, D] if they have leading size-1 dims. + + HuggingFace models produce cos/sin with shape [B, S, D] or [1, 1, S, D]. + """ + cos_shape = _get_node_shape(cos_node) + sin_shape = _get_node_shape(sin_node) + + if cos_shape is None or sin_shape is None: + return cos_node, sin_node + + if len(cos_shape) == 2 and len(sin_shape) == 2: + return cos_node, sin_node + + for name, shape in [("cos", cos_shape), ("sin", sin_shape)]: + if len(shape) < 2: + logger.debug("RoPE %s has fewer than 2 dims: shape=%s", name, shape) + return None + for dim in shape[:-2]: + if dim != 1: + logger.debug( + "RoPE %s has non-unit leading dim: shape=%s", + name, + shape, + ) + return None + + s, d = cos_shape[-2], cos_shape[-1] + with graph.inserting_before(insert_before): + cos_2d = graph.call_function( + torch.ops.aten.view.default, + args=(cos_node, [s, d]), + ) + sin_2d = graph.call_function( + torch.ops.aten.view.default, + args=(sin_node, [s, d]), + ) + return cos_2d, sin_2d + + +def _trace_through_views(node: Node) -> Node: + """Trace backward through view-like ops (unsqueeze, expand, clone, to.dtype, etc.).""" + current = node + while isinstance(current, Node): + if current.op == "call_function" and current.target in ( + torch.ops.aten.unsqueeze.default, + torch.ops.aten.clone.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.expand.default, + torch.ops.aten.to.dtype, + torch.ops.aten._to_copy.default, + ): + current = current.args[0] + elif current.op == "call_method" and current.target in ( + "unsqueeze", + "clone", + "contiguous", + "expand", + "to", + "float", + "half", + "bfloat16", + ): + current = current.args[0] + elif ( + current.op == "call_function" + and current.target is operator.getitem + and len(current.args) >= 2 + and isinstance(current.args[1], tuple) + and all(i is None or i == slice(None) for i in current.args[1]) + ): + current = current.args[0] + else: + break + return current + + +# Transpose Detection + + +def _unwrap_transpose(node: Node) -> Optional[Node]: + """If node is transpose(1,2) or permute([0,2,1,3]), return its input. + + Also looks through contiguous()/clone() wrappers. + """ + if not isinstance(node, Node): + return None + + current = node + while isinstance(current, Node) and _is_op( + current, + torch.ops.aten.contiguous.default, + torch.ops.aten.clone.default, + "contiguous", + "clone", + ): + current = current.args[0] + + if not isinstance(current, Node): + return None + + # aten.transpose.int(tensor, 1, 2) + if _is_op(current, torch.ops.aten.transpose.int): + if len(current.args) >= 3: + dim0, dim1 = current.args[1], current.args[2] + if (dim0 == 1 and dim1 == 2) or (dim0 == 2 and dim1 == 1): + return current.args[0] + + # aten.permute.default(tensor, [0, 2, 1, 3]) + if _is_op(current, torch.ops.aten.permute.default): + if len(current.args) >= 2: + perm = current.args[1] + if list(perm) == [0, 2, 1, 3]: + return current.args[0] + + # call_method transpose(1, 2) + if _is_op(current, "transpose"): + if len(current.args) >= 3: + dim0, dim1 = current.args[1], current.args[2] + if (dim0 == 1 and dim1 == 2) or (dim0 == 2 and dim1 == 1): + return current.args[0] + + # call_method permute(0, 2, 1, 3) + if _is_op(current, "permute"): + if len(current.args) >= 5: + perm = list(current.args[1:5]) + if perm == [0, 2, 1, 3]: + return current.args[0] + elif len(current.args) >= 2 and isinstance(current.args[1], (list, tuple)): + perm = current.args[1] + if list(perm) == [0, 2, 1, 3]: + return current.args[0] + + return None + + +def _unwrap_repeat_kv(node: Node) -> Optional[Node]: + """Unwrap HuggingFace's repeat_kv (GQA head repetition) pattern. + + Pattern: reshape <- expand <- getitem(unsqueeze). + Returns the pre-repetition tensor, or None. + """ + if not isinstance(node, Node): + return None + + if not _is_op( + node, + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + "reshape", + "view", + ): + return None + + inner = node.args[0] if node.args else None + if not isinstance(inner, Node): + return None + + if not _is_op(inner, torch.ops.aten.expand.default, "expand"): + return None + + inner2 = inner.args[0] if inner.args else None + if not isinstance(inner2, Node): + return None + + if _is_op(inner2, torch.ops.aten.unsqueeze.default, "unsqueeze"): + return inner2.args[0] if inner2.args else None + + if inner2.op == "call_function" and inner2.target is operator.getitem: + if len(inner2.args) >= 2 and isinstance(inner2.args[1], tuple): + if any(i is None for i in inner2.args[1]): + return inner2.args[0] if inner2.args else None + + return None + + +# SDPA Detection and Parameter Extraction + + +def _is_sdpa_node(node: Node) -> bool: + return _is_op( + node, + torch.ops.aten.scaled_dot_product_attention.default, + torch._C._nn.scaled_dot_product_attention, + ) + + +def _is_lower_triangular_bool_mask(mask: torch.Tensor) -> bool: + """Check if a tensor is a bool, square lower-triangular (causal) mask.""" + if mask.dtype != torch.bool or mask.ndim < 2: + return False + q_len, kv_len = mask.shape[-2], mask.shape[-1] + if q_len != kv_len: + return False + ref = torch.tril(torch.ones(q_len, kv_len, dtype=torch.bool, device=mask.device)) + return torch.equal(mask.broadcast_to(mask.shape), ref.expand_as(mask)) + + +def detect_causal_mask( + model: nn.Module, + sample_input_ids=None, + flash_impl_name: str | None = None, +) -> bool: + """Run one forward pass to detect whether the model uses causal masks. + + Returns True when every SDPA call used causal attention (either via + a materialized lower-triangular bool mask, or via is_causal=True). + """ + from torch.nn.attention import ( + activate_flash_attention_impl, + restore_flash_attention_impl, + ) + + try: + device = next(model.parameters()).device + except StopIteration: + return False + + if sample_input_ids is None: + vocab_size = getattr(getattr(model, "config", None), "vocab_size", None) + if vocab_size is None: + return False + sample_input_ids = torch.randint(0, vocab_size, (1, 16), device=device) + + all_causal: list[bool] = [] + saw_any_sdpa = False + + original_sdpa = F.scaled_dot_product_attention + + def _hook(*args, **kwargs): + nonlocal saw_any_sdpa + saw_any_sdpa = True + attn_mask = args[3] if len(args) > 3 else kwargs.get("attn_mask", None) + is_causal = kwargs.get("is_causal", False) if len(args) <= 5 else args[5] + + if attn_mask is not None and not is_causal: + all_causal.append(_is_lower_triangular_bool_mask(attn_mask)) + elif attn_mask is None and is_causal: + all_causal.append(True) + + return original_sdpa(*args, **kwargs) + + F.scaled_dot_product_attention = _hook + if flash_impl_name is not None: + activate_flash_attention_impl(flash_impl_name) + try: + with torch.no_grad(): + model(sample_input_ids) + except Exception: + logger.debug("detect_causal_mask: forward pass failed", exc_info=True) + return False + finally: + F.scaled_dot_product_attention = original_sdpa + if flash_impl_name is not None: + restore_flash_attention_impl() + + if not saw_any_sdpa: + return False + + return all(all_causal) + + +def _sdpa_is_fusible(node: Node, strip_causal_mask: bool = False) -> Tuple[bool, bool]: + """Check if an SDPA node is compatible with our FP8 fused kernel. + + Returns (is_fusible, needs_mask_strip). + """ + args = node.args + kwargs = node.kwargs + + attn_mask = args[3] if len(args) > 3 else kwargs.get("attn_mask", None) + is_causal = args[5] if len(args) > 5 else kwargs.get("is_causal", False) + + needs_mask_strip = False + if attn_mask is not None: + if not is_causal and strip_causal_mask and isinstance(attn_mask, Node): + needs_mask_strip = True + else: + return False, False + + dropout_p = args[4] if len(args) > 4 else kwargs.get("dropout_p", 0.0) + if dropout_p != 0.0: + return False, False + + return True, needs_mask_strip + + +def _strip_causal_mask(node: Node) -> None: + """Strip a materialized causal mask from an SDPA node.""" + args = list(node.args) + kwargs = dict(node.kwargs) + + if len(args) > 3: + args[3] = None + elif "attn_mask" in kwargs: + kwargs["attn_mask"] = None + + if len(args) > 5: + args[5] = True + elif "is_causal" in kwargs: + kwargs["is_causal"] = True + else: + kwargs["is_causal"] = True + + node.args = tuple(args) + node.kwargs = kwargs + + logger.info("Stripped causal mask from SDPA node: %s", node.name) + + +def _get_sdpa_params(node: Node) -> Tuple[bool, float, bool]: + """Extract is_causal, scale, and enable_gqa from an SDPA node. + + Scale uses 0.0 as sentinel for "default" (1/sqrt(D)). + """ + args = node.args + kwargs = node.kwargs + + is_causal = args[5] if len(args) > 5 else kwargs.get("is_causal", False) + scale = args[6] if len(args) > 6 else kwargs.get("scale", None) + enable_gqa = args[7] if len(args) > 7 else kwargs.get("enable_gqa", False) + + if scale is None: + scale = 0.0 + + return is_causal, scale, enable_gqa + + +def _get_sdpa_qkv(node: Node) -> Optional[Tuple[Node, Node, Node]]: + """Extract Q, K, V input nodes from an SDPA node.""" + args = node.args + kwargs = node.kwargs + + q = args[0] if len(args) > 0 else kwargs.get("query", None) + k = args[1] if len(args) > 1 else kwargs.get("key", None) + v = args[2] if len(args) > 2 else kwargs.get("value", None) + + if not all(isinstance(n, Node) for n in (q, k, v)): + return None + + return q, k, v + + +# NeoX/LLaMA RoPE Pattern Detection +# +# NeoX/LLaMA RoPE: +# rotate_half(x) = cat(-x[..., D//2:], x[..., :D//2], dim=-1) +# apply_rope(x, cos, sin) = x * cos + rotate_half(x) * sin + + +def _detect_rotate_half(cat_node: Node) -> Optional[Node]: + """Detect rotate_half(x) = cat(-x[..., D//2:], x[..., :D//2], dim=-1). + + Returns the source tensor x, or None. + """ + if not _is_op(cat_node, torch.ops.aten.cat.default, torch.cat): + return None + + if len(cat_node.args) < 1: + return None + + tensors_list = cat_node.args[0] + + if len(cat_node.args) >= 2: + cat_dim = cat_node.args[1] + else: + cat_dim = cat_node.kwargs.get("dim", 0) + + if cat_dim not in (-1, 3): + return None + + if not isinstance(tensors_list, (list, tuple)) or len(tensors_list) != 2: + return None + + neg_part = tensors_list[0] + pos_part = tensors_list[1] + + if not isinstance(neg_part, Node) or not isinstance(pos_part, Node): + return None + + if not _is_op(neg_part, torch.ops.aten.neg.default, operator.neg, torch.neg): + return None + + neg_input = neg_part.args[0] + if not isinstance(neg_input, Node): + return None + + return _match_rotate_half_slices(neg_input, pos_part) + + +def _match_rotate_half_slices(neg_input: Node, pos_part: Node) -> Optional[Node]: + """Match the slice patterns in rotate_half. Returns the source tensor x, or None.""" + # ATen slice pattern + if _is_op(neg_input, torch.ops.aten.slice.Tensor) and _is_op( + pos_part, torch.ops.aten.slice.Tensor + ): + slice_neg_source = neg_input.args[0] + slice_pos_source = pos_part.args[0] + + if slice_neg_source is not slice_pos_source: + return None + + slice_neg_dim = neg_input.args[1] if len(neg_input.args) > 1 else 0 + slice_pos_dim = pos_part.args[1] if len(pos_part.args) > 1 else 0 + + if slice_neg_dim not in (-1, 3) or slice_pos_dim not in (-1, 3): + return None + + pos_start = pos_part.args[2] if len(pos_part.args) > 2 else None + pos_end = pos_part.args[3] if len(pos_part.args) > 3 else None + neg_start = neg_input.args[2] if len(neg_input.args) > 2 else None + + if pos_start != 0: + return None + if neg_start is None or pos_end is None: + return None + if neg_start != pos_end: + return None + + return slice_neg_source + + # Dynamo getitem pattern + if _is_op(neg_input, operator.getitem) and _is_op(pos_part, operator.getitem): + slice_neg_source = neg_input.args[0] + slice_pos_source = pos_part.args[0] + + if slice_neg_source is not slice_pos_source: + return None + + neg_idx = neg_input.args[1] + pos_idx = pos_part.args[1] + + neg_slice = _extract_last_dim_slice(neg_idx) + pos_slice = _extract_last_dim_slice(pos_idx) + + if neg_slice is None or pos_slice is None: + return None + + if pos_slice.start not in (0, None): + return None + if pos_slice.stop is None: + return None + if neg_slice.start is None: + return None + if neg_slice.start != pos_slice.stop: + return None + + return slice_neg_source + + return None + + +def _extract_last_dim_slice(idx) -> Optional[slice]: + """Extract the slice on the last dimension from a getitem index.""" + if isinstance(idx, tuple): + if len(idx) >= 2 and idx[0] is Ellipsis and isinstance(idx[1], slice): + return idx[1] + if len(idx) >= 1 and isinstance(idx[-1], slice): + for i in range(len(idx) - 1): + if idx[i] is not Ellipsis and idx[i] != slice(None): + return None + return idx[-1] + elif isinstance(idx, slice): + return idx + return None + + +def _detect_interleaved_rotation(node: Node) -> Optional[Node]: + """Detect the FLUX-style interleaved rotation pattern. + + Pattern: x.reshape(..., -1, 2).unbind(-1) -> stack([-x_imag, x_real], dim=-1).flatten(3) + Returns the source tensor x, or None. + """ + if not _is_op(node, torch.ops.aten.flatten.using_ints, "flatten"): + return None + if len(node.args) < 1: + return None + + stack_node = node.args[0] + if not isinstance(stack_node, Node): + return None + + if not _is_op(stack_node, torch.ops.aten.stack.default, torch.stack): + return None + if len(stack_node.args) < 1: + return None + + tensors_list = stack_node.args[0] + if len(stack_node.args) >= 2: + stack_dim = stack_node.args[1] + else: + stack_dim = stack_node.kwargs.get("dim", 0) + if stack_dim != -1: + return None + + if not isinstance(tensors_list, (list, tuple)) or len(tensors_list) != 2: + return None + + neg_part = tensors_list[0] + pos_part = tensors_list[1] + if not isinstance(neg_part, Node) or not isinstance(pos_part, Node): + return None + + if not _is_op(neg_part, torch.ops.aten.neg.default, operator.neg, torch.neg): + return None + + x_imag = neg_part.args[0] + if not isinstance(x_imag, Node): + return None + + if not _is_op(pos_part, operator.getitem) or not _is_op(x_imag, operator.getitem): + return None + + real_source = pos_part.args[0] + imag_source = x_imag.args[0] + if real_source is not imag_source: + return None + if pos_part.args[1] != 0 or x_imag.args[1] != 1: + return None + + unbind_node = real_source + if not isinstance(unbind_node, Node): + return None + if not _is_op(unbind_node, torch.ops.aten.unbind.int, "unbind"): + return None + + if len(unbind_node.args) >= 2: + unbind_dim = unbind_node.args[1] + else: + unbind_dim = unbind_node.kwargs.get("dim", 0) + if unbind_dim != -1: + return None + + reshape_node = unbind_node.args[0] + if not isinstance(reshape_node, Node): + return None + if not _is_op( + reshape_node, + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + "reshape", + "view", + ): + return None + + source_x = reshape_node.args[0] + return source_x if isinstance(source_x, Node) else None + + +def _detect_rotation(node: Node) -> Optional[Tuple[Node, bool]]: + """Detect any supported rotation pattern. + + Returns (source_tensor, is_interleaved) or None. + """ + result = _detect_rotate_half(node) + if result is not None: + return result, False + + result = _detect_interleaved_rotation(node) + if result is not None: + return result, True + + return None + + +def _detect_neox_rope(node: Node) -> Optional[RoPEMatch]: + """Detect the NeoX/LLaMA RoPE pattern: add(mul(x, cos), mul(rotate_half(x), sin)).""" + if not _is_op(node, torch.ops.aten.add.Tensor, operator.add): + return None + if len(node.args) < 2: + return None + + left = node.args[0] + right = node.args[1] + if not isinstance(left, Node) or not isinstance(right, Node): + return None + + if not _is_op(left, torch.ops.aten.mul.Tensor, operator.mul): + return None + if not _is_op(right, torch.ops.aten.mul.Tensor, operator.mul): + return None + + def _try_match(x_mul: Node, rot_mul: Node) -> Optional[RoPEMatch]: + rot_a, rot_b = rot_mul.args[0], rot_mul.args[1] + + for rot_candidate, sin_candidate in [(rot_a, rot_b), (rot_b, rot_a)]: + if not isinstance(rot_candidate, Node): + continue + + rot_unwrapped = _trace_through_views(rot_candidate) + rotation_result = _detect_rotation(rot_unwrapped) + if rotation_result is None: + continue + x_from_rotate, is_interleaved = rotation_result + + x_a, x_b = x_mul.args[0], x_mul.args[1] + for x_candidate, cos_candidate in [(x_a, x_b), (x_b, x_a)]: + if not isinstance(x_candidate, Node): + continue + + x_traced = _trace_through_views(x_candidate) + if x_traced is not x_from_rotate: + continue + + cos_source = _trace_through_views(cos_candidate) + sin_source = _trace_through_views(sin_candidate) + + return RoPEMatch( + pre_rope_input=x_from_rotate, + cos_node=cos_source, + sin_node=sin_source, + rope_interleaved=is_interleaved, + ) + + return None + + # Try both orderings of the add (commutative). + result = _try_match(left, right) + if result is not None: + return result + return _try_match(right, left) + + +def _detect_rope(node: Node) -> Optional[RoPEMatch]: + """Detect any supported RoPE variant at a given node.""" + return _detect_neox_rope(node) + + +# Graph Surgery + + +def _replace_with_fused_op( + graph: Graph, + sdpa_node: Node, + pre_rope_q: Node, + pre_rope_k: Node, + v_input: Node, + cos_node: Node, + sin_node: Node, + is_causal: bool, + scale: float, + enable_gqa: bool, + rope_interleaved: bool, + rope_sdpa_op, +) -> None: + """Replace an SDPA node with a fused RoPE+SDPA custom op.""" + with graph.inserting_before(sdpa_node): + fused_node = graph.call_function( + rope_sdpa_op, + args=(pre_rope_q, pre_rope_k, v_input, cos_node, sin_node), + kwargs={ + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + "rope_interleaved": rope_interleaved, + }, + ) + + fused_node.meta = sdpa_node.meta.copy() + sdpa_node.replace_all_uses_with(fused_node) + + logger.info( + "Fused RoPE + SDPA: replaced %s with %s", + sdpa_node.name, + fused_node.name, + ) + + +def _replace_sdpa_with_fp8( + graph: Graph, + sdpa_node: Node, + q_node: Node, + k_node: Node, + v_node: Node, + is_causal: bool, + scale: float, + enable_gqa: bool, + fp8_sdpa_op, +) -> None: + """Replace a plain SDPA node with an FP8 SDPA op (no RoPE fusion).""" + with graph.inserting_before(sdpa_node): + fp8_node = graph.call_function( + fp8_sdpa_op, + args=(q_node, k_node, v_node), + kwargs={ + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + }, + ) + + fp8_node.meta = sdpa_node.meta.copy() + sdpa_node.replace_all_uses_with(fp8_node) + + logger.info( + "Replaced SDPA with FP8: %s -> %s", + sdpa_node.name, + fp8_node.name, + ) + + +# Main Fusion Pass + + +def rope_sdpa_fusion_pass( + graph: Graph, + rope_sdpa_op, + fp8_sdpa_op, + max_head_dim: int = 256, + backend_name: str = "FP8", + fuse_rope: bool = True, + strip_causal_mask: bool = False, +) -> None: + """Detect and replace SDPA patterns in the FX graph. + + For each fusible SDPA node: + - Pattern A (RoPE -> transpose -> SDPA): fuse with rope_sdpa custom op + - Pattern B (transpose -> RoPE -> SDPA): fuse with rope_sdpa custom op + - No RoPE: replace with fp8_sdpa custom op + + Note: KV caching must be disabled before compilation. + DynamicCache.update() inserts torch.cat nodes that break pattern matching. + """ + sdpa_nodes = [n for n in graph.nodes if _is_sdpa_node(n)] + + if not sdpa_nodes: + logger.debug("RoPE + SDPA fusion: found 0 SDPA nodes in graph") + return + + fused_count = 0 + fp8_count = 0 + + for sdpa_node in sdpa_nodes: + is_fusible, needs_mask_strip = _sdpa_is_fusible( + sdpa_node, strip_causal_mask=strip_causal_mask + ) + if not is_fusible: + logger.debug("Skipping non-fusible SDPA: %s", sdpa_node.name) + continue + + if needs_mask_strip: + _strip_causal_mask(sdpa_node) + + is_causal, scale, enable_gqa = _get_sdpa_params(sdpa_node) + + qkv = _get_sdpa_qkv(sdpa_node) + if qkv is None: + continue + q_node, k_node, v_node = qkv + + # Try RoPE fusion + if fuse_rope: + v_pre_transpose = _unwrap_transpose(v_node) + + # Pattern A: RoPE -> transpose -> SDPA (FLUX-style) + q_pre_transpose = _unwrap_transpose(q_node) + k_pre_transpose = _unwrap_transpose(k_node) + + if q_pre_transpose is not None and k_pre_transpose is not None: + q_pre_cast = _trace_through_views(q_pre_transpose) + k_pre_cast = _trace_through_views(k_pre_transpose) + + q_rope = _detect_rope(q_pre_cast) + k_rope = _detect_rope(k_pre_cast) + + if q_rope is not None and k_rope is not None: + pre_rope_q = _trace_through_views(q_rope.pre_rope_input) + pre_rope_k = _trace_through_views(k_rope.pre_rope_input) + + if v_pre_transpose is None: + logger.debug( + "Pattern A: V has no transpose, skipping: %s", + sdpa_node.name, + ) + continue + + cos_sin = _reshape_cos_sin_to_2d( + graph, + q_rope.cos_node, + q_rope.sin_node, + sdpa_node, + ) + if cos_sin is None: + logger.debug( + "Pattern A: cos/sin shape incompatible, skipping: %s", + sdpa_node.name, + ) + continue + cos_2d, sin_2d = cos_sin + + _replace_with_fused_op( + graph=graph, + sdpa_node=sdpa_node, + pre_rope_q=pre_rope_q, + pre_rope_k=pre_rope_k, + v_input=v_pre_transpose, + cos_node=cos_2d, + sin_node=sin_2d, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + rope_interleaved=q_rope.rope_interleaved, + rope_sdpa_op=rope_sdpa_op, + ) + fused_count += 1 + continue + + # Pattern B: transpose -> RoPE -> SDPA (HuggingFace-style) + # For GQA, K may go through repeat_kv after RoPE. + q_rope = _detect_rope(_trace_through_views(q_node)) + + k_rope = _detect_rope(_trace_through_views(k_node)) + gqa_unwrapped = False + if k_rope is None: + k_pre_repeat = _unwrap_repeat_kv(k_node) + if k_pre_repeat is not None: + k_rope = _detect_rope(_trace_through_views(k_pre_repeat)) + if k_rope is not None: + gqa_unwrapped = True + + if q_rope is not None and k_rope is not None: + q_bshd = _unwrap_transpose(_trace_through_views(q_rope.pre_rope_input)) + k_bshd = _unwrap_transpose(_trace_through_views(k_rope.pre_rope_input)) + + if q_bshd is not None and k_bshd is not None: + v_for_fusion = v_node + if gqa_unwrapped: + v_pre_repeat = _unwrap_repeat_kv(v_node) + if v_pre_repeat is not None: + v_for_fusion = v_pre_repeat + + v_bshd = _unwrap_transpose(v_for_fusion) + if v_bshd is None: + logger.debug( + "Pattern B: V has no transpose, skipping: %s", + sdpa_node.name, + ) + continue + + cos_sin = _reshape_cos_sin_to_2d( + graph, + q_rope.cos_node, + q_rope.sin_node, + sdpa_node, + ) + if cos_sin is None: + logger.debug( + "Pattern B: cos/sin shape incompatible, skipping: %s", + sdpa_node.name, + ) + continue + cos_2d, sin_2d = cos_sin + + fused_enable_gqa = True if gqa_unwrapped else enable_gqa + + _replace_with_fused_op( + graph=graph, + sdpa_node=sdpa_node, + pre_rope_q=q_bshd, + pre_rope_k=k_bshd, + v_input=v_bshd, + cos_node=cos_2d, + sin_node=sin_2d, + is_causal=is_causal, + scale=scale, + enable_gqa=fused_enable_gqa, + rope_interleaved=q_rope.rope_interleaved, + rope_sdpa_op=rope_sdpa_op, + ) + fused_count += 1 + continue + + # No RoPE detected (or fuse_rope=False) — replace with non-rope FP8 SDPA + q_shape = _get_node_shape(q_node) + if q_shape is not None and q_shape[-1] > max_head_dim: + logger.debug( + "Skipping FP8 replacement: head_dim=%d > %d for %s", + q_shape[-1], + max_head_dim, + sdpa_node.name, + ) + continue + + _replace_sdpa_with_fp8( + graph=graph, + sdpa_node=sdpa_node, + q_node=q_node, + k_node=k_node, + v_node=v_node, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + fp8_sdpa_op=fp8_sdpa_op, + ) + fp8_count += 1 + + replaced_count = fused_count + fp8_count + logger.info( + "Found %d SDPA node(s): %d RoPE-fused, %d FP8-replaced (backend: %s)", + len(sdpa_nodes), + fused_count, + fp8_count, + backend_name, + ) + + if replaced_count > 0: + graph.eliminate_dead_code() + logger.info( + "Fusion pass complete: %d RoPE-fused, %d FP8-replaced", + fused_count, + fp8_count, + ) diff --git a/torchao/prototype/attention/shared_utils/setup.py b/torchao/prototype/attention/shared_utils/setup.py new file mode 100644 index 0000000000..9f58134377 --- /dev/null +++ b/torchao/prototype/attention/shared_utils/setup.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn as nn + +from torchao.prototype.attention.shared_utils.fusion_utils import detect_causal_mask +from torchao.prototype.attention.shared_utils.wrapper import ( + _FP8FlashAttentionMonkeyPatchWrapper, + _FP8FlashAttentionWrapper, + _make_causal_aware_sdpa, +) + + +def setup_fp8_backend( + model: nn.Module, + flash_impl_name: str, + fuse_rope_using_torch_compile: bool, +) -> nn.Module: + if flash_impl_name == "FA3": + from torchao.prototype.attention.fp8_fa3.attention import ( + fp8_fa3_sdpa as sdpa_fn, + ) + from torchao.prototype.attention.fp8_fa3.fusion_pass import make_fp8_backend + else: + raise ValueError(f"Unknown flash_impl_name: {flash_impl_name}") + + if fuse_rope_using_torch_compile: + wrapper = _FP8FlashAttentionWrapper(model, flash_impl_name=flash_impl_name) + wrapper.compile_backend = make_fp8_backend(model, fuse_rope_using_torch_compile) + warnings.warn( + "fuse_rope_using_torch_compile=True: you must call " + "torch.compile(model, backend=model.compile_backend) for the " + "RoPE + FP8 fusion to take effect. Without it the model runs " + "eagerly with no fusion. " + "Note: this path uses torch._inductor.config.pre_grad_custom_pass, " + "an unstable internal API that may change across PyTorch versions.", + UserWarning, + stacklevel=3, + ) + return wrapper + + strip_causal_mask = detect_causal_mask(model, flash_impl_name=flash_impl_name) + return _FP8FlashAttentionMonkeyPatchWrapper( + model, + flash_impl_name=flash_impl_name, + sdpa_patch_fn=_make_causal_aware_sdpa(sdpa_fn, strip_causal_mask), + ) diff --git a/torchao/prototype/attention/shared_utils/wrapper.py b/torchao/prototype/attention/shared_utils/wrapper.py new file mode 100644 index 0000000000..506d10e80c --- /dev/null +++ b/torchao/prototype/attention/shared_utils/wrapper.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import ( + activate_flash_attention_impl, + restore_flash_attention_impl, +) + + +class _LowPrecisionAttentionWrapper(nn.Module): + """Base wrapper. Proxies attribute access to the original module.""" + + def __init__(self, orig_mod: nn.Module): + super().__init__() + self._orig_mod = orig_mod + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._orig_mod, name) + + +class _FP8FlashAttentionWrapper(_LowPrecisionAttentionWrapper): + """Compile path wrapper. Activates the flash impl around the module forward.""" + + def __init__(self, orig_mod: nn.Module, flash_impl_name: str): + super().__init__(orig_mod) + self._flash_impl_name = flash_impl_name + + def forward(self, *args, **kwargs): + activate_flash_attention_impl(self._flash_impl_name) + try: + return self._orig_mod(*args, **kwargs) + finally: + restore_flash_attention_impl() + + +class _FP8FlashAttentionMonkeyPatchWrapper(_LowPrecisionAttentionWrapper): + """Monkey-patch path wrapper. Replaces ``F.scaled_dot_product_attention`` + with the FP8 backend for the duration of each forward call. + """ + + def __init__( + self, orig_mod: nn.Module, flash_impl_name: str, sdpa_patch_fn: Callable + ): + super().__init__(orig_mod) + self._flash_impl_name = flash_impl_name + self._sdpa_patch_fn = sdpa_patch_fn + + def forward(self, *args, **kwargs): + activate_flash_attention_impl(self._flash_impl_name) + try: + original_sdpa = F.scaled_dot_product_attention + F.scaled_dot_product_attention = self._sdpa_patch_fn + try: + return self._orig_mod(*args, **kwargs) + finally: + F.scaled_dot_product_attention = original_sdpa + finally: + restore_flash_attention_impl() + + +def _make_causal_aware_sdpa(fp8_sdpa_fn: Callable, strip_causal_mask: bool) -> Callable: + """Wrap an FP8 SDPA function to strip materialized causal masks.""" + if strip_causal_mask: + + def _patched( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + ): + if attn_mask is not None: + attn_mask = None + is_causal = True + return fp8_sdpa_fn( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + return _patched + return fp8_sdpa_fn diff --git a/torchao/prototype/attention/utils.py b/torchao/prototype/attention/utils.py new file mode 100644 index 0000000000..721a11f8ad --- /dev/null +++ b/torchao/prototype/attention/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import importlib + +import torch + + +def _is_hopper() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major == 9 + + +def _is_blackwell() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major == 10 + + +def _is_fa3_available() -> bool: + try: + importlib.import_module("flash_attn_interface") + return True + except ModuleNotFoundError: + return False + + +def _is_fa4_available() -> bool: + try: + importlib.import_module("flash_attn.cute.interface") + return True + except ModuleNotFoundError: + return False