Skip to content

Add FP8 FA3 low-precision attention with monkey-patch SDPA path#3959

Merged
howardzhang-cv merged 17 commits intomainfrom
gh/howardzhang-cv/23/head
Mar 9, 2026
Merged

Add FP8 FA3 low-precision attention with monkey-patch SDPA path#3959
howardzhang-cv merged 17 commits intomainfrom
gh/howardzhang-cv/23/head

Conversation

@howardzhang-cv
Copy link
Contributor

@howardzhang-cv howardzhang-cv commented Feb 27, 2026

Stack from ghstack (oldest at bottom):


Summary

  • Added new folder for low-precision attention APIs in torchao/prototype/attention
  • New API for FP8 FA3 low-precision attention with two components:
    • Elementary block: fp8_fa3_sdpa — a direct drop-in replacement for F.scaled_dot_product_attention that users can integrate into their model manually. Performs per-head FP8 quantization of Q, K, V followed by low-precision SDPA.
    • Simple wrapper: apply_low_precision_attention — wraps any model to automatically replace all SDPA calls with the FP8 variant. No torch.compile required.
  • New Triton kernel for fused QKV FP8 quantization (3-phase: absmax reduction, scale computation, quantize)
  • Causal mask detection: Pre-flight forward pass identifies HuggingFace-style materialized causal masks so the wrapper can strip them and use is_causal=True instead.
  • Flash attention activation is handled internally by the wrapper — no manual activate_flash_attention_impl / restore_flash_attention_impl calls needed.
  • Added new test folder for low-precision attention APIs in test/prototype/attention

Folder Breakdown

  • torchao/prototype/attention: new folder for low-precision attention APIs
    • init.py: public exports (apply_low_precision_attention, AttentionBackend, LowPrecisionAttentionConfig)
    • api.py: user-facing entry point that validates config and dispatches to the correct backend
    • config.py: AttentionBackend enum and LowPrecisionAttentionConfig dataclass
    • utils.py: hardware capability checks, backend availability detection
    • shared_utils/: shared infrastructure used by backend implementations
      • attention.py: shared _fp8_sdpa implementation (quantize + SDPA)
      • wrapper.py: _FP8FlashAttentionMonkeyPatchWrapper — replaces F.scaled_dot_product_attention during forward, manages flash activation internally
      • setup.py: setup_fp8_backend — builds the wrapper with causal mask detection
    • fp8_fa3/: FA3-specific backend
      • attention.py: fp8_fa3_sdpa elementary block
      • setup.py: thin wrapper calling setup_fp8_backend with FA3 parameters
    • quantization/: shared FP8 quantization kernels
      • quantization.py: _fp8_sdpa_quantize — calls fused Triton kernels for per-head Q, K, V quantization
      • triton_qkv_quantization.py: fused QKV FP8 quantization Triton kernel
  • test/prototype/attention: new folder for low-precision attention API tests
    • test_fp8_attention.py: numerical accuracy tests (eager SDPA) and model-level API tests (simple wrapper)

Test Plan

python -m pytest test/prototype/attention/test_fp8_attention.py -v

Example Usage

  from torchao.prototype.attention import (
      AttentionBackend,
      LowPrecisionAttentionConfig,
      apply_low_precision_attention,
  )

  model = MyModel()

  # Simple SDPA replacement — no torch.compile needed
  config = LowPrecisionAttentionConfig(backend=AttentionBackend.FP8_FA3)
  model = apply_low_precision_attention(model, config)

  # Flash activation is handled internally by the wrapper
  output = model(inputs)

Results

Single-Layer Results

Results directly comparing FA3 SDPA versus FA3 fp8 SDPA (including quantization time):
image

Llama3 Model Results

Results comparing Llama3 model with FA3 SDPA versus Llama3 using the FA3 fp8 wrapper. Does not use RoPE fusion.
Perplexity: 6.19 -> 6.25
image

