diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index e89e90c4b..970312e4f 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -43,6 +43,10 @@ ) from jax.ad_checkpoint import checkpoint_name from jax.experimental import pallas as pl +try: + from transformer_engine.jax.flax.transformer import DotProductAttention +except: + pass from axlearn.common.attention_bias import ( NEG_INF, @@ -985,3 +989,104 @@ class CuDNNGPUFlashAttentionWithExplicitBias(CuDNNGPUFlashAttention): """ _allow_explicit_bias = True + + +class ROCmTransformerEngineFlashAttention(BaseFlashAttention): + """Wraps Transformer Engine DotProductAttention to enable Flash Attention on ROCm. + + Currently it has only been tested on standard Llama-like configs, and training only. + """ + + _allow_explicit_bias = False + + def is_supported( + self, *, query: Tensor, key: Tensor, value: Tensor, bias: BaseAttentionBias + ) -> bool: + """See `BaseFlashAttention.is_supported`.""" + if not super().is_supported(query=query, key=key, value=value, bias=bias): + return False + + devices = jax.devices() + if "AMD" not in devices[0].device_kind: + return self._log_unsupported("not on ROCm backend") + + try: + from transformer_engine.jax.flax.transformer import DotProductAttention + except ImportError: + return self._log_unsupported("could not import Transformer Engine") + + if self.cfg.is_decoding: + return self._log_unsupported("currently only training has been tested with this attention implementation.") + else: + # cuDNN has no concept of block size. It only requires the length of query and + # key/value to be even. + if not self._check_block_size(query=query, key=key, block_size=2): + return False + + if query.dtype not in (jnp.float16, jnp.bfloat16): + return self._log_unsupported( + f"{query.dtype=} is not supported. Only supports float16 and bfloat16." + ) + + if jax.default_backend() == "cpu": + return self._log_unsupported("we're on CPU emulation.") + + head_dim = query.shape[-1] + if head_dim % 8 != 0: + return self._log_unsupported(f"{head_dim=} is not divisible by 8.") + if head_dim > 128: + return self._log_unsupported(f"{head_dim=} > 128") + _, sliding, explicit_bias = split(bias, CausalAttentionBias, SlidingWindowAttentionBias) + + if sliding.has_value(): + if self.cfg.dropout_rate != 0.0: + return self._log_unsupported("sliding window with dropout has not been tested.)") + if explicit_bias.has_value() and not self._allow_explicit_bias: + return self._log_unsupported("we don't allow explicit bias at this stage.") + + logging.info("Using %s.", self.name()) + return True + + @functools.partial(jax.jit, static_argnames=["self"]) + def __call__( + self, + query: Tensor, + key: Tensor, + value: Tensor, + bias: BaseAttentionBias, + prng_key: Optional[Tensor] = None, + ) -> Tensor: + """See `BaseFlashAttention.__call__`.""" + + args = dict( + query=query, + key=repeat_kv_heads(query.shape[2], key), + value=repeat_kv_heads(query.shape[2], value), + ) + + _, _, num_query_heads, head_dim = query.shape + _, _, num_kv_heads, _ = key.shape + + causal, sliding, _ = split( + bias, CausalAttentionBias, SlidingWindowAttentionBias + ) + mask_type = "causal" if (causal.has_value() or sliding.has_value()) else "no_mask" + window_size = None + if sliding.has_value(): + window_size = (sliding.sliding_window_size, 0) + + rocm_te_dot_product_attention = DotProductAttention( + head_dim=head_dim, + num_attention_heads=num_query_heads, + num_gqa_groups=num_kv_heads, + attn_mask_type=mask_type, + attn_bias_type="no_bias", + attention_dropout=self.cfg.dropout_rate, + dtype=query.dtype, + qkv_layout="BSHD_BSHD_BSHD", + transpose_batch_sequence=False, + scale_factor=self.cfg.softmax_scale, + window_size=window_size, + ) + + return rocm_te_dot_product_attention.apply({}, rngs={'dropout': prng_key}, **args) \ No newline at end of file diff --git a/axlearn/common/flash_attention/gpu_attention_rocm_test.py b/axlearn/common/flash_attention/gpu_attention_rocm_test.py new file mode 100644 index 000000000..0c13e104f --- /dev/null +++ b/axlearn/common/flash_attention/gpu_attention_rocm_test.py @@ -0,0 +1,246 @@ +# Copyright © 2023 Apple Inc. +# +# Some of the code in this file is adapted from: +# +# jax-ml/jax-triton: +# Copyright 2023 The jax_triton Authors. +# Licensed under the Apache License, Version 2.0 (the "License"). + +"""Tests GPU FlashAttention kernels. + +Currently tested on MI300. To run tests in parallel on a multi-GPU machine, use this: +``` +PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py +``` +""" +import functools +from typing import Any, Callable, Literal, Optional + +import chex +import jax +import jax.numpy as jnp +import jax.random +import pytest + +from axlearn.common.attention_bias import ( + CausalAttentionBias, + MaskFn, + ZeroAttentionBias, + causal_mask, +) +from axlearn.common.flash_attention.common import ReferenceMHA +from axlearn.common.flash_attention.gpu_attention import ( + ROCmTransformerEngineFlashAttention, +) +from axlearn.common.flash_attention.test_utils import generate_attention_data +from axlearn.common.utils import Tensor + +if jax.default_backend() not in ("gpu", "cpu"): + pytest.skip(reason="Incompatible hardware", allow_module_level=True) + + +def _default_tol_fn(backend, dtype): + del backend + if dtype == jnp.bfloat16: + return dict(atol=0.05, rtol=1e-2) + if dtype == jnp.float16: + return dict(atol=0.05, rtol=1e-5) + if dtype == jnp.float32: + return dict(atol=0.025, rtol=1e-5) + raise ValueError(f"Unsupported dtype: {dtype}") + + +TestFn = Callable[[Tensor, Tensor, Tensor], Tensor] +TolFn = Callable[[str, Any], dict[str, float]] + + +def _test_forward_and_backward( + q: Tensor, + k: Tensor, + v: Tensor, + bias, + *, + ref_fn: TestFn, + test_fn: TestFn, + forward_tol_fn: Callable = _default_tol_fn, + backward_tol_fn: Callable = _default_tol_fn, +): + ref_fn = jax.jit(ref_fn) + test_fn = jax.jit(test_fn) + prng_key = jax.random.PRNGKey(44) + jax_out = test_fn(q, k, v, bias, prng_key) + jax_ref_out = ref_fn(q, k, v, bias, prng_key) + backend = jax.default_backend() + chex.assert_trees_all_close(jax_out, jax_ref_out, **forward_tol_fn(backend, q.dtype)) + + # Compare gradients. + jax_grads = jax.grad(lambda *args: ref_fn(*args).mean(), argnums=(0, 1, 2))( + q, k, v, bias, prng_key + ) + jax_ref_grads = jax.grad(lambda *args: test_fn(*args).mean(), argnums=(0, 1, 2))( + q, k, v, bias, prng_key + ) + chex.assert_trees_all_close(jax_grads, jax_ref_grads, **backward_tol_fn(backend, q.dtype)) + + +def common_attn_test_params(func): + params = [ + pytest.mark.parametrize("kv_len", [None, 512]), + pytest.mark.parametrize("dropout_rate", [0, 0.1]), + pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]), + pytest.mark.parametrize("with_segment_ids", [True, False]), + pytest.mark.parametrize("block_size", [128]), # Triton broken for block size !=128. + pytest.mark.parametrize("mask_fn", [causal_mask, None]), + pytest.mark.parametrize("dtype", [jnp.float16, jnp.float32]), + ] + # Apply in reverse order to stack correctly. + for param in reversed(params): + func = param(func) + return func + + +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("seq_len", [512, 2048]) +@pytest.mark.parametrize("sliding_window_size", [256]) +@pytest.mark.parametrize("use_segment_ids", [False]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("per_head_dim", [128]) +@pytest.mark.parametrize("test_cls", [ROCmTransformerEngineFlashAttention]) +def test_sliding_window_mask( + batch_size, + seq_len, + num_heads, + per_head_dim, + sliding_window_size, + use_segment_ids, + test_cls, +): + if jax.default_backend() != "gpu" and test_cls is ROCmTransformerEngineFlashAttention: + pytest.skip("ROCm requires GPU.") + q, k, v, bias = generate_attention_data( + batch_size, + seq_len, + seq_len, + num_heads, + per_head_dim, + sliding_window_sz=sliding_window_size, + with_segment_ids=use_segment_ids, + ) + + cfg = dict( + softmax_scale=q.shape[-1] ** -0.5, + interpret=jax.default_backend() == "cpu", + ) + test_fn = test_cls.default_config().set(**cfg).instantiate() + chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True) + ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate() + _test_forward_and_backward(q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn) + + +# We test the ROCm TE DotProductAttention against the reference flash_attention. +# Due to its algorithmic equivalence, the outputs should be close in both fp16 and bfloat16. +@pytest.mark.parametrize( + "batch_size,num_heads,seq_len,per_head_dim", + [ + (1, 2, 1024, 128), + (2, 2, 1024, 128), + (1, 4, 2048, 128), + (2, 8, 2048, 128), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) +def test_rocmte_against_xla_ref( + batch_size: int, + num_heads: int, + seq_len: int, + per_head_dim: int, + causal: bool, + dtype: jnp.dtype, +): + if jax.default_backend() == "cpu": + pytest.skip(reason="ROCm function needs GPU.") + + q, k, v, bias = generate_attention_data( + batch_size, + seq_len, + seq_len, + num_heads, + per_head_dim, + mask_fn=causal_mask if causal else None, + dtype=dtype, + ) + + cfg = dict( + softmax_scale=q.shape[-1] ** -0.5, + ) + + # Compare outputs. + test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate() + chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True) + ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate() + + def forward_tol_fn(backend, dtype): + del backend + if dtype == jnp.bfloat16: + return dict(atol=0.02, rtol=1e-5) + if dtype == jnp.float16: + return dict(atol=0.005, rtol=1e-5) + + _test_forward_and_backward( + q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=forward_tol_fn + ) + + +def _cudnn_xla_forward_tol_fn(backend, dtype): + del backend + # cuDNN has higher diff when compared to non-fused attention in XLA. + if dtype == jnp.bfloat16: + return dict(atol=0.25, rtol=1e-3) + if dtype == jnp.float16: + return dict(atol=0.05, rtol=1e-3) + + +@pytest.mark.parametrize( + "batch_size,num_heads,seq_len,kv_seq_len,per_head_dim", + [ + (1, 1, 378, 676, 72), + (2, 4, 582, 582, 56), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) +def test_rocmte_seqlen_head_support( + batch_size: int, + num_heads: int, + seq_len: int, + kv_seq_len: int, + per_head_dim: int, + causal: bool, + dtype: jnp.dtype, +): + """Tests that ROCm TE supports any even sequence length and head dim % 8 == 0.""" + if jax.default_backend() == "cpu": + pytest.skip(reason="ROCm function needs GPU.") + q, k, v, bias = generate_attention_data( + batch_size, + seq_len, + kv_seq_len, + num_heads, + per_head_dim, + mask_fn=causal_mask if causal else None, + dtype=dtype, + ) + + cfg = dict( + softmax_scale=q.shape[-1] ** -0.5, + ) + + # Compare outputs. + test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate() + ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate() + chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True) + + _test_forward_and_backward( + q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=_cudnn_xla_forward_tol_fn + ) diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index e23d7d178..da889458f 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -11,6 +11,7 @@ CuDNNGPUFlashAttention, CuDNNGPUFlashAttentionWithExplicitBias, PallasGPUFlashAttention, + ROCmTransformerEngineFlashAttention, ) from axlearn.common.flash_attention.gpu_decoding import GPUDecoding from axlearn.common.flash_attention.tpu_attention import LegacyTPUFlashAttention, TPUSplashAttention @@ -24,6 +25,7 @@ gpu=[ GPUDecoding, # For GPU, prefer cuDNN (without bias) whenever possible, as it's the fastest. + ROCmTransformerEngineFlashAttention, CuDNNGPUFlashAttention, # Fallbacks to Pallas if cuDNN cannot be used without instantiating bias tensors. PallasGPUFlashAttention, diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index dc95d6110..cbe47237f 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -707,6 +707,24 @@ def get_trainer_kwargs( ], ), ), + ( + "amd-mi300-single-node", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=False, + policy=jax_remat_policies.nothing_saveable + ), + } + ), + ], + ), + ), ), ) else: @@ -857,7 +875,7 @@ def wrapper(config_name: str = config_name): arch=arch, model_size="golden-run-test", version=f"v{version.value}" ) ] = wrapper - if model_size in ("1B", "3B", "7B", "8B"): + if model_size in ("1B", "3B", "7B", "8B", "70B"): def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: """Make a single-host variant of the base config. diff --git a/launch_70B_single_node.sh b/launch_70B_single_node.sh new file mode 100644 index 000000000..15dc1d8ac --- /dev/null +++ b/launch_70B_single_node.sh @@ -0,0 +1,28 @@ +export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH +export XLA_FLAGS="--xla_gpu_enable_cublaslt=True --xla_gpu_graph_level=0 --xla_gpu_autotune_level=0 --xla_gpu_enable_latency_hiding_scheduler=true" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.975 +export HSA_FORCE_FINE_GRAIN_PCIE=1 +export GPU_MAX_HW_QUEUES=2 +export HIP_FORCE_DEV_KERNARG=1 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 +export NVTE_FUSED_ATTN=1 +export NVTE_FUSED_ATTN_CK=1 +export NVTE_FUSED_ATTN_AOTRITON=1 +export NVTE_CK_EXT_ASM=1 +export NVTE_CK_ASM_ATOMIC_FP32=0 +export NVTE_CK_ASM_NO_COEX=0 +export NVTE_CK_ASM_RTZ_CVT=1 +export NVTE_CK_BWD_V3=1 +export NVTE_CK_V3_RTZ_CVT=2 +export NVTE_CK_USES_BWD_V3=1 +export NVTE_CK_IS_V3_ATOMIC_FP32=0 +export NVTE_CK_IS_V3_SPEC=1 +export NVTE_CK_HOW_V3_BF16_CVT=2 + + +mkdir -p /tmp/gpt_c4_test; \ +python3 -m axlearn.common.launch_trainer_main \ + --module=text.gpt.c4_trainer --config=fuji-70B-v2-flash-single-host \ + --trainer_dir=/tmp/gpt_c4_test --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=gpu \ + --mesh_selector="amd-mi300-single-node" \ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c7746af31..63cc66b00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,8 @@ core = [ "absl-py==2.1.0", "chex==0.1.88", "importlab==0.8.1", # breaks pytype on 0.8 - "jax==0.4.38", - "jaxlib==0.4.38", + "jax==0.4.35", + "jaxlib==0.4.35", "ml-dtypes==0.4.1", "msgpack==1.1.0", # for checkpointing. "nltk==3.7", # for text preprocessing