-
Notifications
You must be signed in to change notification settings - Fork 398
Add ROCm support #1112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Add ROCm support #1112
Changes from 1 commit
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
0e4ffe5
enable attention via Transformer Engine for ROCm support
akasharidas 14f0606
add 70B config and launch script for MI300
akasharidas 44c1ab9
set JAX version to 0.4.35
akasharidas 5377fea
add basic docstring to ROCmTransformerEngineFlashAttention
akasharidas 15a4f47
add unit tests for ROCm TE layer
akasharidas ee4820c
add check for ROCm backend
akasharidas 9b37416
support sliding window attention in ROCm TE path
akasharidas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
208 changes: 208 additions & 0 deletions
208
axlearn/common/flash_attention/gpu_attention_rocm_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| # Copyright © 2023 Apple Inc. | ||
| # | ||
| # Some of the code in this file is adapted from: | ||
| # | ||
| # jax-ml/jax-triton: | ||
| # Copyright 2023 The jax_triton Authors. | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). | ||
|
|
||
| """Tests GPU FlashAttention kernels. | ||
|
|
||
| Currently tested on MI300. To run tests in parallel on a multi-GPU machine, use this: | ||
| ``` | ||
| PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file path needs to be revised to the current file. |
||
| ``` | ||
| """ | ||
| import functools | ||
| from typing import Any, Callable, Literal, Optional | ||
|
|
||
| import chex | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import jax.random | ||
| import pytest | ||
|
|
||
| from axlearn.common.attention_bias import ( | ||
| CausalAttentionBias, | ||
| MaskFn, | ||
| ZeroAttentionBias, | ||
| causal_mask, | ||
| ) | ||
| from axlearn.common.flash_attention.common import ReferenceMHA | ||
| from axlearn.common.flash_attention.gpu_attention import ( | ||
| ROCmTransformerEngineFlashAttention, | ||
| ) | ||
| from axlearn.common.flash_attention.test_utils import generate_attention_data | ||
| from axlearn.common.utils import Tensor | ||
|
|
||
| if jax.default_backend() not in ("gpu", "cpu"): | ||
| pytest.skip(reason="Incompatible hardware", allow_module_level=True) | ||
|
|
||
|
|
||
| def _default_tol_fn(backend, dtype): | ||
| del backend | ||
| if dtype == jnp.bfloat16: | ||
| return dict(atol=0.05, rtol=1e-2) | ||
| if dtype == jnp.float16: | ||
| return dict(atol=0.05, rtol=1e-5) | ||
| if dtype == jnp.float32: | ||
| return dict(atol=0.025, rtol=1e-5) | ||
| raise ValueError(f"Unsupported dtype: {dtype}") | ||
|
|
||
|
|
||
| TestFn = Callable[[Tensor, Tensor, Tensor], Tensor] | ||
| TolFn = Callable[[str, Any], dict[str, float]] | ||
|
|
||
|
|
||
| def _test_forward_and_backward( | ||
| q: Tensor, | ||
| k: Tensor, | ||
| v: Tensor, | ||
| bias, | ||
| *, | ||
| ref_fn: TestFn, | ||
| test_fn: TestFn, | ||
| forward_tol_fn: Callable = _default_tol_fn, | ||
| backward_tol_fn: Callable = _default_tol_fn, | ||
| ): | ||
| ref_fn = jax.jit(ref_fn) | ||
| test_fn = jax.jit(test_fn) | ||
| prng_key = jax.random.PRNGKey(44) | ||
| jax_out = test_fn(q, k, v, bias, prng_key) | ||
| jax_ref_out = ref_fn(q, k, v, bias, prng_key) | ||
| backend = jax.default_backend() | ||
| chex.assert_trees_all_close(jax_out, jax_ref_out, **forward_tol_fn(backend, q.dtype)) | ||
|
|
||
| # Compare gradients. | ||
| jax_grads = jax.grad(lambda *args: ref_fn(*args).mean(), argnums=(0, 1, 2))( | ||
| q, k, v, bias, prng_key | ||
| ) | ||
| jax_ref_grads = jax.grad(lambda *args: test_fn(*args).mean(), argnums=(0, 1, 2))( | ||
| q, k, v, bias, prng_key | ||
| ) | ||
| chex.assert_trees_all_close(jax_grads, jax_ref_grads, **backward_tol_fn(backend, q.dtype)) | ||
|
|
||
|
|
||
| def common_attn_test_params(func): | ||
| params = [ | ||
| pytest.mark.parametrize("kv_len", [None, 512]), | ||
| pytest.mark.parametrize("dropout_rate", [0, 0.1]), | ||
| pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]), | ||
| pytest.mark.parametrize("with_segment_ids", [True, False]), | ||
| pytest.mark.parametrize("block_size", [128]), # Triton broken for block size !=128. | ||
| pytest.mark.parametrize("mask_fn", [causal_mask, None]), | ||
| pytest.mark.parametrize("dtype", [jnp.float16, jnp.float32]), | ||
| ] | ||
| # Apply in reverse order to stack correctly. | ||
| for param in reversed(params): | ||
| func = param(func) | ||
| return func | ||
|
|
||
|
|
||
| # We test the ROCm TE DotProductAttention against the reference flash_attention. | ||
| # Due to its algorithmic equivalence, the outputs should be close in both fp16 and bfloat16. | ||
| @pytest.mark.parametrize( | ||
| "batch_size,num_heads,seq_len,per_head_dim", | ||
| [ | ||
| (1, 2, 1024, 128), | ||
| (2, 2, 1024, 128), | ||
| (1, 4, 2048, 128), | ||
| (2, 8, 2048, 128), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("causal", [True, False]) | ||
| @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) | ||
| def test_rocmte_against_xla_ref( | ||
| batch_size: int, | ||
| num_heads: int, | ||
| seq_len: int, | ||
| per_head_dim: int, | ||
| causal: bool, | ||
| dtype: jnp.dtype, | ||
| ): | ||
| if jax.default_backend() == "cpu": | ||
| pytest.skip(reason="ROCm function needs GPU.") | ||
|
|
||
| q, k, v, bias = generate_attention_data( | ||
| batch_size, | ||
| seq_len, | ||
| seq_len, | ||
| num_heads, | ||
| per_head_dim, | ||
| mask_fn=causal_mask if causal else None, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| cfg = dict( | ||
| softmax_scale=q.shape[-1] ** -0.5, | ||
| ) | ||
|
|
||
| # Compare outputs. | ||
| test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate() | ||
| chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True) | ||
| ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate() | ||
|
|
||
| def forward_tol_fn(backend, dtype): | ||
| del backend | ||
| if dtype == jnp.bfloat16: | ||
| return dict(atol=0.02, rtol=1e-5) | ||
| if dtype == jnp.float16: | ||
| return dict(atol=0.005, rtol=1e-5) | ||
|
|
||
| _test_forward_and_backward( | ||
| q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=forward_tol_fn | ||
| ) | ||
|
|
||
|
|
||
| def _cudnn_xla_forward_tol_fn(backend, dtype): | ||
| del backend | ||
| # cuDNN has higher diff when compared to non-fused attention in XLA. | ||
| if dtype == jnp.bfloat16: | ||
| return dict(atol=0.25, rtol=1e-3) | ||
| if dtype == jnp.float16: | ||
| return dict(atol=0.05, rtol=1e-3) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "batch_size,num_heads,seq_len,kv_seq_len,per_head_dim", | ||
| [ | ||
| (1, 1, 378, 676, 72), | ||
| (2, 4, 582, 582, 56), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("causal", [True, False]) | ||
| @pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) | ||
| def test_rocmte_seqlen_head_support( | ||
| batch_size: int, | ||
| num_heads: int, | ||
| seq_len: int, | ||
| kv_seq_len: int, | ||
| per_head_dim: int, | ||
| causal: bool, | ||
| dtype: jnp.dtype, | ||
| ): | ||
| """Tests that ROCm TE supports any even sequence length and head dim % 8 == 0.""" | ||
| if jax.default_backend() == "cpu": | ||
| pytest.skip(reason="ROCm function needs GPU.") | ||
| q, k, v, bias = generate_attention_data( | ||
| batch_size, | ||
| seq_len, | ||
| kv_seq_len, | ||
| num_heads, | ||
| per_head_dim, | ||
| mask_fn=causal_mask if causal else None, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| cfg = dict( | ||
| softmax_scale=q.shape[-1] ** -0.5, | ||
| ) | ||
|
|
||
| # Compare outputs. | ||
| test_fn = ROCmTransformerEngineFlashAttention.default_config().set(**cfg).instantiate() | ||
| ref_fn = ReferenceMHA.default_config().set(**cfg).instantiate() | ||
| chex.assert_equal(test_fn.is_supported(query=q, key=k, value=v, bias=bias), True) | ||
|
|
||
| _test_forward_and_backward( | ||
| q, k, v, bias, ref_fn=ref_fn, test_fn=test_fn, forward_tol_fn=_cudnn_xla_forward_tol_fn | ||
| ) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And feel free to make it AMD if you intend to maintain it.