@@ -49,6 +49,7 @@ def __init__(
49
49
num_pages_to_load ,
50
50
head_index ,
51
51
):
52
+ # Original k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim]
52
53
self ._vmem_buffer = vmem_buffer
53
54
self ._scales_vmem_buffer = scales_vmem_buffer
54
55
self ._num_pages_to_load = num_pages_to_load
@@ -113,7 +114,6 @@ def wait_and_get_loaded(self) -> jax.Array:
113
114
jax_array = self ._maybe_dequantize (jax_array , scales_jax_array )
114
115
return jax_array .reshape (- 1 , head_dim )
115
116
116
-
117
117
def paged_flash_attention_kernel (
118
118
lengths_ref ,
119
119
page_indices_ref ,
@@ -142,16 +142,21 @@ def paged_flash_attention_kernel(
142
142
program_ids = (),
143
143
):
144
144
"""Pallas kernel for paged attention."""
145
- if program_ids :
146
- core_index , b , h , i = program_ids
145
+ # xw32: core_index, b, h, i=num_cores, batch_size, num_kv_heads, i
146
+ # Note the original q.shape=[batch_size, query_len, num_heads, head_dim]
147
+ print (f'xw32 line147 { q_ref .shape = } ' )
148
+ if program_ids : # inline_seq_dim case.
149
+ core_index , q_idx , b , h , i = program_ids # The 2nd one is q_idx but we don't use it.
147
150
else :
148
- core_index , b , h , i = (
151
+ core_index , q_idx , b , h , i = (
149
152
pl .program_id (0 ),
150
153
pl .program_id (1 ),
151
154
pl .program_id (2 ),
152
155
pl .program_id (3 ),
156
+ pl .program_id (4 ),
153
157
)
154
158
num_kv_heads , _ , page_size , _ = k_pages_hbm_ref .shape
159
+ # xw32: bk should be the overall compute block size.
155
160
bk = page_size * pages_per_compute_block
156
161
num_cores = pl .num_programs (0 )
157
162
@@ -317,10 +322,16 @@ def paged_flash_attention_kernel_inline_seq_dim(
317
322
attn_logits_soft_cap : float | None ,
318
323
megacore_mode : str | None ,
319
324
):
320
- core_index , b , h = pl .program_id (0 ), pl .program_id (1 ), pl .program_id (2 )
325
+ print (f'xw32 line325 { m_ref .shape = } ' )
326
+ core_index , q_idx , b , h = pl .program_id (0 ), pl .program_id (1 ), pl .program_id (2 ), pl .program_id (3 )
321
327
322
328
# Initialize the output HBM buffers to avoid accessing garbage memory inside
323
329
# the kernel body below.
330
+ # Note, q_block_spec = pl.BlockSpec(
331
+ # (None, None, num_heads // num_kv_heads, head_dim), # bs,query_len,num_heads,head_dim
332
+ # lambda core_index, q, b, h, *_: (b, q, h, 0),
333
+ # )
334
+ # m_ref,l_ref,o_ref has out_specs=q_block_spec
324
335
m_ref [...] = jnp .full_like (m_ref , - jnp .inf )
325
336
l_ref [...] = jnp .zeros_like (l_ref )
326
337
o_ref [...] = jnp .zeros_like (o_ref )
@@ -332,7 +343,7 @@ def body(i, _):
332
343
buffer_index_ref ,
333
344
step_ref ,
334
345
q_ref ,
335
- k_pages_hbm_ref , #
346
+ k_pages_hbm_ref ,
336
347
k_scales_pages_hbm_ref ,
337
348
v_pages_hbm_ref ,
338
349
v_scales_pages_hbm_ref ,
@@ -350,11 +361,13 @@ def body(i, _):
350
361
mask_value = mask_value ,
351
362
attn_logits_soft_cap = attn_logits_soft_cap ,
352
363
megacore_mode = megacore_mode ,
353
- program_ids = (core_index , b , h , i ),
364
+ program_ids = (core_index , q_idx , b , h , i ),
354
365
)
355
366
return ()
356
367
357
- bk = pages_per_compute_block * k_pages_hbm_ref .shape [- 2 ]
368
+ # xw32: nb num_kv_heads, _, page_size, head_dim_k = k_pages.shape
369
+ # so k_pages_hbm_ref.shape[-2] is the page_size.
370
+ bk = pages_per_compute_block * k_pages_hbm_ref .shape [- 2 ] # The accumulated page sizes for all the pages in this compute block.
358
371
359
372
if megacore_mode == "batch" :
360
373
num_cores = pl .num_programs (0 )
@@ -391,7 +404,7 @@ def paged_attention(
391
404
"""Paged grouped query attention.
392
405
393
406
Args:
394
- q: A [batch_size, num_heads, head_dim] jax.Array.
407
+ q: A [batch_size, query_len, num_heads, head_dim] jax.Array.
395
408
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
396
409
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
397
410
lengths: A i32[batch_size] jax.Array the length of each example.
@@ -416,9 +429,8 @@ def paged_attention(
416
429
one kernel.
417
430
418
431
Returns:
419
- The output of attention([batch_size, num_heads, head_dim]).
432
+ The output of attention([batch_size, query_len, num_heads, head_dim]).
420
433
"""
421
- # return jnp.zeros_like(q, q.dtype)
422
434
if isinstance (k_pages , quantization_utils .QuantizedTensor ):
423
435
k_pages , k_scales_pages = k_pages .weight , k_pages .scales
424
436
assert isinstance (k_scales_pages , jax .Array ) # For typing.
@@ -436,7 +448,8 @@ def paged_attention(
436
448
else :
437
449
v_scales_pages = None
438
450
439
- batch_size , num_heads , head_dim = q .shape
451
+ # TODO(xw32): consider renaming num_heads to num_query_heads
452
+ batch_size , query_len , num_heads , head_dim = q .shape
440
453
num_kv_heads , _ , page_size , head_dim_k = k_pages .shape
441
454
batch_size_paged_indices , pages_per_sequence = page_indices .shape
442
455
@@ -486,6 +499,7 @@ def paged_attention(
486
499
raise ValueError ("megacore_mode must be one of ['kv_head', 'batch', None]" )
487
500
488
501
if (num_heads // num_kv_heads ) % 8 != 0 :
502
+ # TODO(xw32):add the query_len dim to this branch later.
489
503
# Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a
490
504
# <8x128> layout for a <1x128> memref inside the kernel and error out.
491
505
q = q .reshape (batch_size , num_heads , 1 , head_dim )
@@ -507,46 +521,51 @@ def paged_attention(
507
521
q_dtype_for_kernel_launch = jnp .float32
508
522
else :
509
523
if megacore_mode == "kv_head" :
510
- # q.shape=[batch_size, num_heads, head_dim]
524
+ # q.shape=[batch_size, query_len, num_heads, head_dim]
525
+ # xw32q: The way it chunks the `num_heads` dimension (num_heads // num_kv_heads),
526
+ # does it mean it is a MQA?
511
527
q_block_spec = pl .BlockSpec (
512
- (None , num_heads // num_kv_heads , head_dim ),
513
- lambda core_index , b , h , * _ : (b , h * num_cores + core_index , 0 ),
528
+ (None , None , num_heads // num_kv_heads , head_dim ),
529
+ lambda core_index , q , b , h , * _ : (b , q , h * num_cores + core_index , 0 ),
514
530
)
515
531
elif megacore_mode == "batch" :
516
532
q_block_spec = pl .BlockSpec (
517
- (None , num_heads // num_kv_heads , head_dim ),
518
- lambda core_index , b , h , * _ : (b * num_cores + core_index , h , 0 ),
533
+ (None , None , num_heads // num_kv_heads , head_dim ),
534
+ lambda core_index , q , b , h , * _ : (b * num_cores + core_index , q , h , 0 ),
519
535
)
520
536
else :
521
537
q_block_spec = pl .BlockSpec (
522
- (None , num_heads // num_kv_heads , head_dim ),
523
- lambda core_index , b , h , * _ : (b , h , 0 ),
538
+ (None , None , num_heads // num_kv_heads , head_dim ),
539
+ lambda core_index , q , b , h , * _ : (b , q , h , 0 ),
524
540
)
525
541
q_dtype_for_kernel_launch = q .dtype
526
542
527
543
dimension_semantics : Sequence [Literal ["parallel" , "arbitrary" ]]
528
544
if inline_seq_dim :
529
545
kernel = paged_flash_attention_kernel_inline_seq_dim
546
+ # query_len goes before batch_size and num_kv_heads so the flash_attention kernel doesn't need to be changed.
530
547
grid = (
531
548
num_cores ,
549
+ query_len ,
532
550
batch_size // num_cores if megacore_mode == "batch" else batch_size ,
533
551
num_kv_heads // num_cores
534
552
if megacore_mode == "kv_head"
535
553
else num_kv_heads ,
536
554
)
537
555
# xw32q: shouldn't batch dim and kv_heads dim be parallel?
538
- dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" )
556
+ dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" , "arbitrary" )
539
557
else :
540
558
kernel = paged_flash_attention_kernel
541
559
grid = (
542
560
num_cores ,
561
+ query_len ,
543
562
batch_size // num_cores if megacore_mode == "batch" else batch_size ,
544
563
num_kv_heads // num_cores
545
564
if megacore_mode == "kv_head"
546
565
else num_kv_heads ,
547
566
pages_per_sequence // pages_per_compute_block ,
548
567
) # type: ignore
549
- dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" , "arbitrary" )
568
+ dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" , "arbitrary" , "arbitrary" )
550
569
551
570
if k_scales_pages is not None and v_scales_pages is not None :
552
571
in_specs = [
@@ -597,7 +616,7 @@ def paged_attention(
597
616
), # v_scales_pages buffer
598
617
pltpu .SemaphoreType .DMA ,
599
618
)
600
- else :
619
+ else : # either k_scales_pages or v_scales_pages is None.
601
620
in_specs = [
602
621
q_block_spec ,
603
622
# Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, etc.
@@ -672,4 +691,4 @@ def paged_attention(
672
691
v_pages ,
673
692
v_scales_pages ,
674
693
)
675
- return out .reshape (batch_size , num_heads , head_dim ).astype (q .dtype )
694
+ return out .reshape (batch_size , query_len , num_heads , head_dim ).astype (q .dtype )
0 commit comments