Skip to content

Commit c8a012b

Browse files
committed
finished implementing the v0. Also add a test that use 1 query token and verify the extend_paged_attention generate the same result as the original paged_attention.
1 parent 58fe257 commit c8a012b

File tree

2 files changed

+119
-23
lines changed

2 files changed

+119
-23
lines changed

test/test_pallas.py

+77
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,83 @@ def test_extended_paged_attention(self):
615615
atol=1e-5,
616616
rtol=1e-5))
617617

618+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
619+
"This test only works on TPUv4+.")
620+
def test_extended_paged_attention_single_query(self):
621+
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention
622+
from torch_xla.experimental.pallas_kernels.extended_paged_attention_kernel0 import paged_attention as jax_extended_paged_attention0
623+
624+
# flash_attn_block_size seems to be the compute block concept
625+
# in flash attn per https://github.com/jax-ml/jax/blob/c6e5530aab9b859056883ccb3c1937259b998af0/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L400-L401
626+
# And pages_per_compute_block seems to be a tunable param in vLLM per
627+
# https://github.com/vllm-project/vllm/blob/f5e1bf5d44877149eaabf9c04379a4e14a023145/vllm/attention/backends/pallas.py#L184
628+
pallas_compute_block_size = 512
629+
batch_size: int = 3
630+
query_len: int = 1
631+
num_query_heads: int = 64
632+
num_kv_heads: int = 8
633+
head_size: int = 128
634+
dtype: torch.dtype = torch.float32
635+
max_kv_len: int = 1024
636+
page_size: int = 64
637+
total_num_pages: int = 32
638+
assert num_query_heads % num_kv_heads == 0
639+
assert query_len <= max_kv_len
640+
assert max_kv_len <= total_num_pages * page_size
641+
642+
q = torch.randn(batch_size, query_len, num_query_heads, head_size, dtype=dtype)
643+
k_pages = torch.randn(num_kv_heads, total_num_pages, page_size, head_size, dtype=dtype)
644+
v_pages = torch.rand_like(k_pages)
645+
kv_seq_lengths = torch.randint(query_len, max_kv_len + 1, (batch_size,))
646+
page_indices = torch.randint(0, total_num_pages, (batch_size, total_num_pages))
647+
648+
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
649+
assert q_jax.shape==(batch_size, query_len, num_query_heads, head_size), f"Input q_jax has the wrong shape: {q_jax.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
650+
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
651+
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
652+
kv_seq_lens_jax = jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32)
653+
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
654+
actual_output = torch.from_numpy(
655+
np.array(
656+
jax_extended_paged_attention0(
657+
q_jax,
658+
k_pages_jax,
659+
v_pages_jax,
660+
kv_seq_lens_jax,
661+
page_indices_jax,
662+
pages_per_compute_block=pallas_compute_block_size // page_size,
663+
)))
664+
665+
ref_q_jax = jnp.array(q.squeeze().numpy(), dtype=jnp.float32)
666+
assert ref_q_jax.shape==(batch_size, num_query_heads, head_size), f"Input ref_q_jax has the wrong shape: {ref_q_jax.shape}. Expect {(batch_size, num_query_heads, head_size)}."
667+
ref_k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
668+
ref_v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
669+
ref_kv_seq_lens_jax = jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32)
670+
ref_page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
671+
expected_output = torch.from_numpy(
672+
np.array(
673+
jax_paged_attention(
674+
ref_q_jax,
675+
ref_k_pages_jax,
676+
ref_v_pages_jax,
677+
ref_kv_seq_lens_jax,
678+
ref_page_indices_jax,
679+
pages_per_compute_block=pallas_compute_block_size // page_size,
680+
)))
681+
682+
# print(f'{expected_output.cpu()=}')
683+
# print(f'{actual_output.cpu()=}')
684+
expected_output_cpu=expected_output.cpu()
685+
actual_output_cpu=actual_output.cpu()
686+
print(f'Output max diff: {(expected_output_cpu - actual_output_cpu).abs().max().item()}')
687+
print(f'Output mean diff: {(expected_output_cpu - actual_output_cpu).abs().mean().item()}')
688+
self.assertTrue(
689+
torch.allclose(
690+
expected_output_cpu,
691+
actual_output_cpu,
692+
atol=1e-5,
693+
rtol=1e-5))
694+
618695
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4,
619696
"This test only works on TPUv4 and TPUv5p.")
620697
def test_paged_attention_wrapper_with_megacore_modes(self):

torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel0.py

+42-23
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
num_pages_to_load,
5050
head_index,
5151
):
52+
# Original k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim]
5253
self._vmem_buffer = vmem_buffer
5354
self._scales_vmem_buffer = scales_vmem_buffer
5455
self._num_pages_to_load = num_pages_to_load
@@ -113,7 +114,6 @@ def wait_and_get_loaded(self) -> jax.Array:
113114
jax_array = self._maybe_dequantize(jax_array, scales_jax_array)
114115
return jax_array.reshape(-1, head_dim)
115116

116-
117117
def paged_flash_attention_kernel(
118118
lengths_ref,
119119
page_indices_ref,
@@ -142,16 +142,21 @@ def paged_flash_attention_kernel(
142142
program_ids=(),
143143
):
144144
"""Pallas kernel for paged attention."""
145-
if program_ids:
146-
core_index, b, h, i = program_ids
145+
# xw32: core_index, b, h, i=num_cores, batch_size, num_kv_heads, i
146+
# Note the original q.shape=[batch_size, query_len, num_heads, head_dim]
147+
print(f'xw32 line147 {q_ref.shape=}')
148+
if program_ids: # inline_seq_dim case.
149+
core_index, q_idx, b, h, i = program_ids # The 2nd one is q_idx but we don't use it.
147150
else:
148-
core_index, b, h, i = (
151+
core_index, q_idx, b, h, i = (
149152
pl.program_id(0),
150153
pl.program_id(1),
151154
pl.program_id(2),
152155
pl.program_id(3),
156+
pl.program_id(4),
153157
)
154158
num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape
159+
# xw32: bk should be the overall compute block size.
155160
bk = page_size * pages_per_compute_block
156161
num_cores = pl.num_programs(0)
157162

@@ -317,10 +322,16 @@ def paged_flash_attention_kernel_inline_seq_dim(
317322
attn_logits_soft_cap: float | None,
318323
megacore_mode: str | None,
319324
):
320-
core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2)
325+
print(f'xw32 line325 {m_ref.shape=}')
326+
core_index, q_idx, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2), pl.program_id(3)
321327

322328
# Initialize the output HBM buffers to avoid accessing garbage memory inside
323329
# the kernel body below.
330+
# Note, q_block_spec = pl.BlockSpec(
331+
# (None, None, num_heads // num_kv_heads, head_dim), # bs,query_len,num_heads,head_dim
332+
# lambda core_index, q, b, h, *_: (b, q, h, 0),
333+
# )
334+
# m_ref,l_ref,o_ref has out_specs=q_block_spec
324335
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
325336
l_ref[...] = jnp.zeros_like(l_ref)
326337
o_ref[...] = jnp.zeros_like(o_ref)
@@ -332,7 +343,7 @@ def body(i, _):
332343
buffer_index_ref,
333344
step_ref,
334345
q_ref,
335-
k_pages_hbm_ref, #
346+
k_pages_hbm_ref,
336347
k_scales_pages_hbm_ref,
337348
v_pages_hbm_ref,
338349
v_scales_pages_hbm_ref,
@@ -350,11 +361,13 @@ def body(i, _):
350361
mask_value=mask_value,
351362
attn_logits_soft_cap=attn_logits_soft_cap,
352363
megacore_mode=megacore_mode,
353-
program_ids=(core_index, b, h, i),
364+
program_ids=(core_index, q_idx, b, h, i),
354365
)
355366
return ()
356367

357-
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2]
368+
# xw32: nb num_kv_heads, _, page_size, head_dim_k = k_pages.shape
369+
# so k_pages_hbm_ref.shape[-2] is the page_size.
370+
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2] # The accumulated page sizes for all the pages in this compute block.
358371

359372
if megacore_mode == "batch":
360373
num_cores = pl.num_programs(0)
@@ -391,7 +404,7 @@ def paged_attention(
391404
"""Paged grouped query attention.
392405
393406
Args:
394-
q: A [batch_size, num_heads, head_dim] jax.Array.
407+
q: A [batch_size, query_len, num_heads, head_dim] jax.Array.
395408
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
396409
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
397410
lengths: A i32[batch_size] jax.Array the length of each example.
@@ -416,9 +429,8 @@ def paged_attention(
416429
one kernel.
417430
418431
Returns:
419-
The output of attention([batch_size, num_heads, head_dim]).
432+
The output of attention([batch_size, query_len, num_heads, head_dim]).
420433
"""
421-
# return jnp.zeros_like(q, q.dtype)
422434
if isinstance(k_pages, quantization_utils.QuantizedTensor):
423435
k_pages, k_scales_pages = k_pages.weight, k_pages.scales
424436
assert isinstance(k_scales_pages, jax.Array) # For typing.
@@ -436,7 +448,8 @@ def paged_attention(
436448
else:
437449
v_scales_pages = None
438450

439-
batch_size, num_heads, head_dim = q.shape
451+
# TODO(xw32): consider renaming num_heads to num_query_heads
452+
batch_size, query_len, num_heads, head_dim = q.shape
440453
num_kv_heads, _, page_size, head_dim_k = k_pages.shape
441454
batch_size_paged_indices, pages_per_sequence = page_indices.shape
442455

@@ -486,6 +499,7 @@ def paged_attention(
486499
raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]")
487500

488501
if (num_heads // num_kv_heads) % 8 != 0:
502+
# TODO(xw32):add the query_len dim to this branch later.
489503
# Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a
490504
# <8x128> layout for a <1x128> memref inside the kernel and error out.
491505
q = q.reshape(batch_size, num_heads, 1, head_dim)
@@ -507,46 +521,51 @@ def paged_attention(
507521
q_dtype_for_kernel_launch = jnp.float32
508522
else:
509523
if megacore_mode == "kv_head":
510-
# q.shape=[batch_size, num_heads, head_dim]
524+
# q.shape=[batch_size, query_len, num_heads, head_dim]
525+
# xw32q: The way it chunks the `num_heads` dimension (num_heads // num_kv_heads),
526+
# does it mean it is a MQA?
511527
q_block_spec = pl.BlockSpec(
512-
(None, num_heads // num_kv_heads, head_dim),
513-
lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0),
528+
(None, None, num_heads // num_kv_heads, head_dim),
529+
lambda core_index, q, b, h, *_: (b, q, h * num_cores + core_index, 0),
514530
)
515531
elif megacore_mode == "batch":
516532
q_block_spec = pl.BlockSpec(
517-
(None, num_heads // num_kv_heads, head_dim),
518-
lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0),
533+
(None, None, num_heads // num_kv_heads, head_dim),
534+
lambda core_index, q, b, h, *_: (b * num_cores + core_index, q, h, 0),
519535
)
520536
else:
521537
q_block_spec = pl.BlockSpec(
522-
(None, num_heads // num_kv_heads, head_dim),
523-
lambda core_index, b, h, *_: (b, h, 0),
538+
(None, None, num_heads // num_kv_heads, head_dim),
539+
lambda core_index, q, b, h, *_: (b, q, h, 0),
524540
)
525541
q_dtype_for_kernel_launch = q.dtype
526542

527543
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]]
528544
if inline_seq_dim:
529545
kernel = paged_flash_attention_kernel_inline_seq_dim
546+
# query_len goes before batch_size and num_kv_heads so the flash_attention kernel doesn't need to be changed.
530547
grid = (
531548
num_cores,
549+
query_len,
532550
batch_size // num_cores if megacore_mode == "batch" else batch_size,
533551
num_kv_heads // num_cores
534552
if megacore_mode == "kv_head"
535553
else num_kv_heads,
536554
)
537555
# xw32q: shouldn't batch dim and kv_heads dim be parallel?
538-
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
556+
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
539557
else:
540558
kernel = paged_flash_attention_kernel
541559
grid = (
542560
num_cores,
561+
query_len,
543562
batch_size // num_cores if megacore_mode == "batch" else batch_size,
544563
num_kv_heads // num_cores
545564
if megacore_mode == "kv_head"
546565
else num_kv_heads,
547566
pages_per_sequence // pages_per_compute_block,
548567
) # type: ignore
549-
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
568+
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary", "arbitrary")
550569

551570
if k_scales_pages is not None and v_scales_pages is not None:
552571
in_specs = [
@@ -597,7 +616,7 @@ def paged_attention(
597616
), # v_scales_pages buffer
598617
pltpu.SemaphoreType.DMA,
599618
)
600-
else:
619+
else: # either k_scales_pages or v_scales_pages is None.
601620
in_specs = [
602621
q_block_spec,
603622
# Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, etc.
@@ -672,4 +691,4 @@ def paged_attention(
672691
v_pages,
673692
v_scales_pages,
674693
)
675-
return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype)
694+
return out.reshape(batch_size, query_len, num_heads, head_dim).astype(q.dtype)

0 commit comments

Comments
 (0)