Skip to content

Commit

Permalink
Extend paged attention to support query_len>1 (#8328)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 authored Oct 31, 2024
1 parent efee3bc commit 1bac062
Show file tree
Hide file tree
Showing 6 changed files with 1,130 additions and 3 deletions.
180 changes: 179 additions & 1 deletion test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _pagedattention_generate_qkv(
num_heads,
head_dim,
dtype=torch.float32,
query_len=None,
):
assert max_seq_len % page_size == 0
pages_per_sequence = max_seq_len // page_size
Expand All @@ -67,7 +68,10 @@ def _pagedattention_generate_qkv(
page_indices = torch.randperm(
batch_size * pages_per_sequence, dtype=torch.int32)
page_indices = page_indices.reshape(batch_size, pages_per_sequence)
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
if not query_len:
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
else:
q = torch.randn(batch_size, query_len, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
Expand Down Expand Up @@ -547,6 +551,180 @@ def test_paged_attention_wrapper(self):
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_multi_queries_wrapper(self):
from torch_xla.experimental.custom_kernel import multi_queries_paged_attention
from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention as jax_multi_queries_paged_attention

dtype = torch.float32
page_size = 16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
num_queries_per_compute_block = 32
block_kv_size = 256

max_kv_len = 2048
query_len = 64
kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32)
assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
batch_size = len(kv_seq_lens)
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
query_len=query_len,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
kv_seq_lens_xla = kv_seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

output = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)

nonkernel_output = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
)

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
kv_seq_lens_jax = jnp.array(kv_seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_multi_queries_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
kv_seq_lens_jax,
page_indices_jax,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)))

self.assertTrue(
torch.allclose(
output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))
self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_multi_queries_wrapper_with_dynamo(self):
from torch_xla.experimental.custom_kernel import multi_queries_paged_attention
from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention as jax_multi_queries_paged_attention

dtype = torch.float32
page_size = 16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
num_queries_per_compute_block = 32
block_kv_size = 256

max_kv_len = 2048
query_len = 64
kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32)
assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
batch_size = len(kv_seq_lens)
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
query_len=query_len,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
kv_seq_lens_xla = kv_seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel):
return torch.ops.xla.multi_queries_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=use_kernel,
)

compiled_paged_attention = torch.compile(
multi_queries_paged_attention_wrapper, backend="openxla")

output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=True,
)

nonkernel_output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
)

self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4,
"This test only works on TPUv4 and TPUv5p.")
def test_paged_attention_wrapper_with_megacore_modes(self):
Expand Down
187 changes: 187 additions & 0 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention
import jax.numpy as jnp
import numpy as np

jax.config.parse_flags_with_absl()


# Set up paged_attention inputs.
def _generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_q_heads,
head_dim,
prng_key,
dtype,
):
assert max_kv_len % page_size == 0
pages_per_sequence = max_kv_len // page_size
batch_size = len(kv_seq_lens)
total_pages = batch_size * pages_per_sequence
k1, k2, k3, k4 = jax.random.split(prng_key, 4)
k_pages = jax.random.normal(
k1, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype)
v_pages = jax.random.normal(
k2, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype)

page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32)
page_indices = jax.random.permutation(k3, page_indices, independent=True)
page_indices = page_indices.reshape(batch_size, pages_per_sequence)
q = jax.random.normal(
k4, (batch_size, query_len, num_q_heads, head_dim), dtype=dtype)
return q, k_pages, v_pages, page_indices


def _ref_jax_extended_paged_attention(
q, # [batch_size, query_len, num_query_heads, head_size]
k_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
v_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
lengths, # [batch_size]
page_indices, # [batch_size, pages_per_sequence]
):
batch_size, query_len, num_query_heads, head_size = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
num_query_per_kv = num_query_heads // num_kv_heads

outputs = []
for i in range(batch_size):
kv_len = lengths[i]
num_pages = (kv_len + page_size - 1) // page_size
indices = page_indices[i, :num_pages]

k = k_pages[:, indices]
k = jnp.permute_dims(k, (1, 2, 0, 3))
k = jnp.reshape(k, (num_pages * page_size, num_kv_heads, head_size))
k = k[:kv_len]

v = v_pages[:, indices]
v = jnp.permute_dims(v, (1, 2, 0, 3))
v = jnp.reshape(v, (num_pages * page_size, num_kv_heads, head_size))
v = v[:kv_len]

if num_query_per_kv != 1:
k = jnp.repeat(k, num_query_per_kv, axis=1)
v = jnp.repeat(v, num_query_per_kv, axis=1)

attn = jnp.einsum("qhd,khd->hqk", q[i], k)
attn = attn.astype('float32')
q_span = (kv_len - query_len) + jax.lax.broadcasted_iota(
jnp.int32, (query_len, kv_len), 0)
kv_span = jax.lax.broadcasted_iota(jnp.int32, (query_len, kv_len), 1)
mask = jnp.where(q_span < kv_span, float("-inf"), 0.)
with jax.numpy_rank_promotion("allow"):
attn = attn + mask
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
output = jnp.stack(outputs, axis=0)
return output


@jtu.with_config(jax_numpy_dtype_promotion="standard")
class PagedAttentionKernelTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()


# def test_paged_attention(
# self,
# ):
# dtype = jnp.bfloat16
# page_size=16
# num_kv_heads = 8
# q_kv_head_ratio = 4
# head_dim = 256
# num_queries_per_compute_block = 32
# block_kv_size = 256

@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
page_size=(16, 32, 64),
num_kv_heads=(1, 8),
q_kv_head_ratio=(1, 4, 8),
head_dim=(128, 256),
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
)
def test_paged_attention(
self,
dtype,
page_size,
num_kv_heads,
q_kv_head_ratio,
head_dim,
num_queries_per_compute_block,
block_kv_size,
):

max_kv_len = 2048
query_len = 64
kv_seq_lens = jax.random.randint(
jax.random.key(0), (3,), query_len, max_kv_len)

assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
batch_size = len(kv_seq_lens)
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size

q, k_pages, v_pages, page_indices = _generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
jax.random.key(0),
dtype,
)

print(f'Running paged_attention with {query_len=}')
num_kv_pages_per_compute_block = block_kv_size // page_size
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
)
# actual_output = jax.block_until_ready(actual_output)

# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
)

self.assertEqual(actual_output.shape, expected_output.shape)

if dtype == jnp.float32:
atol = 1e-2
rtol = 1e-2
elif dtype == jnp.bfloat16:
atol = 6e-1
rtol = 1e-1
else:
self.fail(f'Unsupported dtype: {dtype}')
self.assertTrue(
jnp.allclose(expected_output, actual_output, atol=atol, rtol=rtol))

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
3 changes: 2 additions & 1 deletion test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/test_scan.py
python3 test/test_pallas.py
python3 test/test_pallas.py -v
python3 test/test_pallas_spmd.py
python3 test/test_tpu_paged_attention_kernel.py
python3 test/test_input_output_aliases.py
python3 test/test_gmm.py
python3 test/eager/test_eager_spmd.py
Expand Down
Loading

0 comments on commit 1bac062

Please sign in to comment.