[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3959

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 699c526 with merge base 42bcdc4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 27, 2026
@howardzhang-cv howardzhang-cv marked this pull request as draft February 27, 2026 08:09
@howardzhang-cv howardzhang-cv added the topic: new feature Use this tag if this PR adds a new feature label Feb 27, 2026
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 27, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility


ghstack-source-id: 4f8d3fb
Pull-Request: pytorch#3959
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 28, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: 4f8d3fb
Pull-Request: pytorch#3959
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 28, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: 4f8d3fb
Pull-Request: pytorch#3959
[ghstack-poisoned]
[ghstack-poisoned]
@howardzhang-cv howardzhang-cv marked this pull request as ready for review February 28, 2026 20:22
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 28, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: 921d691
Pull-Request: pytorch#3959
@howardzhang-cv howardzhang-cv requested review from drisspg and vkuzo March 2, 2026 19:28
"""
Apply low-precision attention to a model.

Depending on the configuration, the model is either:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing, can we make the distinction more clear? either with a boolean in this API, or just have two APIs, etc

- **Monkey-patch path** (``fuse_rope=False``, default): wraps the model
so that ``F.scaled_dot_product_attention`` is replaced with the FP8
backend at call time. No ``torch.compile`` is needed.
- **Compile path** (``fuse_rope=True``): internally calls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fuse_rope also changes compile settings, let's call it something like fuse_rope_using_torch_compile?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense, I changed the name

Args:
backend: Attention backend to use. If None (default), automatically
selected based on hardware capabilities.
use_hadamard: Apply Hadamard transform. Options:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a boolean but is actually an enum, maybe the name can be clearer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to hadamard_mode

- None: No Hadamard transform (default)
- "v": Apply Hadamard to V only
- "qkv": Apply Hadamard to Q, K, and V
fuse_rope: If True, the model is compiled with a custom Inductor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be about fusing rope as well as implementation details of compile usage, can the name reflect this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to fuse_rope_using_torch_compile

Apply low-precision attention to a model.

Depending on the configuration, the model is either:
- **Monkey-patch path** (``fuse_rope=False``, default): wraps the model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user starts with mokey patch patch and calls torch.compile on it, do they get the compile path? if yes, why not just delete the setting for compile path from this API and have it automatically work if compile is on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if the user starts with monkey patch path and calls torch.compile, it will still just be a torch.compiled model but the F.SDPA is replaced with the triton quant kernel + low precision SDPA. There will be no rope fusion like the compile path. I originally had it set so that the compile path is always on, and it would simply fallback to the simpler path if no RoPE is detected, but I figured that wasn't good since the compile path uses the pre-grad IR path, and the default should be a simpler path that doesn't use that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this stand post convo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll keep the higher-level API with the --fuse_rope option. If the user enables fuse_rope, we send a warning stating that they must explicitly run torch.compile to enable the rope fusion and low precision attention pathway. No more magic behind-the-scenes compile

[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 2, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: e7282fa
Pull-Request: pytorch#3959
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 3, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: e7282fa
Pull-Request: pytorch#3959
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 3, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: e7282fa
Pull-Request: pytorch#3959
FA3 uses _scaled_dot_product_attention_quantized internally,
which requires FA3 activation. This probe catches mismatches.
"""
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you run a deslop pass on some of this, I kind of doubt you need code like this

try:
q = torch.randn(1, 1, 4, 64, device="cuda", dtype=torch.bfloat16)
with torch.no_grad():
sdpa_fn(q, q, q, is_causal=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also we should just know right?

_scaled_dot_product_attention_quantized,
)

_MIN_VERSION_ERROR = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like these global error strings is weird Claude coding we could just inlin e

# In the monkey-patch path (no torch.compile), accelerate's hooks move
# tensors to the correct device but don't call torch.cuda.set_device().
# Triton dispatches based on current_device(), not tensor device, so
# without this guard the kernel launches on the wrong GPU's stream.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is kinda weird to me .. hmmm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, I was actually planning on asking you this later in the 1/1. This was weird to me too, but for some reason without this guard, triton does actually dispatch to the wrong device leading to NaNs in multi-gpu setups.

return torch.equal(mask.broadcast_to(mask.shape), ref.expand_as(mask))


def detect_causal_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really the best way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's surprisingly difficult to have a general way to see if the attention mask is just a causal mask. Claude suggested a bunch of ideas that imo either leaned too heavily on user interaction or was not robust enough. We can't do it in the pre-grad IR pass because we don't have access to the attention mask values to see if it's causal, so I came up with this method where we do a fake forward pass to see if the masks are causal to get rid of them. I could definitely just be doing something dumb though, there likely is a better way.

def _make_causal_aware_sdpa(fp8_sdpa_fn: Callable, strip_causal_mask: bool) -> Callable:
"""Wrap an FP8 SDPA function to strip materialized causal masks.

HuggingFace models (e.g. LLaMA) pass a materialized lower-triangular
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this is true.. are you sure ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I had to verify to see if it was true as well. In "SDPA" attention mode, huggingface materializes an attention mask instead of just using causal. Since FA3/FA4 aren't compatible with attention masks and only work with causal, I had to do the whole sketchy "detect mask" thing. On Blackwell though, it doesn't do this, only on Hopper.

howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 5, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: e7282fa
Pull-Request: pytorch#3959
- None: No Hadamard transform (default)
- "v": Apply Hadamard to V only
- "qkv": Apply Hadamard to Q, K, and V
fuse_rope_using_torch_compile: If True, fuse RoPE + quantization + SDPA into optimized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user is using torch.compile, do we still need this? shouldn't we just pick the best available fusion in torch.compile land?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we'll still be needing this flag, since we need to use the pre-grad IR pass, which is a global setting, and there is no way to selectively use the pre-grad IR pass for just the selected module (as far as I know but I could totally be wrong). It's one of the reasons why I first chose to have the behind the scenes compile, so we can set the global flag, compile that one module, and reset it to default.

What I have now is we return a compile backend, and in the user warning, we tell them they must run torch.compile(module, backend=returned_backend) to compile that module with the specific backend that they need (which is different depending on FA3/FA4). We still need the flag, and the user cannot just call torch.compile on the whole model with us automatically picking the best available fusion for that specific module, but I think it's probably still better than having a magic compile the user doesn't know about?


def apply_low_precision_attention(
model: nn.Module,
config: Optional[LowPrecisionAttentionConfig] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts about removing the config and just adding arguments here? will be more discoverable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that makes sense to me. I originally had it as a config because I want it to be usable as we start adding more and more parameters. But, I think for what it is now, it makes more sense to have it as arguments. I have made the change.

howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 5, 2026
Introduces a new `torchao.prototype.attention` module providing FP8
attention via FlashAttention 3. The default path monkey-patches
`F.scaled_dot_product_attention` at call time — no `torch.compile`
needed. Flash attention activation is managed internally by the wrapper.

Key components:
- `apply_low_precision_attention()`: user-facing API
- `shared_utils/`: backend-agnostic wrapper, setup, and SDPA logic
- `fp8_fa3/`: FA3-specific thin wrappers
- `quantization/`: fused Triton FP8 QKV quantization kernels
- Causal mask detection for HuggingFace model compatibility

ghstack-source-id: e7282fa
Pull-Request: pytorch#3959
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@howardzhang-cv howardzhang-cv added the module: inference quantize_ api inference flow label Mar 6, 2026
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@howardzhang-cv howardzhang-cv changed the base branch from gh/howardzhang-cv/23/base to main March 9, 2026 17:24
@howardzhang-cv howardzhang-cv merged commit c32dea9 into main Mar 9, 2026
36 checks passed
@howardzhang-cv howardzhang-cv deleted the gh/howardzhang-cv/23/head branch March 9, 2026 22:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants