Skip to content

Commit 669d598

Browse files
committed
finished implementing the v0. Also add a test that use 1 query token and verify the extend_paged_attention generate the same result as the original paged_attention.
1 parent 58fe257 commit 669d598

File tree

2 files changed

+129
-25
lines changed

2 files changed

+129
-25
lines changed

test/test_pallas.py

+84
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,90 @@ def test_extended_paged_attention(self):
615615
atol=1e-5,
616616
rtol=1e-5))
617617

618+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
619+
"This test only works on TPUv4+.")
620+
def test_extended_paged_attention_single_query(self):
621+
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention
622+
from torch_xla.experimental.pallas_kernels.extended_paged_attention_kernel0 import paged_attention as jax_extended_paged_attention0
623+
624+
# flash_attn_block_size seems to be the compute block concept
625+
# in flash attn per https://github.com/jax-ml/jax/blob/c6e5530aab9b859056883ccb3c1937259b998af0/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L400-L401
626+
# And pages_per_compute_block seems to be a tunable param in vLLM per
627+
# https://github.com/vllm-project/vllm/blob/f5e1bf5d44877149eaabf9c04379a4e14a023145/vllm/attention/backends/pallas.py#L184
628+
pallas_compute_block_size = 512
629+
batch_size: int = 3
630+
query_len: int = 1
631+
num_query_heads: int = 64
632+
num_kv_heads: int = 8
633+
head_size: int = 128
634+
dtype: torch.dtype = torch.float32
635+
max_kv_len: int = 1024
636+
page_size: int = 64
637+
total_num_pages: int = 32
638+
assert num_query_heads % num_kv_heads == 0
639+
assert query_len <= max_kv_len
640+
assert max_kv_len <= total_num_pages * page_size
641+
642+
q = torch.randn(batch_size, query_len, num_query_heads, head_size, dtype=dtype)
643+
k_pages = torch.randn(num_kv_heads, total_num_pages, page_size, head_size, dtype=dtype)
644+
v_pages = torch.rand_like(k_pages)
645+
kv_seq_lengths = torch.randint(query_len, max_kv_len + 1, (batch_size,))
646+
page_indices = torch.randint(0, total_num_pages, (batch_size, total_num_pages))
647+
648+
# Run the extended_paged_attention with query_len=1
649+
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
650+
assert q_jax.shape==(batch_size, query_len, num_query_heads, head_size), f"Input q_jax has the wrong shape: {q_jax.shape}. Expect {(batch_size, query_len, num_query_heads, head_size)}."
651+
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
652+
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
653+
kv_seq_lens_jax = jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32)
654+
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
655+
actual_output = torch.from_numpy(
656+
np.array(
657+
jax_extended_paged_attention0(
658+
q_jax,
659+
k_pages_jax,
660+
v_pages_jax,
661+
kv_seq_lens_jax,
662+
page_indices_jax,
663+
pages_per_compute_block=pallas_compute_block_size // page_size,
664+
)))
665+
666+
# Run the original paged_attention.
667+
ref_q_jax = jnp.array(q.squeeze().numpy(), dtype=jnp.float32)
668+
assert ref_q_jax.shape==(batch_size, num_query_heads, head_size), f"Input ref_q_jax has the wrong shape: {ref_q_jax.shape}. Expect {(batch_size, num_query_heads, head_size)}."
669+
assert jnp.allclose(q_jax[:,0,...], ref_q_jax)
670+
ref_k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
671+
assert jnp.allclose(k_pages_jax, ref_k_pages_jax)
672+
ref_v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
673+
assert jnp.allclose(v_pages_jax, ref_v_pages_jax)
674+
ref_kv_seq_lens_jax = jnp.array(kv_seq_lengths.numpy(), dtype=jnp.int32)
675+
assert jnp.allclose(kv_seq_lens_jax, ref_kv_seq_lens_jax)
676+
ref_page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
677+
assert jnp.allclose(page_indices_jax, ref_page_indices_jax)
678+
expected_output = torch.from_numpy(
679+
np.array(
680+
jax_paged_attention(
681+
ref_q_jax,
682+
ref_k_pages_jax,
683+
ref_v_pages_jax,
684+
ref_kv_seq_lens_jax,
685+
ref_page_indices_jax,
686+
pages_per_compute_block=pallas_compute_block_size // page_size,
687+
)))
688+
689+
# print(f'{expected_output.cpu()=}')
690+
# print(f'{actual_output.cpu()=}')
691+
expected_output_cpu=expected_output.cpu()
692+
actual_output_cpu=actual_output.cpu()
693+
print(f'Output max diff: {(expected_output_cpu - actual_output_cpu).abs().max().item()}')
694+
print(f'Output mean diff: {(expected_output_cpu - actual_output_cpu).abs().mean().item()}')
695+
self.assertTrue(
696+
torch.allclose(
697+
expected_output_cpu,
698+
actual_output_cpu,
699+
atol=1e-5,
700+
rtol=1e-5))
701+
618702
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4,
619703
"This test only works on TPUv4 and TPUv5p.")
620704
def test_paged_attention_wrapper_with_megacore_modes(self):

torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel0.py

+45-25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
# This is the original paged_attention_kernel copied from
1+
# This is the adapted extended_paged_attention_kernel copied from
22
# https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L400-L401
3-
# Don't review it. Will be removed later.
43

54
# Copyright 2024 The JAX Authors.
65
#
@@ -49,6 +48,7 @@ def __init__(
4948
num_pages_to_load,
5049
head_index,
5150
):
51+
# Original k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim]
5252
self._vmem_buffer = vmem_buffer
5353
self._scales_vmem_buffer = scales_vmem_buffer
5454
self._num_pages_to_load = num_pages_to_load
@@ -113,7 +113,6 @@ def wait_and_get_loaded(self) -> jax.Array:
113113
jax_array = self._maybe_dequantize(jax_array, scales_jax_array)
114114
return jax_array.reshape(-1, head_dim)
115115

116-
117116
def paged_flash_attention_kernel(
118117
lengths_ref,
119118
page_indices_ref,
@@ -142,16 +141,21 @@ def paged_flash_attention_kernel(
142141
program_ids=(),
143142
):
144143
"""Pallas kernel for paged attention."""
145-
if program_ids:
146-
core_index, b, h, i = program_ids
144+
# xw32: core_index, b, h, i=num_cores, batch_size, num_kv_heads, i
145+
# Note the original q.shape=[batch_size, query_len, num_heads, head_dim]
146+
print(f'xw32 line147 {q_ref.shape=}')
147+
if program_ids: # inline_seq_dim case.
148+
core_index, q_idx, b, h, i = program_ids # The 2nd one is q_idx but we don't use it.
147149
else:
148-
core_index, b, h, i = (
150+
core_index, q_idx, b, h, i = (
149151
pl.program_id(0),
150152
pl.program_id(1),
151153
pl.program_id(2),
152154
pl.program_id(3),
155+
pl.program_id(4),
153156
)
154157
num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape
158+
# xw32: bk should be the overall compute block size.
155159
bk = page_size * pages_per_compute_block
156160
num_cores = pl.num_programs(0)
157161

@@ -317,10 +321,16 @@ def paged_flash_attention_kernel_inline_seq_dim(
317321
attn_logits_soft_cap: float | None,
318322
megacore_mode: str | None,
319323
):
320-
core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2)
324+
print(f'xw32 line325 {m_ref.shape=}')
325+
core_index, q_idx, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2), pl.program_id(3)
321326

322327
# Initialize the output HBM buffers to avoid accessing garbage memory inside
323328
# the kernel body below.
329+
# Note, q_block_spec = pl.BlockSpec(
330+
# (None, None, num_heads // num_kv_heads, head_dim), # bs,query_len,num_heads,head_dim
331+
# lambda core_index, q, b, h, *_: (b, q, h, 0),
332+
# )
333+
# m_ref,l_ref,o_ref has out_specs=q_block_spec
324334
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
325335
l_ref[...] = jnp.zeros_like(l_ref)
326336
o_ref[...] = jnp.zeros_like(o_ref)
@@ -332,7 +342,7 @@ def body(i, _):
332342
buffer_index_ref,
333343
step_ref,
334344
q_ref,
335-
k_pages_hbm_ref, #
345+
k_pages_hbm_ref,
336346
k_scales_pages_hbm_ref,
337347
v_pages_hbm_ref,
338348
v_scales_pages_hbm_ref,
@@ -350,11 +360,13 @@ def body(i, _):
350360
mask_value=mask_value,
351361
attn_logits_soft_cap=attn_logits_soft_cap,
352362
megacore_mode=megacore_mode,
353-
program_ids=(core_index, b, h, i),
363+
program_ids=(core_index, q_idx, b, h, i),
354364
)
355365
return ()
356366

357-
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2]
367+
# xw32: nb num_kv_heads, _, page_size, head_dim_k = k_pages.shape
368+
# so k_pages_hbm_ref.shape[-2] is the page_size.
369+
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2] # The accumulated page sizes for all the pages in this compute block.
358370

359371
if megacore_mode == "batch":
360372
num_cores = pl.num_programs(0)
@@ -391,7 +403,7 @@ def paged_attention(
391403
"""Paged grouped query attention.
392404
393405
Args:
394-
q: A [batch_size, num_heads, head_dim] jax.Array.
406+
q: A [batch_size, query_len, num_heads, head_dim] jax.Array.
395407
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
396408
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
397409
lengths: A i32[batch_size] jax.Array the length of each example.
@@ -416,9 +428,8 @@ def paged_attention(
416428
one kernel.
417429
418430
Returns:
419-
The output of attention([batch_size, num_heads, head_dim]).
431+
The output of attention([batch_size, query_len, num_heads, head_dim]).
420432
"""
421-
# return jnp.zeros_like(q, q.dtype)
422433
if isinstance(k_pages, quantization_utils.QuantizedTensor):
423434
k_pages, k_scales_pages = k_pages.weight, k_pages.scales
424435
assert isinstance(k_scales_pages, jax.Array) # For typing.
@@ -436,7 +447,8 @@ def paged_attention(
436447
else:
437448
v_scales_pages = None
438449

439-
batch_size, num_heads, head_dim = q.shape
450+
# TODO(xw32): consider renaming num_heads to num_query_heads
451+
batch_size, query_len, num_heads, head_dim = q.shape
440452
num_kv_heads, _, page_size, head_dim_k = k_pages.shape
441453
batch_size_paged_indices, pages_per_sequence = page_indices.shape
442454

@@ -486,6 +498,7 @@ def paged_attention(
486498
raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]")
487499

488500
if (num_heads // num_kv_heads) % 8 != 0:
501+
# TODO(xw32):add the query_len dim to this branch later.
489502
# Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a
490503
# <8x128> layout for a <1x128> memref inside the kernel and error out.
491504
q = q.reshape(batch_size, num_heads, 1, head_dim)
@@ -507,46 +520,53 @@ def paged_attention(
507520
q_dtype_for_kernel_launch = jnp.float32
508521
else:
509522
if megacore_mode == "kv_head":
510-
# q.shape=[batch_size, num_heads, head_dim]
523+
# q.shape=[batch_size, query_len, num_heads, head_dim]
524+
# xw32q: The way it chunks the `num_heads` dimension (num_heads // num_kv_heads),
525+
# does it mean it is a MQA?
511526
q_block_spec = pl.BlockSpec(
512-
(None, num_heads // num_kv_heads, head_dim),
513-
lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0),
527+
(None, None, num_heads // num_kv_heads, head_dim),
528+
lambda core_index, q, b, h, *_: (b, q, h * num_cores + core_index, 0),
514529
)
515530
elif megacore_mode == "batch":
516531
q_block_spec = pl.BlockSpec(
517-
(None, num_heads // num_kv_heads, head_dim),
518-
lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0),
532+
(None, None, num_heads // num_kv_heads, head_dim),
533+
lambda core_index, q, b, h, *_: (b * num_cores + core_index, q, h, 0),
519534
)
520535
else:
536+
# Here, if (num_heads // num_kv_heads)%8==0 and megacore_mode is None and inline_seq_dim == True, then
537+
# grid=[num_cores, query_len, batch_size, num_kv_heads]
521538
q_block_spec = pl.BlockSpec(
522-
(None, num_heads // num_kv_heads, head_dim),
523-
lambda core_index, b, h, *_: (b, h, 0),
539+
(None, None, num_heads // num_kv_heads, head_dim),
540+
lambda core_index, q, b, h, *_: (b, q, h, 0),
524541
)
525542
q_dtype_for_kernel_launch = q.dtype
526543

527544
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]]
528545
if inline_seq_dim:
529546
kernel = paged_flash_attention_kernel_inline_seq_dim
547+
# query_len goes before batch_size and num_kv_heads so the flash_attention kernel doesn't need to be changed.
530548
grid = (
531549
num_cores,
550+
query_len,
532551
batch_size // num_cores if megacore_mode == "batch" else batch_size,
533552
num_kv_heads // num_cores
534553
if megacore_mode == "kv_head"
535554
else num_kv_heads,
536555
)
537556
# xw32q: shouldn't batch dim and kv_heads dim be parallel?
538-
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
557+
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
539558
else:
540559
kernel = paged_flash_attention_kernel
541560
grid = (
542561
num_cores,
562+
query_len,
543563
batch_size // num_cores if megacore_mode == "batch" else batch_size,
544564
num_kv_heads // num_cores
545565
if megacore_mode == "kv_head"
546566
else num_kv_heads,
547567
pages_per_sequence // pages_per_compute_block,
548568
) # type: ignore
549-
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
569+
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary", "arbitrary")
550570

551571
if k_scales_pages is not None and v_scales_pages is not None:
552572
in_specs = [
@@ -597,7 +617,7 @@ def paged_attention(
597617
), # v_scales_pages buffer
598618
pltpu.SemaphoreType.DMA,
599619
)
600-
else:
620+
else: # either k_scales_pages or v_scales_pages is None.
601621
in_specs = [
602622
q_block_spec,
603623
# Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, etc.
@@ -672,4 +692,4 @@ def paged_attention(
672692
v_pages,
673693
v_scales_pages,
674694
)
675-
return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype)
695+
return out.reshape(batch_size, query_len, num_heads, head_dim).astype(q.dtype)

0 commit comments

Comments
 (0)