-
Notifications
You must be signed in to change notification settings - Fork 317
Add ROCm support #1112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add ROCm support #1112
Conversation
akasharidas
commented
Apr 16, 2025
- Add a Transformer Engine wrapper to enable calling Flash Attention on ROCm.
- Add a config and launch script to demonstrate training Fuji-v2-70B on a single MI300 node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -23,8 +23,8 @@ core = [ | |||
"absl-py==2.1.0", | |||
"chex==0.1.88", | |||
"importlab==0.8.1", # breaks pytype on 0.8 | |||
"jax==0.4.38", | |||
"jaxlib==0.4.38", | |||
"jax==0.4.35", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intentional to downgrade the jax version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am guessing they are relying on a fixed version on transformer_engine, which has to be compatible with Jax version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because 0.4.35 is the closest available ROCm JAX release (https://github.com/ROCm/jax/releases). The next version available is 0.5 which we are open to upgrading to whenever axlearn decides to use it.
@@ -985,3 +989,95 @@ class CuDNNGPUFlashAttentionWithExplicitBias(CuDNNGPUFlashAttention): | |||
""" | |||
|
|||
_allow_explicit_bias = True | |||
|
|||
|
|||
class ROCmTransformerEngineFlashAttention(BaseFlashAttention): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a unittest for this layer?
--module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \ | ||
--trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \ | ||
--jax_backend=gpu \ | ||
--mesh_selector="amd-mi300-single-node" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we reproduce your experiments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
@@ -43,6 +43,10 @@ | |||
) | |||
from jax.ad_checkpoint import checkpoint_name | |||
from jax.experimental import pallas as pl | |||
try: | |||
from transformer_engine.jax.flax.transformer import DotProductAttention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to extract the lower level API from transformer engine? We waited to integrate cudnn from Jax native exactly to avoid dependency on transformer_engine, which make Jax upgrade and version control much more difficult.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw, we are also open to plumb through xla customized call
@@ -0,0 +1,208 @@ | |||
# Copyright © 2023 Apple Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright © 2023 Apple Inc. | |
# Copyright © 2025 Apple Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And feel free to make it AMD if you intend to maintain it.
@@ -24,6 +25,7 @@ | |||
gpu=[ | |||
GPUDecoding, | |||
# For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest. | |||
ROCmTransformerEngineFlashAttention, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is ROCmTransformerEngineFlashAttention
code path selected here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a check in is_supported
which will throw unsupported if not on ROCm backend. Let me know if this approach is fine
|
||
Currently tested on MI300. To run tests in parallel on a multi-GPU machine, use this: | ||
``` | ||
PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file path needs to be revised to the current file.