-
Notifications
You must be signed in to change notification settings - Fork 316
Add ROCm support #1112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add ROCm support #1112
Changes from 5 commits
0e4ffe5
14f0606
44c1ab9
5377fea
15a4f47
ee4820c
9b37416
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,10 @@ | |
) | ||
from jax.ad_checkpoint import checkpoint_name | ||
from jax.experimental import pallas as pl | ||
try: | ||
from transformer_engine.jax.flax.transformer import DotProductAttention | ||
except: | ||
pass | ||
|
||
from axlearn.common.attention_bias import ( | ||
NEG_INF, | ||
|
@@ -985,3 +989,95 @@ class CuDNNGPUFlashAttentionWithExplicitBias(CuDNNGPUFlashAttention): | |
""" | ||
|
||
_allow_explicit_bias = True | ||
|
||
|
||
class ROCmTransformerEngineFlashAttention(BaseFlashAttention): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a unittest for this layer? |
||
"""Wraps Transformer Engine DotProductAttention to enable Flash Attention on ROCm. | ||
|
||
Currently it has only been tested on standard Llama-like configs, and training only. | ||
""" | ||
|
||
_allow_explicit_bias = False | ||
|
||
def is_supported( | ||
self, *, query: Tensor, key: Tensor, value: Tensor, bias: BaseAttentionBias | ||
) -> bool: | ||
"""See `BaseFlashAttention.is_supported`.""" | ||
if not super().is_supported(query=query, key=key, value=value, bias=bias): | ||
return False | ||
|
||
try: | ||
from transformer_engine.jax.flax.transformer import DotProductAttention | ||
except ImportError: | ||
return self._log_unsupported("could not import Transformer Engine") | ||
|
||
if self.cfg.is_decoding: | ||
return self._log_unsupported("currently only training has been tested with this attention implementation.") | ||
else: | ||
# cuDNN has no concept of block size. It only requires the length of query and | ||
# key/value to be even. | ||
if not self._check_block_size(query=query, key=key, block_size=2): | ||
return False | ||
|
||
if query.dtype not in (jnp.float16, jnp.bfloat16): | ||
return self._log_unsupported( | ||
f"{query.dtype=} is not supported. Only supports float16 and bfloat16." | ||
) | ||
|
||
if jax.default_backend() == "cpu": | ||
return self._log_unsupported("we're on CPU emulation.") | ||
|
||
head_dim = query.shape[-1] | ||
if head_dim % 8 != 0: | ||
return self._log_unsupported(f"{head_dim=} is not divisible by 8.") | ||
if head_dim > 128: | ||
return self._log_unsupported(f"{head_dim=} > 128") | ||
_, sliding, explicit_bias = split(bias, CausalAttentionBias, SlidingWindowAttentionBias) | ||
|
||
if sliding.has_value(): | ||
return self._log_unsupported("sliding window attention has not been tested currently.") | ||
if explicit_bias.has_value() and not self._allow_explicit_bias: | ||
return self._log_unsupported("we don't allow explicit bias at this stage.") | ||
|
||
logging.info("Using %s.", self.name()) | ||
return True | ||
|
||
@functools.partial(jax.jit, static_argnames=["self"]) | ||
def __call__( | ||
self, | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
bias: BaseAttentionBias, | ||
prng_key: Optional[Tensor] = None, | ||
) -> Tensor: | ||
"""See `BaseFlashAttention.__call__`.""" | ||
|
||
args = dict( | ||
query=query, | ||
key=repeat_kv_heads(query.shape[2], key), | ||
value=repeat_kv_heads(query.shape[2], value), | ||
) | ||
|
||
_, _, num_query_heads, head_dim = query.shape | ||
_, _, num_kv_heads, _ = key.shape | ||
|
||
causal, _, _ = split( | ||
bias, CausalAttentionBias, SlidingWindowAttentionBias | ||
) | ||
mask_type = "causal" if causal.has_value() else "no_mask" | ||
|
||
rocm_te_dot_product_attention = DotProductAttention( | ||
head_dim=head_dim, | ||
num_attention_heads=num_query_heads, | ||
num_gqa_groups=num_kv_heads, | ||
attn_mask_type=mask_type, | ||
attn_bias_type="no_bias", | ||
attention_dropout=self.cfg.dropout_rate, | ||
dtype=query.dtype, | ||
qkv_layout="BSHD_BSHD_BSHD", | ||
transpose_batch_sequence=False, | ||
scale_factor=self.cfg.softmax_scale, | ||
) | ||
|
||
return rocm_te_dot_product_attention.apply({}, rngs={'dropout': prng_key}, **args) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,208 @@ | ||||||
# Copyright © 2023 Apple Inc. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And feel free to make it AMD if you intend to maintain it. |
||||||
# | ||||||
# Some of the code in this file is adapted from: | ||||||
# | ||||||
# jax-ml/jax-triton: | ||||||
# Copyright 2023 The jax_triton Authors. | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||||||
|
||||||
"""Tests GPU FlashAttention kernels. | ||||||
|
||||||
Currently tested on MI300. To run tests in parallel on a multi-GPU machine, use this: | ||||||
``` | ||||||
PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file path needs to be revised to the current file. |
||||||
``` | ||||||
""" | ||||||
import functools | ||||||
from typing import Any, Callable, Literal, Optional | ||||||
|
||||||
import chex | ||||||
import jax | ||||||
import jax.numpy as jnp | ||||||
import jax.random | ||||||
import pytest | ||||||
|
||||||
from axlearn.common.attention_bias import ( | ||||||
CausalAttentionBias, | ||||||
MaskFn, | ||||||
ZeroAttentionBias, | ||||||
causal_mask, | ||||||
) | ||||||
from axlearn.common.flash_attention.common import ReferenceMHA | ||||||
from axlearn.common.flash_attention.gpu_attention import ( | ||||||
ROCmTransformerEngineFlashAttention, | ||||||
) | ||||||
from axlearn.common.flash_attention.test_utils import generate_attention_data | ||||||
from axlearn.common.utils import Tensor | ||||||
|
||||||
if jax.default_backend() not in ("gpu", "cpu"): | ||||||
pytest.skip(reason="Incompatible hardware", allow_module_level=True) | ||||||
|
||||||
|
||||||
def _default_tol_fn(backend, dtype): | ||||||
del backend | ||||||
if dtype == jnp.bfloat16: | ||||||
return dict(atol=0.05, rtol=1e-2) | ||||||
if dtype == jnp.float16: | ||||||
return dict(atol=0.05, rtol=1e-5) | ||||||
if dtype == jnp.float32: | ||||||
return dict(atol=0.025, rtol=1e-5) | ||||||
raise ValueError(f"Unsupported dtype: {dtype}") | ||||||
|
||||||
|
||||||
TestFn = Callable[[Tensor, Tensor, Tensor], Tensor] | ||||||
TolFn = Callable[[str, Any], dict[str, float]] | ||||||
|
||||||
|
||||||
def _test_forward_and_backward( | ||||||
q: Tensor, | ||||||
k: Tensor, | ||||||
v: Tensor, | ||||||
bias, | ||||||
*, | ||||||
ref_fn: TestFn, | ||||||
test_fn: TestFn, | ||||||
forward_tol_fn: Callable = _default_tol_fn, | ||||||
backward_tol_fn: Callable = _default_tol_fn, | ||||||
): | ||||||
ref_fn = jax.jit(ref_fn) | ||||||
test_fn = jax.jit(test_fn) | ||||||
prng_key = jax.random.PRNGKey(44) | ||||||
jax_out = test_fn(q, k, v, bias, prng_key) | ||||||
jax_ref_out = ref_fn(q, k, v, bias, prng_key) | ||||||
backend = jax.default_backend() | ||||||
chex.assert_trees_all_close(jax_out, jax_ref_out, **forward_tol_fn(backend, q.dtype)) | ||||||
|
||||||
# Compare gradients. | ||||||
jax_grads = jax.grad(lambda *args: ref_fn(*args).mean(), argnums=(0, 1, 2))( | ||||||
q, k, v, bias, prng_key | ||||||
) | ||||||
jax_ref_grads = jax.grad(lambda *args: test_fn(*args).mean(), argnums=(0, 1, 2))( | ||||||
q, k, v, bias, prng_key | ||||||
) | ||||||
chex.assert_trees_all_close(jax_grads, jax_ref_grads, **backward_tol_fn(backend, q.dtype)) | ||||||
|
||||||
|
||||||
def common_attn_test_params(func): | ||||||
params = [ | ||||||
pytest.mark.parametrize("kv_len", [None, 512]), | ||||||
pytest.mark.parametrize("dropout_rate", [0, 0.1]), | ||||||
pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]), | ||||||
pytest.mark.parametrize("with_segment_ids", [True, False]), | ||||||
pytest.mark.parametrize("block_size", [128]), # Triton broken for block size !=128. | ||||||
pytest.mark.parametrize("mask_fn", [causal_mask, None]), | ||||||
pytest.mark.parametrize("dtype", [jnp.float16, jnp.float32]), | ||||||
] | ||||||
# Apply in reverse order to stack correctly. | ||||||
for param in reversed(params): | ||||||
func = param(func) | ||||||
return func | ||||||
|
||||||
|
||||||
# We test the ROCm TE DotProductAttention against the reference flash_attention. | ||||||
# Due to its algorithmic equivalence, the outputs should be close in both fp16 and bfloat16. | ||||||
@pytest.mark.parametrize( | ||||||
"batch_size,num_heads,seq_len,per_head_dim", | ||||||
[ | ||||||
(1, 2, 1024, 128), | ||||||
(2, 2, 1024, 128), | ||||||
(1, 4, 2048, 128), | ||||||
(2, 8, 2048, 128), | ||||||
], | ||||||
) | ||||||
@pytest.mark.parametrize("causal", [True, False]) | ||||||
@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) | ||||||
def test_rocmte_against_xla_ref( | ||||||
batch_size: int, | ||||||
num_heads: int, | ||||||
seq_len: int, | ||||||
per_head_dim: int, | ||||||
causal: bool, | ||||||
dtype: jnp.dtype, | ||||||
): | ||||||
if jax.default_backend() == "cpu": | ||||||
pytest.skip(reason="ROCm function needs GPU.") | ||||||
|
||||||
q, k, v, bias = generate_attention_data( | ||||||
batch_size, | ||||||
seq_len, | ||||||
seq_len, | ||||||
num_heads, | ||||||
per_head_dim, | ||||||
mask_fn=causal_mask if causal else None, | ||||||
dtype=dtype, | ||||||
) | ||||||
|
||||||
cfg = dict( | ||||||
softmax_scale=q.shape[-1] ** -0.5, | ||||||
) | ||||||
|
||||||
# Compare outputs. | ||||||
test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate() | ||||||
chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True) | ||||||
ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate() | ||||||
|
||||||
def forward_tol_fn(backend, dtype): | ||||||
del backend | ||||||
if dtype == jnp.bfloat16: | ||||||
return dict(atol=0.02, rtol=1e-5) | ||||||
if dtype == jnp.float16: | ||||||
return dict(atol=0.005, rtol=1e-5) | ||||||
|
||||||
_test_forward_and_backward( | ||||||
q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=forward_tol_fn | ||||||
) | ||||||
|
||||||
|
||||||
def _cudnn_xla_forward_tol_fn(backend, dtype): | ||||||
del backend | ||||||
# cuDNN has higher diff when compared to non-fused attention in XLA. | ||||||
if dtype == jnp.bfloat16: | ||||||
return dict(atol=0.25, rtol=1e-3) | ||||||
if dtype == jnp.float16: | ||||||
return dict(atol=0.05, rtol=1e-3) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize( | ||||||
"batch_size,num_heads,seq_len,kv_seq_len,per_head_dim", | ||||||
[ | ||||||
(1, 1, 378, 676, 72), | ||||||
(2, 4, 582, 582, 56), | ||||||
], | ||||||
) | ||||||
@pytest.mark.parametrize("causal", [True, False]) | ||||||
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) | ||||||
def test_rocmte_seqlen_head_support( | ||||||
batch_size: int, | ||||||
num_heads: int, | ||||||
seq_len: int, | ||||||
kv_seq_len: int, | ||||||
per_head_dim: int, | ||||||
causal: bool, | ||||||
dtype: jnp.dtype, | ||||||
): | ||||||
"""Tests that ROCm TE supports any even sequence length and head dim % 8 == 0.""" | ||||||
if jax.default_backend() == "cpu": | ||||||
pytest.skip(reason="ROCm function needs GPU.") | ||||||
q, k, v, bias = generate_attention_data( | ||||||
batch_size, | ||||||
seq_len, | ||||||
kv_seq_len, | ||||||
num_heads, | ||||||
per_head_dim, | ||||||
mask_fn=causal_mask if causal else None, | ||||||
dtype=dtype, | ||||||
) | ||||||
|
||||||
cfg = dict( | ||||||
softmax_scale=q.shape[-1] ** -0.5, | ||||||
) | ||||||
|
||||||
# Compare outputs. | ||||||
test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate() | ||||||
ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate() | ||||||
chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True) | ||||||
|
||||||
_test_forward_and_backward( | ||||||
q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=_cudnn_xla_forward_tol_fn | ||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
CuDNNGPUFlashAttention, | ||
CuDNNGPUFlashAttentionWithExplicitBias, | ||
PallasGPUFlashAttention, | ||
ROCmTransformerEngineFlashAttention, | ||
) | ||
from axlearn.common.flash_attention.gpu_decoding import GPUDecoding | ||
from axlearn.common.flash_attention.tpu_attention import LegacyTPUFlashAttention, TPUSplashAttention | ||
|
@@ -24,6 +25,7 @@ | |
gpu=[ | ||
GPUDecoding, | ||
# For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest. | ||
ROCmTransformerEngineFlashAttention, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a check in |
||
CuDNNGPUFlashAttention, | ||
# Fallbacks to Pallas if cuDNN cannot be used without instantiating bias tensors. | ||
PallasGPUFlashAttention, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH | ||
export XLA_FLAGS="--xla_gpu_enable_cublaslt=True --xla_gpu_graph_level=0 --xla_gpu_autotune_level=0 --xla_gpu_enable_latency_hiding_scheduler=true" | ||
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.975 | ||
export HSA_FORCE_FINE_GRAIN_PCIE=1 | ||
export GPU_MAX_HW_QUEUES=2 | ||
export HIP_FORCE_DEV_KERNARG=1 | ||
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 | ||
export NVTE_FUSED_ATTN=1 | ||
export NVTE_FUSED_ATTN_CK=1 | ||
export NVTE_FUSED_ATTN_AOTRITON=1 | ||
export NVTE_CK_EXT_ASM=1 | ||
export NVTE_CK_ASM_ATOMIC_FP32=0 | ||
export NVTE_CK_ASM_NO_COEX=0 | ||
export NVTE_CK_ASM_RTZ_CVT=1 | ||
export NVTE_CK_BWD_V3=1 | ||
export NVTE_CK_V3_RTZ_CVT=2 | ||
export NVTE_CK_USES_BWD_V3=1 | ||
export NVTE_CK_IS_V3_ATOMIC_FP32=0 | ||
export NVTE_CK_IS_V3_SPEC=1 | ||
export NVTE_CK_HOW_V3_BF16_CVT=2 | ||
|
||
|
||
mkdir -p /tmp/gpt_c4_test; \ | ||
python3 -m axlearn.common.launch_trainer_main \ | ||
--module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \ | ||
--trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \ | ||
--jax_backend=gpu \ | ||
--mesh_selector="amd-mi300-single-node" \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can we reproduce your experiments? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to extract the lower level API from transformer engine? We waited to integrate cudnn from Jax native exactly to avoid dependency on transformer_engine, which make Jax upgrade and version control much more difficult.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw, we are also open to plumb through xla customized call