Skip to content

Added new API for low precision fp8 attention using FA3#3857

Merged
howardzhang-cv merged 45 commits intomainfrom
gh/howardzhang-cv/16/head
Mar 9, 2026
Merged

Added new API for low precision fp8 attention using FA3#3857
howardzhang-cv merged 45 commits intomainfrom
gh/howardzhang-cv/16/head

Conversation

@howardzhang-cv
Copy link
Contributor

@howardzhang-cv howardzhang-cv commented Feb 11, 2026

Stack from ghstack (oldest at bottom):


Summary

  • Added RoPE fusion compile path for FA3 FP8 low-precision attention (fuse_rope=True)
  • New elementary block: fp8_fa3_rope_sdpa — fused RoPE + FP8 quantization + low-precision SDPA
  • New Triton kernel for fused RoPE + QKV quantization with layout transpose ([B,S,H,D] → [B,H,S,D])
  • RoPE fusion method: Custom Inductor backend that traces the FX graph, detects RoPE + SDPA patterns (NeoX half-split and FLUX interleaved formats), and replaces them with fp8_fa3_rope_sdpa custom ops. Falls back to fp8_fa3_sdpa for SDPA
    nodes without RoPE.
  • Causal mask detection: Pre-flight forward pass identifies HuggingFace-style materialized causal masks so the fusion pass can strip them and use is_causal=True instead.
  • Added compiled model wrapper (_FP8FlashAttentionCompiledWrapper) with @torch._dynamo.disable to prevent re-tracing.
  • Added RoPE SDPA numerical accuracy tests and fuse_rope parametrization on model-level tests.

New Files

  • shared_utils/fusion_utils.py: Shared FX graph fusion pass — RoPE pattern detection, SDPA detection, transpose unwrapping, parameterized graph surgery
  • shared_utils/custom_ops.py: Factory functions to register backend-specific custom ops with register_fake, and helpers to build fusion passes and compile functions
  • fp8_fa3/fusion_pass.py: FA3-specific custom op registration, rope_sdpa_fusion_pass, and compile_with_fp8_fusion entry point
  • quantization/triton_rope_qkv_quantization.py: Fused RoPE + QKV FP8 quantization Triton kernel

Modified Files

  • shared_utils/attention.py: Added _fp8_rope_sdpa shared implementation
  • shared_utils/wrapper.py: Added _FP8FlashAttentionCompiledWrapper
  • shared_utils/setup.py: Added compile path routing via compile_fn parameter, moved detect_causal_mask to fusion_utils.py
  • quantization/quantization.py: Added _fp8_rope_sdpa_quantize
  • fp8_fa3/attention.py: Added fp8_fa3_rope_sdpa elementary block
  • fp8_fa3/setup.py: Passes compile_with_fp8_fusion as compile_fn
  • test_fp8_attention.py: Added TestFP8RopeSDPANumericalAccuracy, fuse_rope parametrization on model test

Test Plan

python -m pytest test/prototype/attention/test_fp8_attention.py -v

Example Usage

  from torchao.prototype.attention import (
      AttentionBackend,
      LowPrecisionAttentionConfig,
      apply_low_precision_attention,
  )

  model = MyModel()

  # Compile path with RoPE fusion
  config = LowPrecisionAttentionConfig(
      backend=AttentionBackend.FP8_FA3,
      fuse_rope=True,
  )
  model = apply_low_precision_attention(model, config)

  # Flash activation is handled internally by the wrapper
  output = model(inputs)

Results

Single-Layer Results

Results directly comparing FA3 SDPA versus FA3 fp8 SDPA (including quantization time):
image

Llama3 Model Results

Results comparing Llama3 model with FA3 SDPA versus Llama3 using the FA3 fp8 wrapper. Uses RoPE fusion.
Perplexity: 6.19 -> 6.24
image

[ghstack-poisoned]
[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Feb 11, 2026
Summary: Added new folder for low precision attention APIs in torchao/attention

Test Plan: python test/attention/test_fp8_fa3.py

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: a86eedf
Pull-Request: #3857
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3857

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 94d9200 with merge base aad1018 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 11, 2026
[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Feb 11, 2026
Summary: Added new folder for low precision attention APIs in torchao/attention

Test Plan: python test/attention/test_fp8_fa3.py

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 33c50a4
Pull-Request: #3857
@howardzhang-cv howardzhang-cv added the topic: new feature Use this tag if this PR adds a new feature label Feb 11, 2026
@howardzhang-cv
Copy link
Contributor Author

Assuming this is gonna take a bit of back and forth to land to get it just right, since this is a user-facing change. @drisspg please take a look and let me know if the API looks okay

original_forward = model.forward

def wrapped_forward(*args, **kwargs):
with _fp8_fa3_attention_context(config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we talked about this before, what's the thoughts on this v.s. following other torchao APIs, e.g.

q = Float8Tensor.from_hp(q)
k = Float8Tensor.from_hp(k)
v = Float8Tensor.from_hp(v)
# dispatch to `_fp8_fa3_sdpa` in Float8Tensor implementation
F.scaled_dot_product_attention(q, k, v, is_causal=True)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, ideally we follow other torchao APIs but it's a bit difficult for attention. This first backend/recipe works well because it is a simple replacement of F.scaled_dot_product_attention, but as we add other features such as RoPE fusion or RoPE + Hadamard it becomes much more difficult, since it becomes model specific (i.e. if there is RoPE followed by SDPA, replace with RoPE + quantization fused kernel into fp8 SDPA). To future-proof other attention backends and recipes, I think it's better to have it as its own separate APIs. What are your thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the fused kernel replacement should happen in inductor? or is there a reason why we have to hand replace these in eager mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inductor unfortunately does not fuse RoPE with the quantization kernel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I mean is this something that's not possible to do in inductor or just something that does not exist right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too sure, I feel like it should be possible in inductor? But it definitely is something that does not exist right now. Could look into it for the future, but for now, I think it would be simpler just to move this into prototype and have it be available through that. Will work on moving this over.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 I've gotten the rope fusion to work by using the inductor path. I've moved everything to a prototype folder. Due to the way attention works, I think it's better if it's own API, as it doesn't fit cleanly with others for now (especially with rope fusion and hadamard and stuff requiring inductor path). What do you think?

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
if q.shape[3] != k.shape[3]:
raise ValueError(f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}")

if torch.compiler.is_compiling():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a test comparing numerics between these two paths

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numerics are exactly the same between the two paths. I tested the runtime and it seems like the triton implementation was much faster (~100 ms difference on llama3 model with 124k sequence length)


from torchao.prototype.attention import apply_low_precision_attention

model = MyTransformer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should specify what kind of syntax is automatically converted to low precision here. F.SDPA? something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

)

def fp8_attention_backend(gm, example_inputs):
"""Custom Inductor backend that applies the RoPE + FP8 fusion pass."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should specify what kind of user syntax counts as "attention"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Different backends have different hardware requirements and capabilities.
"""

FP8_FA3 = "fa3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the string match the enum value name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

restore_flash_attention_impl()


def apply_low_precision_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be explicit that parts of torch.compile is used to do the logic swap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, additionally added a warning as well to be even more explicit

howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 23, 2026
Summary: Added new folder for low precision attention APIs in torchao/prototype/attention

Test Plan: python test/prototype/attention/test_fp8_fa3.py

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 25b6a97
Pull-Request: pytorch#3857
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 24, 2026
Summary: Added new folder for low precision attention APIs in torchao/prototype/attention

Test Plan: python test/prototype/attention/test_fp8_fa3.py

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 25b6a97
Pull-Request: pytorch#3857
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 28, 2026
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.

Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
  parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
  register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
  quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}

ghstack-source-id: 4e11b16
Pull-Request: pytorch#3857
@howardzhang-cv howardzhang-cv requested a review from vkuzo March 2, 2026 19:28
[ghstack-poisoned]
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 2, 2026
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.

Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
  parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
  register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
  quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}

ghstack-source-id: 7cf571a
Pull-Request: pytorch#3857
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 3, 2026
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.

Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
  parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
  register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
  quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}

ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
[ghstack-poisoned]
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 3, 2026
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.

Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
  parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
  register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
  quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}

ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 5, 2026
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.

Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
  parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
  register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
  quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}

ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 5, 2026
Adds the compile path (fuse_rope=True) which compiles the model with a
custom Inductor backend that fuses RoPE + FP8 quantization + SDPA into
optimized kernels via FX graph pattern matching.

Key additions:
- shared_utils/fusion_utils.py: FX graph RoPE/SDPA pattern detection and
  parameterized graph surgery (NeoX + FLUX interleaved RoPE variants)
- shared_utils/custom_ops.py: custom op registration factory with
  register_fake for torch.compile traceability
- fp8_fa3/fusion_pass.py: FA3-specific custom ops and compile helper
- quantization/triton_rope_qkv_quantization.py: fused RoPE + FP8
  quantization Triton kernels with layout transpose
- _FP8FlashAttentionCompiledWrapper with @dynamo.disable boundary
- _fp8_rope_sdpa shared implementation + fp8_fa3_rope_sdpa entry point
- Tests parametrized over fuse_rope={True, False}

ghstack-source-id: 35e5edf
Pull-Request: pytorch#3857
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@howardzhang-cv howardzhang-cv added the module: inference quantize_ api inference flow label Mar 6, 2026
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@howardzhang-cv howardzhang-cv changed the base branch from gh/howardzhang-cv/16/base to main March 9, 2026 17:30
@howardzhang-cv howardzhang-cv merged commit 2ec82b3 into main Mar 9, 2026
36 of 40 checks passed
@howardzhang-cv howardzhang-cv deleted the gh/howardzhang-cv/16/head branch March 9, 2026 22:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants