Skip to content

Add ROCm support #1112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
)
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import pallas as pl
try:
from transformer_engine.jax.flax.transformer import DotProductAttention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to extract the lower level API from transformer engine? We waited to integrate cudnn from Jax native exactly to avoid dependency on transformer_engine, which make Jax upgrade and version control much more difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, we are also open to plumb through xla customized call

except:
pass

from axlearn.common.attention_bias import (
NEG_INF,
Expand Down Expand Up @@ -985,3 +989,95 @@ class CuDNNGPUFlashAttentionWithExplicitBias(CuDNNGPUFlashAttention):
"""

_allow_explicit_bias = True


class ROCmTransformerEngineFlashAttention(BaseFlashAttention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a unittest for this layer?

"""Wraps Transformer Engine DotProductAttention to enable Flash Attention on ROCm.

Currently it has only been tested on standard Llama-like configs, and training only.
"""

_allow_explicit_bias = False

def is_supported(
self, *, query: Tensor, key: Tensor, value: Tensor, bias: BaseAttentionBias
) -> bool:
"""See `BaseFlashAttention.is_supported`."""
if not super().is_supported(query=query, key=key, value=value, bias=bias):
return False

try:
from transformer_engine.jax.flax.transformer import DotProductAttention
except ImportError:
return self._log_unsupported("could not import Transformer Engine")

if self.cfg.is_decoding:
return self._log_unsupported("currently only training has been tested with this attention implementation.")
else:
# cuDNN has no concept of block size. It only requires the length of query and
# key/value to be even.
if not self._check_block_size(query=query, key=key, block_size=2):
return False

if query.dtype not in (jnp.float16, jnp.bfloat16):
return self._log_unsupported(
f"{query.dtype=} is not supported. Only supports float16 and bfloat16."
)

if jax.default_backend() == "cpu":
return self._log_unsupported("we're on CPU emulation.")

head_dim = query.shape[-1]
if head_dim % 8 != 0:
return self._log_unsupported(f"{head_dim=} is not divisible by 8.")
if head_dim > 128:
return self._log_unsupported(f"{head_dim=} > 128")
_, sliding, explicit_bias = split(bias, CausalAttentionBias, SlidingWindowAttentionBias)

if sliding.has_value():
return self._log_unsupported("sliding window attention has not been tested currently.")
if explicit_bias.has_value() and not self._allow_explicit_bias:
return self._log_unsupported("we don't allow explicit bias at this stage.")

logging.info("Using %s.", self.name())
return True

@functools.partial(jax.jit, static_argnames=["self"])
def __call__(
self,
query: Tensor,
key: Tensor,
value: Tensor,
bias: BaseAttentionBias,
prng_key: Optional[Tensor] = None,
) -> Tensor:
"""See `BaseFlashAttention.__call__`."""

args = dict(
query=query,
key=repeat_kv_heads(query.shape[2], key),
value=repeat_kv_heads(query.shape[2], value),
)

_, _, num_query_heads, head_dim = query.shape
_, _, num_kv_heads, _ = key.shape

causal, _, _ = split(
bias, CausalAttentionBias, SlidingWindowAttentionBias
)
mask_type = "causal" if causal.has_value() else "no_mask"

rocm_te_dot_product_attention = DotProductAttention(
head_dim=head_dim,
num_attention_heads=num_query_heads,
num_gqa_groups=num_kv_heads,
attn_mask_type=mask_type,
attn_bias_type="no_bias",
attention_dropout=self.cfg.dropout_rate,
dtype=query.dtype,
qkv_layout="BSHD_BSHD_BSHD",
transpose_batch_sequence=False,
scale_factor=self.cfg.softmax_scale,
)

return rocm_te_dot_product_attention.apply({}, rngs={'dropout': prng_key}, **args)
2 changes: 2 additions & 0 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CuDNNGPUFlashAttention,
CuDNNGPUFlashAttentionWithExplicitBias,
PallasGPUFlashAttention,
ROCmTransformerEngineFlashAttention,
)
from axlearn.common.flash_attention.gpu_decoding import GPUDecoding
from axlearn.common.flash_attention.tpu_attention import LegacyTPUFlashAttention, TPUSplashAttention
Expand All @@ -24,6 +25,7 @@
gpu=[
GPUDecoding,
# For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest.
ROCmTransformerEngineFlashAttention,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is ROCmTransformerEngineFlashAttention code path selected here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check in is_supported which will throw unsupported if not on ROCm backend. Let me know if this approach is fine

CuDNNGPUFlashAttention,
# Fallbacks to Pallas if cuDNN cannot be used without instantiating bias tensors.
PallasGPUFlashAttention,
Expand Down
20 changes: 19 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,24 @@ def get_trainer_kwargs(
],
),
),
(
"amd-mi300-single-node",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=False,
policy=jax_remat_policies.nothing_saveable
),
}
),
],
),
),
),
)
else:
Expand Down Expand Up @@ -857,7 +875,7 @@ def wrapper(config_name: str = config_name):
arch=arch, model_size="golden-run-test", version=f"v{version.value}"
)
] = wrapper
if model_size in ("1B", "3B", "7B", "8B"):
if model_size in ("1B", "3B", "7B", "8B", "70B"):

def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config:
"""Make a single-host variant of the base config.
Expand Down
28 changes: 28 additions & 0 deletions launch_70B_single_node.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH
export XLA_FLAGS="--xla_gpu_enable_cublaslt=True --xla_gpu_graph_level=0 --xla_gpu_autotune_level=0 --xla_gpu_enable_latency_hiding_scheduler=true"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.975
export HSA_FORCE_FINE_GRAIN_PCIE=1
export GPU_MAX_HW_QUEUES=2
export HIP_FORCE_DEV_KERNARG=1
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1
export NVTE_FUSED_ATTN=1
export NVTE_FUSED_ATTN_CK=1
export NVTE_FUSED_ATTN_AOTRITON=1
export NVTE_CK_EXT_ASM=1
export NVTE_CK_ASM_ATOMIC_FP32=0
export NVTE_CK_ASM_NO_COEX=0
export NVTE_CK_ASM_RTZ_CVT=1
export NVTE_CK_BWD_V3=1
export NVTE_CK_V3_RTZ_CVT=2
export NVTE_CK_USES_BWD_V3=1
export NVTE_CK_IS_V3_ATOMIC_FP32=0
export NVTE_CK_IS_V3_SPEC=1
export NVTE_CK_HOW_V3_BF16_CVT=2


mkdir -p /tmp/gpt_c4_test; \
python3 -m axlearn.common.launch_trainer_main \
--module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \
--trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \
--jax_backend=gpu \
--mesh_selector="amd-mi300-single-node" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we reproduce your experiments?

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ core = [
"absl-py==2.1.0",
"chex==0.1.88",
"importlab==0.8.1", # breaks pytype on 0.8
"jax==0.4.38",
"jaxlib==0.4.38",
"jax==0.4.35",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional to downgrade the jax version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing they are relying on a fixed version on transformer_engine, which has to be compatible with Jax version.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because 0.4.35 is the closest available ROCm JAX release (https://github.com/ROCm/jax/releases). The next version available is 0.5 which we are open to upgrading to whenever axlearn decides to use it.

"jaxlib==0.4.35",
"ml-dtypes==0.4.1",
"msgpack==1.1.0", # for checkpointing.
"nltk==3.7", # for text preprocessing
Expand Down