Skip to content

SageAttention AMD version#332

Open
eliotwang wants to merge 5 commits intothu-ml:mainfrom
eliotwang:amd
Open

SageAttention AMD version#332
eliotwang wants to merge 5 commits intothu-ml:mainfrom
eliotwang:amd

Conversation

@eliotwang
Copy link

This is a rocWMMA-based implementation of SageAttention. Its interface uses the rocWMMA::fragment API, rather than a PTX-based implementation like CUDA.

Performance: taking CogVideoX-2B as an example, it delivers ~30% better end-to-end performance on the 9070 compared to SageAttention V1.

@woct0rdho
Copy link

woct0rdho commented Dec 20, 2025

Just curious, how does the performance compared to the Triton kernel (maybe needs some tuning of block sizes) on AMD GPU? Ah I see, you compared with SageAttention V1.

@eliotwang
Copy link
Author

Just curious, how does the performance compared to the Triton kernel (maybe needs some tuning of block sizes) on AMD
The original SageAttention v2 can only run on NVIDIA GPUs, and the implementation available for AMD GPUs is the Triton-based one. Our goal was to develop a ROCm implementation based on SageAttention v2. The test results provided above were obtained using the default configuration of Triton with no tuning performed.

@jammm
Copy link

jammm commented Dec 26, 2025

This is great! A few high-level comments before a deeper review can happen (note that I'm not the maintainer of this repo, so feel free to ignore):

  1. Lots of code seems (potentially) duplicated. For example, all the reduction ops should be usable as-is, with the only difference being FINAL_MASK 64 bit on gfx9 vs. 32 bit on gfx11/12. It would be good to try and port in-place within the .cu code so any changes on the CUDA code will benefit both. You can do this with #ifdef __HIPCC__ for example. This was done on llama.cpp for example. You also don't need hipLaunchKernelGGL and can instead use the same <<< >>> notation like CUDA.
  2. This should technically work on gfx12 too, and rocWMMA supports it already, so you should be able to use it as-is.
  3. "MFMA" is gfx9 specific, while "WMMA" is gfx11/12 specific. I would rename things accordingly, or simply unify them both under a single "mma" name prefix or something like that.
  4. I noticed the rocWMMA tile shape was 16x16x32 for gfx11. Have you also tried the 16 and 64 K dims? Seems like rocWMMA supports any multiple of K for the fragments.
  5. (nitpick) would be nice to have all the comments written in English, as it can be helpful for folks who don't understand kanji (like me).

@0xDELUXA
Copy link

This is great! A few high-level comments before a deeper review can happen (note that I'm not the maintainer of this repo, so feel free to ignore)

@jammm I'm surprised to see an AMD dev here. Hope @eliotwang considers your points so us AMD users can benefit from SageAttention V2 in the future.

@patientx
Copy link

can this be made to work with rdna2 too ? previously sage-attention 1 worked with rdna2 but not anymore with 7.x

@jammm
Copy link

jammm commented Dec 26, 2025

can this be made to work with rdna2 too ? previously sage-attention 1 worked with rdna2 but not anymore with 7.x

rocWMMA is specific to those GPUs which support WMMA or MFMA, so probably not.

@0xDELUXA
Copy link

0xDELUXA commented Dec 26, 2025

rocWMMA is specific to those GPUs which support WMMA or MFMA, so probably not.

Tried it on my gfx1200 with ROCm 7, but as I have seen, it includes the rocwmma_coop.hpp header, which is deprecated and has been removed in recent versions.

To be clear: This can’t possibly run on Windows, right? Not because of rocminfo (we have hipinfo there) but because of rocWMMA itself. It doesn’t come bundled with the TheRock wheels, and I couldn’t find any way to install it on Windows, apart from cloning it from the rocm-libraries repo, which doesn’t really sound right.

@jammm
Copy link

jammm commented Dec 27, 2025

rocWMMA is specific to those GPUs which support WMMA or MFMA, so probably not.

Tried it on my gfx1200 with ROCm 7, but as I have seen, it includes the rocwmma_coop.hpp header, which is deprecated and has been removed in recent versions.

To be clear: This can’t possibly run on Windows, right? Not because of rocminfo (we have hipinfo there) but because of rocWMMA itself. It doesn’t come bundled with the TheRock wheels, and I couldn’t find any way to install it on Windows, apart from cloning it from the rocm-libraries repo, which doesn’t really sound right.

TheRock does bundle rocWMMA now (not for all archs though; see https://github.com/ROCm/TheRock/blob/main/cmake/therock_amdgpu_targets.cmake for rocWMMA listed under EXCLUDE_TARGET_PROJECTS in specific archs, which implies that rocWMMA is excluded from building for that arch), but not sure about this coop header.

@0xDELUXA
Copy link

0xDELUXA commented Dec 27, 2025

TheRock does bundle rocWMMA now, but not sure about this coop header.

Oh. I couldn’t find a folder/file whose name contains "rocwmma" anywhere on my PC using a program like WizFile. How am I supposed to locate it?
As for that header, I could only find it in the deprecated rocWMMA repository, in the "release/rocm-rel-6.4.2.2" branch. The next version, 7.0, as well as subsequent versions, no longer include it.

@jammm
Copy link

jammm commented Dec 27, 2025

Oh. I couldn’t find a folder named “rocwmma” anywhere on my PC using a program like WizFile. How am I supposed to locate it? As for that header, I could only find it in the deprecated rocWMMA repository, in the "release/rocm-rel-6.4.2.2" branch. The next version, 7.0, as well as subsequent versions, no longer include it.

It was added a while ago ROCm/TheRock#1938. rocWMMA has been moved to rocm-libraries now.
You should find it in <venv>/.../site-packages/_rocm_sdk_core/include/rocwmma after running rocm-sdk init. On the latest nightly wheels (assuming your GPU arch supports it. Otherwise you won't find it).

@0xDELUXA
Copy link

0xDELUXA commented Dec 27, 2025

It was added a while ago ROCm/TheRock#1938. rocWMMA has been moved to rocm-libraries now.
You should find it in <venv>/.../site-packages/_rocm_sdk_core/include/rocwmma after running rocm-sdk init. On the latest nightly wheels (assuming your GPU arch supports it. Otherwise you won't find it).

I see, thanks for the explanation. I suppose the gfx1200 should be supported, so I’ll take another look.

Yeah, so rocm-sdk init didn’t run for me because it requires rocm-sdk[devel], and python -m pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ torch torchvision torchaudio doesn’t install it automatically.
So I tried running python -m pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ rocm-sdk[devel], but now I’m facing another problem:

ERROR: Cannot install rocm-sdk[devel]==6.5.0rc20250601, rocm-sdk[devel]==6.5.0rc20250602, rocm-sdk[devel]==6.5.0rc20250603, rocm-sdk[devel]==6.5.0rc20250604, rocm-sdk[devel]==6.5.0rc20250605, rocm-sdk[devel]==6.5.0rc20250606, rocm-sdk[devel]==6.5.0rc20250607, rocm-sdk[devel]==6.5.0rc20250608 and rocm-sdk[devel]==6.5.0rc20250609 because these package versions have conflicting dependencies.

The conflict is caused by:
    rocm-sdk[devel] 6.5.0rc20250609 depends on rocm-sdk-core==6.5.0rc20250609
    rocm-sdk[devel] 6.5.0rc20250608 depends on rocm-sdk-core==6.5.0rc20250608
    rocm-sdk[devel] 6.5.0rc20250607 depends on rocm-sdk-core==6.5.0rc20250607
    rocm-sdk[devel] 6.5.0rc20250606 depends on rocm-sdk-core==6.5.0rc20250606
    rocm-sdk[devel] 6.5.0rc20250605 depends on rocm-sdk-core==6.5.0rc20250605
    rocm-sdk[devel] 6.5.0rc20250604 depends on rocm-sdk-core==6.5.0rc20250604
    rocm-sdk[devel] 6.5.0rc20250603 depends on rocm-sdk-core==6.5.0rc20250603
    rocm-sdk[devel] 6.5.0rc20250602 depends on rocm-sdk-core==6.5.0rc20250602
    rocm-sdk[devel] 6.5.0rc20250601 depends on rocm-sdk-core==6.5.0rc20250601

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

So basically, there isn’t yet a rocm-sdk[devel] package on Windows for rocm-sdk-core > 6.5.0rc20250609. For example, mine is 7.11.0a20251226.

@NeuralFault
Copy link

It was added a while ago ROCm/TheRock#1938. rocWMMA has been moved to rocm-libraries now.
You should find it in <venv>/.../site-packages/_rocm_sdk_core/include/rocwmma after running rocm-sdk init. On the latest nightly wheels (assuming your GPU arch supports it. Otherwise you won't find it).

I see, thanks for the explanation. I suppose the gfx1200 should be supported, so I’ll take another look.

Yeah, so rocm-sdk init didn’t run for me because it requires rocm-sdk[devel], and python -m pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ torch torchvision torchaudio doesn’t install it automatically.
So I tried running python -m pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ rocm-sdk[devel], but now I’m facing another problem:

ERROR: Cannot install rocm-sdk[devel]==6.5.0rc20250601, rocm-sdk[devel]==6.5.0rc20250602, rocm-sdk[devel]==6.5.0rc20250603, rocm-sdk[devel]==6.5.0rc20250604, rocm-sdk[devel]==6.5.0rc20250605, rocm-sdk[devel]==6.5.0rc20250606, rocm-sdk[devel]==6.5.0rc20250607, rocm-sdk[devel]==6.5.0rc20250608 and rocm-sdk[devel]==6.5.0rc20250609 because these package versions have conflicting dependencies.

The conflict is caused by:
    rocm-sdk[devel] 6.5.0rc20250609 depends on rocm-sdk-core==6.5.0rc20250609
    rocm-sdk[devel] 6.5.0rc20250608 depends on rocm-sdk-core==6.5.0rc20250608
    rocm-sdk[devel] 6.5.0rc20250607 depends on rocm-sdk-core==6.5.0rc20250607
    rocm-sdk[devel] 6.5.0rc20250606 depends on rocm-sdk-core==6.5.0rc20250606
    rocm-sdk[devel] 6.5.0rc20250605 depends on rocm-sdk-core==6.5.0rc20250605
    rocm-sdk[devel] 6.5.0rc20250604 depends on rocm-sdk-core==6.5.0rc20250604
    rocm-sdk[devel] 6.5.0rc20250603 depends on rocm-sdk-core==6.5.0rc20250603
    rocm-sdk[devel] 6.5.0rc20250602 depends on rocm-sdk-core==6.5.0rc20250602
    rocm-sdk[devel] 6.5.0rc20250601 depends on rocm-sdk-core==6.5.0rc20250601

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

So basically, there isn’t yet a rocm-sdk[devel] package on Windows for rocm-sdk-core > 6.5.0rc20250609. For example, mine is 7.11.0a20251226.

Add --pre to your pip arguments?
It's there but requires that argument for the latest builds which are pre-release status in the index

@0xDELUXA
Copy link

0xDELUXA commented Dec 27, 2025

We have rocm-sdk-devel-7.11.0a20251226, but somehow rocm-sdk init still doesn't work, it needs rocm-sdk[devel], so a package from https://rocm.nightlies.amd.com/v2/gfx120X-all/rocm-sdk/

Nevermind:

(venv) E:\sdnext>rocm-sdk init
Devel contents expanded to 'E:\sdnext\venv\Lib\site-packages\_rocm_sdk_devel'

Now I actually have rocwmma in ...\venv\Lib\site-packages\_rocm_sdk_devel\include\rocwmma

@0xDELUXA
Copy link

0xDELUXA commented Dec 27, 2025

Once again got:

In file included from E:\sdnext\SageAttention\csrc\qattn\rocm\launch_sgattn.hip:3:
E:\sdnext\SageAttention\csrc\qattn\rocm\sgattn.hip:36:10: fatal error: 'rocwmma/rocwmma_coop.hpp' file not found
   36 | #include <rocwmma/rocwmma_coop.hpp>
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~

Because of:

#include <rocwmma/rocwmma_coop.hpp>

This file is no longer included with rocWMMA 2 for ROCm 7.
I assume this PR was created using rocWMMA 1.7 for ROCm 6.4 or earlier.

Additionally, there are several other compatibility issues stemming from AMD’s changes to rocWMMA 2.

@rwfsmith
Copy link

There's a guide here on migrating to 2.0, if that helps

https://rocm.docs.amd.com/projects/rocWMMA/en/latest/conceptual/migration-guide.html

@eliotwang
Copy link
Author

There's a guide here on migrating to 2.0, if that helps

https://rocm.docs.amd.com/projects/rocWMMA/en/latest/conceptual/migration-guide.html

In fact, I tried implementing SageAttention with rocWMMA 2.0. However, frustratingly, using CogVideoX-2B as an example: on the 9070, the end-to-end performance dropped from being ~30% faster than SageAttention v1 to essentially on par with SageAttention v1; on MI300X, the end-to-end performance slightly regressed. The reason for the slowdown is not clear yet. From a performance perspective, I’d still recommend basing SageAttention on rocWMMA 1.7.

@0xDELUXA
Copy link

0xDELUXA commented Dec 29, 2025

In fact, I tried implementing SageAttention with rocWMMA 2.0. However, frustratingly, using CogVideoX-2B as an example: on the 9070, the end-to-end performance dropped from being ~30% faster than SageAttention v1 to essentially on par with SageAttention v1; on MI300X, the end-to-end performance slightly regressed. The reason for the slowdown is not clear yet. From a performance perspective, I’d still recommend basing SageAttention on rocWMMA 1.7.

I see. Sadly, I'm Windows-only, and we don't have ROCm 6.4 from TheRock there. The earliest available version is 7, which comes with rocWMMA 2.0.
On Windows, Flash Attention is supported, but no Sage v1, because of things like import triton not working (AFAIK). I’d definitely give your Sage v2 for rocWMMA 2 a try there. I think it would be an improvement (like 2x the perf.) over what we currently have. Unfortunately, I can’t really code it myself.

I'm kind of curious what an AMD dev would say about this performance regression, because theoretically it shouldn't happen.

@rwfsmith
Copy link

I'm kind of curious what an AMD dev would say about this performance regression, because theoretically it shouldn't happen.

maybe @jammm can help out with that :)

@eliotwang
Copy link
Author

eliotwang commented Dec 30, 2025

In fact, I tried implementing SageAttention with rocWMMA 2.0. However, frustratingly, using CogVideoX-2B as an example: on the 9070, the end-to-end performance dropped from being ~30% faster than SageAttention v1 to essentially on par with SageAttention v1; on MI300X, the end-to-end performance slightly regressed. The reason for the slowdown is not clear yet. From a performance perspective, I’d still recommend basing SageAttention on rocWMMA 1.7.

I see. Sadly, I'm Windows-only, and we don't have ROCm 6.4 from TheRock there. The earliest available version is 7, which comes with rocWMMA 2.0. On Windows, Flash Attention is supported, but no Sage v1, because of things like import triton not working (AFAIK). I’d definitely give your Sage v2 for rocWMMA 2 a try there. I think it would be an improvement (like 2x the perf.) over what we currently have. Unfortunately, I can’t really code it myself.

I'm kind of curious what an AMD dev would say about this performance regression, because theoretically it shouldn't happen.

Okay, this is SageAttention implemented using rocWMMA 2.0 (gfx12-only support and not clean code). Hope it helps you.https://github.com/eliotwang/sgattn_rocwmma2.0

@0xDELUXA
Copy link

Okay, this is SageAttention implemented using rocWMMA 2.0 (gfx12-only support and not clean code). Hope it helps you.https://github.com/eliotwang/sgattn_rocwmma2.0

WoW

Huge thanks for taking the time to do this!

@sorasoras
Copy link

I guess this couldn't work for RDNA3 yet.

@0xDELUXA
Copy link

0xDELUXA commented Dec 30, 2025

I was able to build Sage 2 with rocWMMA 2 on Windows (ROCm 7) using #332 (comment), but cosine similarity is too low.

@eliotwang
Copy link
Author

I guess this couldn't work for RDNA3 yet.

Yeah, according to the document, only RDNA 12 support FP8. https://rocm.docs.amd.com/projects/rocWMMA/en/latest/api-reference/api-reference-guide.html

@sorasoras
Copy link

I guess this couldn't work for RDNA3 yet.

Yeah, according to the document, only RDNA 12 support FP8. https://rocm.docs.amd.com/projects/rocWMMA/en/latest/api-reference/api-reference-guide.html

I guess it could work like ampere with int8 and int4 as sageattention2. If it can utilized int4, there would be quite a bit improvement as well.

@0xDELUXA
Copy link

I guess this couldn't work for RDNA3 yet.

You can try this: thu-ml/SpargeAttn#108

@eliotwang
Copy link
Author

This is great! A few high-level comments before a deeper review can happen (note that I'm not the maintainer of this repo, so feel free to ignore):

1. Lots of code seems (potentially) duplicated. For example, all the reduction ops should be usable as-is, with the only difference being FINAL_MASK 64 bit on gfx9 vs. 32 bit on gfx11/12. It would be good to try and port in-place within the .cu code so any changes on the CUDA code will benefit both. You can do this with `#ifdef __HIPCC__` for example. This was done on llama.cpp for example. You also don't need `hipLaunchKernelGGL` and can instead use the same `<<< >>>` notation like CUDA.

2. This should technically work on gfx12 too, and rocWMMA supports it already, so you should be able to use it as-is.

3. "MFMA" is gfx9 specific, while "WMMA" is gfx11/12 specific. I would rename things accordingly, or simply unify them both under a single "mma" name prefix or something like that.

4. I noticed the rocWMMA tile shape was 16x16x32 for gfx11. Have you also tried the 16 and 64 K dims? Seems like rocWMMA supports any multiple of K for the fragments.

5. (nitpick) would be nice to have all the comments written in English, as it can be helpful for folks who don't understand kanji (like me).

Thank you for your suggestions! I’ve made changes as much as possible based on your feedback and have partially unified the CUDA/ROCm code.

Regarding your suggestion to keep the kernel launch style consistent with CUDA: I think cudaFuncSetAttribute(kernel, ...) has special handling for kernel symbols on CUDA, while HIP does not provide an equivalent implicit conversion. Therefore, I’m keeping the original approach.

@ouco1986
Copy link

In fact, I tried implementing SageAttention with rocWMMA 2.0. However, frustratingly, using CogVideoX-2B as an example: on the 9070, the end-to-end performance dropped from being ~30% faster than SageAttention v1 to essentially on par with SageAttention v1; on MI300X, the end-to-end performance slightly regressed. The reason for the slowdown is not clear yet. From a performance perspective, I’d still recommend basing SageAttention on rocWMMA 1.7.

I see. Sadly, I'm Windows-only, and we don't have ROCm 6.4 from TheRock there. The earliest available version is 7, which comes with rocWMMA 2.0. On Windows, Flash Attention is supported, but no Sage v1, because of things like import triton not working (AFAIK). I’d definitely give your Sage v2 for rocWMMA 2 a try there. I think it would be an improvement (like 2x the perf.) over what we currently have. Unfortunately, I can’t really code it myself.
I'm kind of curious what an AMD dev would say about this performance regression, because theoretically it shouldn't happen.

Okay, this is SageAttention implemented using rocWMMA 2.0 (gfx12-only support and not clean code). Hope it helps you.https://github.com/eliotwang/sgattn_rocwmma2.0

Hello, I used the code you provided for the 9070ct graphics card in the Linux rocm7.11 environment and successfully compiled Sage 2. This is very exciting. But I always have memory errors when using Sage 2 attention in Comfyui. I don't know if you know what caused it. The following is the error message. Thank you very much.

loaded partially; 8125.29 MB usable, 8042.93 MB loaded, 8496.66 MB offloaded, 75.01 MB buffer reserved, lowvram patches: 0
0%| | 0/4 [00:00<?, ?it/s]Memory access fault by GPU node-1 (Agent handle: 0x42b46500) on address 0x7f6e64800000. Reason: Page not present or supervisor privilege.
Failed to write segment data to pipe: Broken pipe
GPU coredump: handler exited with error (status: 1)
GPU core dump failed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants