-
Notifications
You must be signed in to change notification settings - Fork 337
Add ROCm support #1112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add ROCm support #1112
Changes from 4 commits
0e4ffe5
14f0606
44c1ab9
5377fea
15a4f47
ee4820c
9b37416
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,10 @@ | |
) | ||
from jax.ad_checkpoint import checkpoint_name | ||
from jax.experimental import pallas as pl | ||
try: | ||
from transformer_engine.jax.flax.transformer import DotProductAttention | ||
except: | ||
pass | ||
|
||
from axlearn.common.attention_bias import ( | ||
NEG_INF, | ||
|
@@ -985,3 +989,95 @@ class CuDNNGPUFlashAttentionWithExplicitBias(CuDNNGPUFlashAttention): | |
""" | ||
|
||
_allow_explicit_bias = True | ||
|
||
|
||
class ROCmTransformerEngineFlashAttention(BaseFlashAttention): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a unittest for this layer? |
||
"""Wraps Transformer Engine DotProductAttention to enable Flash Attention on ROCm. | ||
|
||
Currently it has only been tested on standard Llama-like configs, and training only. | ||
""" | ||
|
||
_allow_explicit_bias = False | ||
|
||
def is_supported( | ||
self, *, query: Tensor, key: Tensor, value: Tensor, bias: BaseAttentionBias | ||
) -> bool: | ||
"""See `BaseFlashAttention.is_supported`.""" | ||
if not super().is_supported(query=query, key=key, value=value, bias=bias): | ||
return False | ||
|
||
try: | ||
from transformer_engine.jax.flax.transformer import DotProductAttention | ||
except ImportError: | ||
return self._log_unsupported("could not import Transformer Engine") | ||
|
||
if self.cfg.is_decoding: | ||
return self._log_unsupported("currently only training has been tested with this attention implementation.") | ||
else: | ||
# cuDNN has no concept of block size. It only requires the length of query and | ||
# key/value to be even. | ||
if not self._check_block_size(query=query, key=key, block_size=2): | ||
return False | ||
|
||
if query.dtype not in (jnp.float16, jnp.bfloat16): | ||
return self._log_unsupported( | ||
f"{query.dtype=} is not supported. Only supports float16 and bfloat16." | ||
) | ||
|
||
if jax.default_backend() == "cpu": | ||
return self._log_unsupported("we're on CPU emulation.") | ||
|
||
head_dim = query.shape[-1] | ||
if head_dim % 8 != 0: | ||
return self._log_unsupported(f"{head_dim=} is not divisible by 8.") | ||
if head_dim > 128: | ||
return self._log_unsupported(f"{head_dim=} > 128") | ||
_, sliding, explicit_bias = split(bias, CausalAttentionBias, SlidingWindowAttentionBias) | ||
|
||
if sliding.has_value(): | ||
return self._log_unsupported("sliding window attention has not been tested currently.") | ||
if explicit_bias.has_value() and not self._allow_explicit_bias: | ||
return self._log_unsupported("we don't allow explicit bias at this stage.") | ||
|
||
logging.info("Using %s.", self.name()) | ||
return True | ||
|
||
@functools.partial(jax.jit, static_argnames=["self"]) | ||
def __call__( | ||
self, | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
bias: BaseAttentionBias, | ||
prng_key: Optional[Tensor] = None, | ||
) -> Tensor: | ||
"""See `BaseFlashAttention.__call__`.""" | ||
|
||
args = dict( | ||
query=query, | ||
key=repeat_kv_heads(query.shape[2], key), | ||
value=repeat_kv_heads(query.shape[2], value), | ||
) | ||
|
||
_, _, num_query_heads, head_dim = query.shape | ||
_, _, num_kv_heads, _ = key.shape | ||
|
||
causal, _, _ = split( | ||
bias, CausalAttentionBias, SlidingWindowAttentionBias | ||
) | ||
mask_type = "causal" if causal.has_value() else "no_mask" | ||
|
||
rocm_te_dot_product_attention = DotProductAttention( | ||
head_dim=head_dim, | ||
num_attention_heads=num_query_heads, | ||
num_gqa_groups=num_kv_heads, | ||
attn_mask_type=mask_type, | ||
attn_bias_type="no_bias", | ||
attention_dropout=self.cfg.dropout_rate, | ||
dtype=query.dtype, | ||
qkv_layout="BSHD_BSHD_BSHD", | ||
transpose_batch_sequence=False, | ||
scale_factor=self.cfg.softmax_scale, | ||
) | ||
|
||
return rocm_te_dot_product_attention.apply({}, rngs={'dropout': prng_key}, **args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
CuDNNGPUFlashAttention, | ||
CuDNNGPUFlashAttentionWithExplicitBias, | ||
PallasGPUFlashAttention, | ||
ROCmTransformerEngineFlashAttention, | ||
) | ||
from axlearn.common.flash_attention.gpu_decoding import GPUDecoding | ||
from axlearn.common.flash_attention.tpu_attention import LegacyTPUFlashAttention, TPUSplashAttention | ||
|
@@ -24,6 +25,7 @@ | |
gpu=[ | ||
GPUDecoding, | ||
# For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest. | ||
ROCmTransformerEngineFlashAttention, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a check in |
||
CuDNNGPUFlashAttention, | ||
# Fallbacks to Pallas if cuDNN cannot be used without instantiating bias tensors. | ||
PallasGPUFlashAttention, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH | ||
export XLA_FLAGS="--xla_gpu_enable_cublaslt=True --xla_gpu_graph_level=0 --xla_gpu_autotune_level=0 --xla_gpu_enable_latency_hiding_scheduler=true" | ||
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.975 | ||
export HSA_FORCE_FINE_GRAIN_PCIE=1 | ||
export GPU_MAX_HW_QUEUES=2 | ||
export HIP_FORCE_DEV_KERNARG=1 | ||
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 | ||
export NVTE_FUSED_ATTN=1 | ||
export NVTE_FUSED_ATTN_CK=1 | ||
export NVTE_FUSED_ATTN_AOTRITON=1 | ||
export NVTE_CK_EXT_ASM=1 | ||
export NVTE_CK_ASM_ATOMIC_FP32=0 | ||
export NVTE_CK_ASM_NO_COEX=0 | ||
export NVTE_CK_ASM_RTZ_CVT=1 | ||
export NVTE_CK_BWD_V3=1 | ||
export NVTE_CK_V3_RTZ_CVT=2 | ||
export NVTE_CK_USES_BWD_V3=1 | ||
export NVTE_CK_IS_V3_ATOMIC_FP32=0 | ||
export NVTE_CK_IS_V3_SPEC=1 | ||
export NVTE_CK_HOW_V3_BF16_CVT=2 | ||
|
||
|
||
mkdir -p /tmp/gpt_c4_test; \ | ||
python3 -m axlearn.common.launch_trainer_main \ | ||
--module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \ | ||
--trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \ | ||
--jax_backend=gpu \ | ||
--mesh_selector="amd-mi300-single-node" \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can we reproduce your experiments? |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,8 +23,8 @@ core = [ | |
"absl-py==2.1.0", | ||
"chex==0.1.88", | ||
"importlab==0.8.1", # breaks pytype on 0.8 | ||
"jax==0.4.38", | ||
"jaxlib==0.4.38", | ||
"jax==0.4.35", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this intentional to downgrade the jax version? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am guessing they are relying on a fixed version on transformer_engine, which has to be compatible with Jax version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because 0.4.35 is the closest available ROCm JAX release (https://github.com/ROCm/jax/releases). The next version available is 0.5 which we are open to upgrading to whenever axlearn decides to use it. |
||
"jaxlib==0.4.35", | ||
"ml-dtypes==0.4.1", | ||
"msgpack==1.1.0", # for checkpointing. | ||
"nltk==3.7", # for text preprocessing | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to extract the lower level API from transformer engine? We waited to integrate cudnn from Jax native exactly to avoid dependency on transformer_engine, which make Jax upgrade and version control much more difficult.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw, we are also open to plumb through xla customized call