Skip to content

Commit 58fe257

Browse files
committed
Implementing kernel v0
1 parent 43c2bf0 commit 58fe257

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

torch_xla/experimental/custom_kernel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def extended_paged_attention(
506506
attn_logits_soft_cap,
507507
)
508508

509-
from torch_xla.experimental.pallas_kernels.extended_paged_attention_kernel import paged_attention
509+
from torch_xla.experimental.pallas_kernels.extended_paged_attention_kernel0 import paged_attention
510510

511511
assert megacore_mode in [
512512
"kv_head", "batch", None

torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel.py torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel0.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def body(i, _):
332332
buffer_index_ref,
333333
step_ref,
334334
q_ref,
335-
k_pages_hbm_ref,
335+
k_pages_hbm_ref, #
336336
k_scales_pages_hbm_ref,
337337
v_pages_hbm_ref,
338338
v_scales_pages_hbm_ref,
@@ -418,7 +418,7 @@ def paged_attention(
418418
Returns:
419419
The output of attention([batch_size, num_heads, head_dim]).
420420
"""
421-
return jnp.zeros_like(q, jnp.int32)
421+
# return jnp.zeros_like(q, q.dtype)
422422
if isinstance(k_pages, quantization_utils.QuantizedTensor):
423423
k_pages, k_scales_pages = k_pages.weight, k_pages.scales
424424
assert isinstance(k_scales_pages, jax.Array) # For typing.
@@ -507,6 +507,7 @@ def paged_attention(
507507
q_dtype_for_kernel_launch = jnp.float32
508508
else:
509509
if megacore_mode == "kv_head":
510+
# q.shape=[batch_size, num_heads, head_dim]
510511
q_block_spec = pl.BlockSpec(
511512
(None, num_heads // num_kv_heads, head_dim),
512513
lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0),
@@ -533,6 +534,7 @@ def paged_attention(
533534
if megacore_mode == "kv_head"
534535
else num_kv_heads,
535536
)
537+
# xw32q: shouldn't batch dim and kv_heads dim be parallel?
536538
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
537539
else:
538540
kernel = paged_flash_attention_kernel
@@ -549,12 +551,14 @@ def paged_attention(
549551
if k_scales_pages is not None and v_scales_pages is not None:
550552
in_specs = [
551553
q_block_spec,
554+
# pltpu.TPUMemorySpace.ANY means we are putting everything in HBM.
552555
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
553556
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
554557
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
555558
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
556559
]
557560
scratch_shapes = (
561+
# xw32: how is the pltpu.VMEM being used? I see. It's used in the kernel.
558562
pltpu.VMEM(
559563
(
560564
2, # For double buffering during DMA copies.
@@ -596,10 +600,11 @@ def paged_attention(
596600
else:
597601
in_specs = [
598602
q_block_spec,
603+
# Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, etc.
599604
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
600-
None, # type: ignore[list-item]
605+
None, # type: ignore[list-item] k_scales_pages=None
601606
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
602-
None, # type: ignore[list-item]
607+
None, # type: ignore[list-item] v_scales_pages=None
603608
]
604609
scratch_shapes = (
605610
pltpu.VMEM(
@@ -610,8 +615,8 @@ def paged_attention(
610615
head_dim,
611616
),
612617
k_pages.dtype,
613-
), # k_pages buffer
614-
None,
618+
), # k_pages buffer, k_pages.shape=[num_kv_heads, total_num_pages, page_size, head_dim]
619+
None, # k_scales_pages=None
615620
pltpu.VMEM(
616621
(
617622
2, # For double buffering during DMA copies.
@@ -621,7 +626,7 @@ def paged_attention(
621626
),
622627
v_pages.dtype,
623628
), # v_pages buffer
624-
None,
629+
None, # v_scales_pages=None
625630
pltpu.SemaphoreType.DMA,
626631
)
627632

@@ -656,6 +661,7 @@ def paged_attention(
656661
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
657662
],
658663
)(
664+
# The first 4 are prefetched scalars.
659665
lengths,
660666
page_indices.reshape(-1),
661667
jnp.zeros((1,), jnp.int32), # buffer index

0 commit comments

Comments
 (0)