Skip to content

Add ROCm support #1112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Add ROCm support #1112

wants to merge 7 commits into from

Conversation

akasharidas
Copy link

  • Add a Transformer Engine wrapper to enable calling Flash Attention on ROCm.
  • Add a config and launch script to demonstrate training Fuji-v2-70B on a single MI300 node.

@akasharidas akasharidas marked this pull request as ready for review April 16, 2025 19:22
@akasharidas akasharidas requested review from ruomingp, markblee and a team as code owners April 16, 2025 19:22
@ruomingp ruomingp requested a review from kelvin-zou April 16, 2025 19:45
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -23,8 +23,8 @@ core = [
"absl-py==2.1.0",
"chex==0.1.88",
"importlab==0.8.1", # breaks pytype on 0.8
"jax==0.4.38",
"jaxlib==0.4.38",
"jax==0.4.35",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional to downgrade the jax version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing they are relying on a fixed version on transformer_engine, which has to be compatible with Jax version.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because 0.4.35 is the closest available ROCm JAX release (https://github.com/ROCm/jax/releases). The next version available is 0.5 which we are open to upgrading to whenever axlearn decides to use it.

@@ -985,3 +989,95 @@ class CuDNNGPUFlashAttentionWithExplicitBias(CuDNNGPUFlashAttention):
"""

_allow_explicit_bias = True


class ROCmTransformerEngineFlashAttention(BaseFlashAttention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a unittest for this layer?

--module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \
--trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \
--jax_backend=gpu \
--mesh_selector="amd-mi300-single-node" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we reproduce your experiments?

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@@ -43,6 +43,10 @@
)
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import pallas as pl
try:
from transformer_engine.jax.flax.transformer import DotProductAttention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to extract the lower level API from transformer engine? We waited to integrate cudnn from Jax native exactly to avoid dependency on transformer_engine, which make Jax upgrade and version control much more difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, we are also open to plumb through xla customized call

@@ -0,0 +1,208 @@
# Copyright © 2023 Apple Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright © 2023 Apple Inc.
# Copyright © 2025 Apple Inc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And feel free to make it AMD if you intend to maintain it.

@@ -24,6 +25,7 @@
gpu=[
GPUDecoding,
# For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest.
ROCmTransformerEngineFlashAttention,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is ROCmTransformerEngineFlashAttention code path selected here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check in is_supported which will throw unsupported if not on ROCm backend. Let me know if this approach is fine


Currently tested on MI300. To run tests in parallel on a multi-GPU machine, use this:
```
PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file path needs to be revised to the current file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants