Skip to content

Add FA4 monkey-patch path for low-precision attention#3960

Draft
howardzhang-cv wants to merge 16 commits intogh/howardzhang-cv/24/basefrom
gh/howardzhang-cv/24/head
Draft

Add FA4 monkey-patch path for low-precision attention#3960
howardzhang-cv wants to merge 16 commits intogh/howardzhang-cv/24/basefrom
gh/howardzhang-cv/24/head

Conversation

@howardzhang-cv
Copy link
Contributor

@howardzhang-cv howardzhang-cv commented Feb 27, 2026

Stack from ghstack (oldest at bottom):

Summary

  • Added FA4 FP8 low-precision attention with simple SDPA replacement path, mirroring the FA3 design
  • New elementary block: fp8_fa4_sdpa — a direct drop-in replacement for F.scaled_dot_product_attention using the FA4 backend. Reuses the shared FP8 quantization kernels.
  • Simple wrapper support via apply_low_precision_attention with AttentionBackend.FP8_FA4 — no torch.compile required.
  • FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
  • Added _is_blackwell() and _is_fa4_available() hardware detection utilities
  • Added FA4 backend config and numerical accuracy tests (eager SDPA and model-level API)

New Files

  • fp8_fa4/init.py: Exports fp8_fa4_sdpa
  • fp8_fa4/attention.py: fp8_fa4_sdpa elementary block
  • fp8_fa4/setup.py: Thin wrapper calling setup_fp8_backend with FA4 parameters

Modified Files

  • config.py: Added FP8_FA4 to AttentionBackend enum
  • utils.py: Added _is_blackwell(), _is_fa4_available(), FA4 support in _get_available_backend() and _check_backend_available()
  • api.py: Added FA4 dispatch path
  • test_fp8_attention.py: Added FA4 backend config, numerical accuracy tests for FA4

Test Plan

python -m pytest test/prototype/attention/test_fp8_attention.py -v

Example Usage

  from torchao.prototype.attention import (
      AttentionBackend,
      LowPrecisionAttentionConfig,
      apply_low_precision_attention,
  )

  model = MyModel()

  # Simple SDPA replacement using FA4 — no torch.compile needed
  config = LowPrecisionAttentionConfig(backend=AttentionBackend.FP8_FA4)
  model = apply_low_precision_attention(model, config)

  # Flash activation is handled internally by the wrapper
  output = model(inputs)

Results

Single-Layer Results

Results directly comparing FA4 SDPA versus FA4 fp8 SDPA (including quantization time):
image

Llama3 Model Results

Results comparing Llama3 model with FA4 SDPA versus Llama3 using the FA4 fp8 wrapper. Does not use RoPE fusion.
Perplexity: 6.19 -> 6.25
image

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3960

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 8 Pending

As of commit 6611347 with merge base 5ebd10d (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 27, 2026
@howardzhang-cv howardzhang-cv marked this pull request as draft February 27, 2026 08:09
@howardzhang-cv howardzhang-cv added the topic: new feature Use this tag if this PR adds a new feature label Feb 27, 2026
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 27, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)


ghstack-source-id: fa10283
Pull-Request: pytorch#3960
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 28, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)


ghstack-source-id: fa10283
Pull-Request: pytorch#3960
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 28, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)


ghstack-source-id: fa10283
Pull-Request: pytorch#3960
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@howardzhang-cv
Copy link
Contributor Author

howardzhang-cv commented Feb 28, 2026

So this FA4 fp8 low precision backend currently has some issues randomly with the Llama 3 model. For some reason, we get NaNs in the quantized QKV tensors (specifically in torchao/prototype/attention/shared_utils/attention.py in the _fp8_sdpa function, the q_fp8 tensor seems to be the first to become NaN). This issue goes away when we replace the triton kernel with simple PyTorch ops. It also goes away with --compile, and only happens on Blackwell for some reason. Not sure why this happens, I spent a bunch of time trying to fix this, but couldn't find the heart of the issue.

howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Feb 28, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: af3a98b
Pull-Request: pytorch#3960
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 2, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: af3a98b
Pull-Request: pytorch#3960
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 2, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: 47eb616
Pull-Request: pytorch#3960
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 3, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: 05d86df
Pull-Request: pytorch#3960
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 3, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: 05d86df
Pull-Request: pytorch#3960
[ghstack-poisoned]
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 3, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: 05d86df
Pull-Request: pytorch#3960
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 5, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: 05d86df
Pull-Request: pytorch#3960
howardzhang-cv added a commit to howardzhang-cv/ao that referenced this pull request Mar 5, 2026
Adds FlashAttention 4 backend support for the monkey-patch SDPA path.
FA4 supports both Hopper (SM 9.x) and Blackwell (SM 10.x) hardware
with the flash_attn.cute.interface package.

Key additions:
- FP8_FA4 enum value in AttentionBackend
- _is_blackwell(), _is_fa4_available() hardware/library checks
- FA4 dispatch in apply_low_precision_attention
- fp8_fa4/ directory with fp8_fa4_sdpa entry point
- FA4 backend config in test suite with eager probe guard
- RoPE fusion placeholder (fuse_rope=True raises NotImplementedError)

ghstack-source-id: 05d86df
Pull-Request: pytorch#3960
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant