Skip to content

Add ROCm support #1112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
)
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import pallas as pl
try:
from transformer_engine.jax.flax.transformer import DotProductAttention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to extract the lower level API from transformer engine? We waited to integrate cudnn from Jax native exactly to avoid dependency on transformer_engine, which make Jax upgrade and version control much more difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, we are also open to plumb through xla customized call

except:
pass

from axlearn.common.attention_bias import (
NEG_INF,
Expand Down Expand Up @@ -985,3 +989,95 @@ class CuDNNGPUFlashAttentionWithExplicitBias(CuDNNGPUFlashAttention):
"""

_allow_explicit_bias = True


class ROCmTransformerEngineFlashAttention(BaseFlashAttention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a unittest for this layer?

"""Wraps Transformer Engine DotProductAttention to enable Flash Attention on ROCm.

Currently it has only been tested on standard Llama-like configs, and training only.
"""

_allow_explicit_bias = False

def is_supported(
self, *, query: Tensor, key: Tensor, value: Tensor, bias: BaseAttentionBias
) -> bool:
"""See `BaseFlashAttention.is_supported`."""
if not super().is_supported(query=query, key=key, value=value, bias=bias):
return False

try:
from transformer_engine.jax.flax.transformer import DotProductAttention
except ImportError:
return self._log_unsupported("could not import Transformer Engine")

if self.cfg.is_decoding:
return self._log_unsupported("currently only training has been tested with this attention implementation.")
else:
# cuDNN has no concept of block size. It only requires the length of query and
# key/value to be even.
if not self._check_block_size(query=query, key=key, block_size=2):
return False

if query.dtype not in (jnp.float16, jnp.bfloat16):
return self._log_unsupported(
f"{query.dtype=} is not supported. Only supports float16 and bfloat16."
)

if jax.default_backend() == "cpu":
return self._log_unsupported("we're on CPU emulation.")

head_dim = query.shape[-1]
if head_dim % 8 != 0:
return self._log_unsupported(f"{head_dim=} is not divisible by 8.")
if head_dim > 128:
return self._log_unsupported(f"{head_dim=} > 128")
_, sliding, explicit_bias = split(bias, CausalAttentionBias, SlidingWindowAttentionBias)

if sliding.has_value():
return self._log_unsupported("sliding window attention has not been tested currently.")
if explicit_bias.has_value() and not self._allow_explicit_bias:
return self._log_unsupported("we don't allow explicit bias at this stage.")

logging.info("Using %s.", self.name())
return True

@functools.partial(jax.jit, static_argnames=["self"])
def __call__(
self,
query: Tensor,
key: Tensor,
value: Tensor,
bias: BaseAttentionBias,
prng_key: Optional[Tensor] = None,
) -> Tensor:
"""See `BaseFlashAttention.__call__`."""

args = dict(
query=query,
key=repeat_kv_heads(query.shape[2], key),
value=repeat_kv_heads(query.shape[2], value),
)

_, _, num_query_heads, head_dim = query.shape
_, _, num_kv_heads, _ = key.shape

causal, _, _ = split(
bias, CausalAttentionBias, SlidingWindowAttentionBias
)
mask_type = "causal" if causal.has_value() else "no_mask"

rocm_te_dot_product_attention = DotProductAttention(
head_dim=head_dim,
num_attention_heads=num_query_heads,
num_gqa_groups=num_kv_heads,
attn_mask_type=mask_type,
attn_bias_type="no_bias",
attention_dropout=self.cfg.dropout_rate,
dtype=query.dtype,
qkv_layout="BSHD_BSHD_BSHD",
transpose_batch_sequence=False,
scale_factor=self.cfg.softmax_scale,
)

return rocm_te_dot_product_attention.apply({}, rngs={'dropout': prng_key}, **args)
208 changes: 208 additions & 0 deletions axlearn/common/flash_attention/gpu_attention_rocm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright © 2023 Apple Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright © 2023 Apple Inc.
# Copyright © 2025 Apple Inc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And feel free to make it AMD if you intend to maintain it.

#
# Some of the code in this file is adapted from:
#
# jax-ml/jax-triton:
# Copyright 2023 The jax_triton Authors.
# Licensed under the Apache License, Version 2.0 (the "License").

"""Tests GPU FlashAttention kernels.

Currently tested on MI300. To run tests in parallel on a multi-GPU machine, use this:
```
PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file path needs to be revised to the current file.

```
"""
import functools
from typing import Any, Callable, Literal, Optional

import chex
import jax
import jax.numpy as jnp
import jax.random
import pytest

from axlearn.common.attention_bias import (
CausalAttentionBias,
MaskFn,
ZeroAttentionBias,
causal_mask,
)
from axlearn.common.flash_attention.common import ReferenceMHA
from axlearn.common.flash_attention.gpu_attention import (
ROCmTransformerEngineFlashAttention,
)
from axlearn.common.flash_attention.test_utils import generate_attention_data
from axlearn.common.utils import Tensor

if jax.default_backend() not in ("gpu", "cpu"):
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


def _default_tol_fn(backend, dtype):
del backend
if dtype == jnp.bfloat16:
return dict(atol=0.05, rtol=1e-2)
if dtype == jnp.float16:
return dict(atol=0.05, rtol=1e-5)
if dtype == jnp.float32:
return dict(atol=0.025, rtol=1e-5)
raise ValueError(f"Unsupported dtype: {dtype}")


TestFn = Callable[[Tensor, Tensor, Tensor], Tensor]
TolFn = Callable[[str, Any], dict[str, float]]


def _test_forward_and_backward(
q: Tensor,
k: Tensor,
v: Tensor,
bias,
*,
ref_fn: TestFn,
test_fn: TestFn,
forward_tol_fn: Callable = _default_tol_fn,
backward_tol_fn: Callable = _default_tol_fn,
):
ref_fn = jax.jit(ref_fn)
test_fn = jax.jit(test_fn)
prng_key = jax.random.PRNGKey(44)
jax_out = test_fn(q, k, v, bias, prng_key)
jax_ref_out = ref_fn(q, k, v, bias, prng_key)
backend = jax.default_backend()
chex.assert_trees_all_close(jax_out, jax_ref_out, **forward_tol_fn(backend, q.dtype))

# Compare gradients.
jax_grads = jax.grad(lambda *args: ref_fn(*args).mean(), argnums=(0, 1, 2))(
q, k, v, bias, prng_key
)
jax_ref_grads = jax.grad(lambda *args: test_fn(*args).mean(), argnums=(0, 1, 2))(
q, k, v, bias, prng_key
)
chex.assert_trees_all_close(jax_grads, jax_ref_grads, **backward_tol_fn(backend, q.dtype))


def common_attn_test_params(func):
params = [
pytest.mark.parametrize("kv_len", [None, 512]),
pytest.mark.parametrize("dropout_rate", [0, 0.1]),
pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]),
pytest.mark.parametrize("with_segment_ids", [True, False]),
pytest.mark.parametrize("block_size", [128]), # Triton broken for block size !=128.
pytest.mark.parametrize("mask_fn", [causal_mask, None]),
pytest.mark.parametrize("dtype", [jnp.float16, jnp.float32]),
]
# Apply in reverse order to stack correctly.
for param in reversed(params):
func = param(func)
return func


# We test the ROCm TE DotProductAttention against the reference flash_attention.
# Due to its algorithmic equivalence, the outputs should be close in both fp16 and bfloat16.
@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 2, 1024, 128),
(2, 2, 1024, 128),
(1, 4, 2048, 128),
(2, 8, 2048, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16])
def test_rocmte_against_xla_ref(
batch_size: int,
num_heads: int,
seq_len: int,
per_head_dim: int,
causal: bool,
dtype: jnp.dtype,
):
if jax.default_backend() == "cpu":
pytest.skip(reason="ROCm function needs GPU.")

q, k, v, bias = generate_attention_data(
batch_size,
seq_len,
seq_len,
num_heads,
per_head_dim,
mask_fn=causal_mask if causal else None,
dtype=dtype,
)

cfg = dict(
softmax_scale=q.shape[-1] ** -0.5,
)

# Compare outputs.
test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate()
chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True)
ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate()

def forward_tol_fn(backend, dtype):
del backend
if dtype == jnp.bfloat16:
return dict(atol=0.02, rtol=1e-5)
if dtype == jnp.float16:
return dict(atol=0.005, rtol=1e-5)

_test_forward_and_backward(
q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=forward_tol_fn
)


def _cudnn_xla_forward_tol_fn(backend, dtype):
del backend
# cuDNN has higher diff when compared to non-fused attention in XLA.
if dtype == jnp.bfloat16:
return dict(atol=0.25, rtol=1e-3)
if dtype == jnp.float16:
return dict(atol=0.05, rtol=1e-3)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,kv_seq_len,per_head_dim",
[
(1, 1, 378, 676, 72),
(2, 4, 582, 582, 56),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
def test_rocmte_seqlen_head_support(
batch_size: int,
num_heads: int,
seq_len: int,
kv_seq_len: int,
per_head_dim: int,
causal: bool,
dtype: jnp.dtype,
):
"""Tests that ROCm TE supports any even sequence length and head dim % 8 == 0."""
if jax.default_backend() == "cpu":
pytest.skip(reason="ROCm function needs GPU.")
q, k, v, bias = generate_attention_data(
batch_size,
seq_len,
kv_seq_len,
num_heads,
per_head_dim,
mask_fn=causal_mask if causal else None,
dtype=dtype,
)

cfg = dict(
softmax_scale=q.shape[-1] ** -0.5,
)

# Compare outputs.
test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate()
ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate()
chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True)

_test_forward_and_backward(
q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=_cudnn_xla_forward_tol_fn
)
2 changes: 2 additions & 0 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CuDNNGPUFlashAttention,
CuDNNGPUFlashAttentionWithExplicitBias,
PallasGPUFlashAttention,
ROCmTransformerEngineFlashAttention,
)
from axlearn.common.flash_attention.gpu_decoding import GPUDecoding
from axlearn.common.flash_attention.tpu_attention import LegacyTPUFlashAttention, TPUSplashAttention
Expand All @@ -24,6 +25,7 @@
gpu=[
GPUDecoding,
# For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest.
ROCmTransformerEngineFlashAttention,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is ROCmTransformerEngineFlashAttention code path selected here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check in is_supported which will throw unsupported if not on ROCm backend. Let me know if this approach is fine

CuDNNGPUFlashAttention,
# Fallbacks to Pallas if cuDNN cannot be used without instantiating bias tensors.
PallasGPUFlashAttention,
Expand Down
20 changes: 19 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,24 @@ def get_trainer_kwargs(
],
),
),
(
"amd-mi300-single-node",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=False,
policy=jax_remat_policies.nothing_saveable
),
}
),
],
),
),
),
)
else:
Expand Down Expand Up @@ -857,7 +875,7 @@ def wrapper(config_name: str = config_name):
arch=arch, model_size="golden-run-test", version=f"v{version.value}"
)
] = wrapper
if model_size in ("1B", "3B", "7B", "8B"):
if model_size in ("1B", "3B", "7B", "8B", "70B"):

def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config:
"""Make a single-host variant of the base config.
Expand Down
28 changes: 28 additions & 0 deletions launch_70B_single_node.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH
export XLA_FLAGS="--xla_gpu_enable_cublaslt=True --xla_gpu_graph_level=0 --xla_gpu_autotune_level=0 --xla_gpu_enable_latency_hiding_scheduler=true"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.975
export HSA_FORCE_FINE_GRAIN_PCIE=1
export GPU_MAX_HW_QUEUES=2
export HIP_FORCE_DEV_KERNARG=1
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1
export NVTE_FUSED_ATTN=1
export NVTE_FUSED_ATTN_CK=1
export NVTE_FUSED_ATTN_AOTRITON=1
export NVTE_CK_EXT_ASM=1
export NVTE_CK_ASM_ATOMIC_FP32=0
export NVTE_CK_ASM_NO_COEX=0
export NVTE_CK_ASM_RTZ_CVT=1
export NVTE_CK_BWD_V3=1
export NVTE_CK_V3_RTZ_CVT=2
export NVTE_CK_USES_BWD_V3=1
export NVTE_CK_IS_V3_ATOMIC_FP32=0
export NVTE_CK_IS_V3_SPEC=1
export NVTE_CK_HOW_V3_BF16_CVT=2


mkdir -p /tmp/gpt_c4_test; \
python3 -m axlearn.common.launch_trainer_main \
--module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \
--trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \
--jax_backend=gpu \
--mesh_selector="amd-mi300-single-node" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we reproduce your experiments?

Loading