Skip to content

Conversation

@akasharidas
Copy link

  • Add a Transformer Engine wrapper to enable calling Flash Attention on ROCm.
  • Add a config and launch script to demonstrate training Fuji-v2-70B on a single MI300 node.

@akasharidas akasharidas marked this pull request as ready for review April 16, 2025 19:22
@akasharidas akasharidas requested review from a team, markblee and ruomingp as code owners April 16, 2025 19:22
@ruomingp ruomingp requested a review from kelvin-zou April 16, 2025 19:45
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

"importlab==0.8.1", # breaks pytype on 0.8
"jax==0.4.38",
"jaxlib==0.4.38",
"jax==0.4.35",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional to downgrade the jax version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing they are relying on a fixed version on transformer_engine, which has to be compatible with Jax version.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because 0.4.35 is the closest available ROCm JAX release (https://github.com/ROCm/jax/releases). The next version available is 0.5 which we are open to upgrading to whenever axlearn decides to use it.

_allow_explicit_bias = True


class ROCmTransformerEngineFlashAttention(BaseFlashAttention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a unittest for this layer?

--module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \
--trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \
--jax_backend=gpu \
--mesh_selector="amd-mi300-single-node" \ No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we reproduce your experiments?

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

from jax.ad_checkpoint import checkpoint_name
from jax.experimental import pallas as pl
try:
from transformer_engine.jax.flax.transformer import DotProductAttention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to extract the lower level API from transformer engine? We waited to integrate cudnn from Jax native exactly to avoid dependency on transformer_engine, which make Jax upgrade and version control much more difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, we are also open to plumb through xla customized call

@@ -0,0 +1,208 @@
# Copyright © 2023 Apple Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright © 2023 Apple Inc.
# Copyright © 2025 Apple Inc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And feel free to make it AMD if you intend to maintain it.

gpu=[
GPUDecoding,
# For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest.
ROCmTransformerEngineFlashAttention,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is ROCmTransformerEngineFlashAttention code path selected here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check in is_supported which will throw unsupported if not on ROCm backend. Let me know if this approach is fine

Currently tested on MI300. To run tests in parallel on a multi-GPU machine, use this:
```
PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file path needs to be revised to the current file.

@changlan
Copy link
Member

Closing this PR due to inactivity. Feel free to reopen if you would like to continue the work.

@changlan changlan closed this Jul 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants