@@ -332,7 +332,7 @@ def body(i, _):
332
332
buffer_index_ref ,
333
333
step_ref ,
334
334
q_ref ,
335
- k_pages_hbm_ref ,
335
+ k_pages_hbm_ref , #
336
336
k_scales_pages_hbm_ref ,
337
337
v_pages_hbm_ref ,
338
338
v_scales_pages_hbm_ref ,
@@ -418,7 +418,7 @@ def paged_attention(
418
418
Returns:
419
419
The output of attention([batch_size, num_heads, head_dim]).
420
420
"""
421
- return jnp .zeros_like (q , jnp . int32 )
421
+ # return jnp.zeros_like(q, q.dtype )
422
422
if isinstance (k_pages , quantization_utils .QuantizedTensor ):
423
423
k_pages , k_scales_pages = k_pages .weight , k_pages .scales
424
424
assert isinstance (k_scales_pages , jax .Array ) # For typing.
@@ -507,6 +507,7 @@ def paged_attention(
507
507
q_dtype_for_kernel_launch = jnp .float32
508
508
else :
509
509
if megacore_mode == "kv_head" :
510
+ # q.shape=[batch_size, num_heads, head_dim]
510
511
q_block_spec = pl .BlockSpec (
511
512
(None , num_heads // num_kv_heads , head_dim ),
512
513
lambda core_index , b , h , * _ : (b , h * num_cores + core_index , 0 ),
@@ -533,6 +534,7 @@ def paged_attention(
533
534
if megacore_mode == "kv_head"
534
535
else num_kv_heads ,
535
536
)
537
+ # xw32q: shouldn't batch dim and kv_heads dim be parallel?
536
538
dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" )
537
539
else :
538
540
kernel = paged_flash_attention_kernel
@@ -549,12 +551,14 @@ def paged_attention(
549
551
if k_scales_pages is not None and v_scales_pages is not None :
550
552
in_specs = [
551
553
q_block_spec ,
554
+ # pltpu.TPUMemorySpace.ANY means we are putting everything in HBM.
552
555
pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY ),
553
556
pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY ),
554
557
pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY ),
555
558
pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY ),
556
559
]
557
560
scratch_shapes = (
561
+ # xw32: how is the pltpu.VMEM being used? I see. It's used in the kernel.
558
562
pltpu .VMEM (
559
563
(
560
564
2 , # For double buffering during DMA copies.
@@ -596,10 +600,11 @@ def paged_attention(
596
600
else :
597
601
in_specs = [
598
602
q_block_spec ,
603
+ # Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, etc.
599
604
pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY ),
600
- None , # type: ignore[list-item]
605
+ None , # type: ignore[list-item] k_scales_pages=None
601
606
pl .BlockSpec (memory_space = pltpu .TPUMemorySpace .ANY ),
602
- None , # type: ignore[list-item]
607
+ None , # type: ignore[list-item] v_scales_pages=None
603
608
]
604
609
scratch_shapes = (
605
610
pltpu .VMEM (
@@ -610,8 +615,8 @@ def paged_attention(
610
615
head_dim ,
611
616
),
612
617
k_pages .dtype ,
613
- ), # k_pages buffer
614
- None ,
618
+ ), # k_pages buffer, k_pages.shape=[num_kv_heads, total_num_pages, page_size, head_dim]
619
+ None , # k_scales_pages=None
615
620
pltpu .VMEM (
616
621
(
617
622
2 , # For double buffering during DMA copies.
@@ -621,7 +626,7 @@ def paged_attention(
621
626
),
622
627
v_pages .dtype ,
623
628
), # v_pages buffer
624
- None ,
629
+ None , # v_scales_pages=None
625
630
pltpu .SemaphoreType .DMA ,
626
631
)
627
632
@@ -656,6 +661,7 @@ def paged_attention(
656
661
jax .ShapeDtypeStruct ((* q .shape [:- 1 ], 1 ), jnp .float32 ),
657
662
],
658
663
)(
664
+ # The first 4 are prefetched scalars.
659
665
lengths ,
660
666
page_indices .reshape (- 1 ),
661
667
jnp .zeros ((1 ,), jnp .int32 ), # buffer index
0 commit comments