1
- # This is the original paged_attention_kernel copied from
1
+ # This is the adapted extended_paged_attention_kernel copied from
2
2
# https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L400-L401
3
- # Don't review it. Will be removed later.
4
3
5
4
# Copyright 2024 The JAX Authors.
6
5
#
@@ -49,6 +48,7 @@ def __init__(
49
48
num_pages_to_load ,
50
49
head_index ,
51
50
):
51
+ # Original k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim]
52
52
self ._vmem_buffer = vmem_buffer
53
53
self ._scales_vmem_buffer = scales_vmem_buffer
54
54
self ._num_pages_to_load = num_pages_to_load
@@ -113,7 +113,6 @@ def wait_and_get_loaded(self) -> jax.Array:
113
113
jax_array = self ._maybe_dequantize (jax_array , scales_jax_array )
114
114
return jax_array .reshape (- 1 , head_dim )
115
115
116
-
117
116
def paged_flash_attention_kernel (
118
117
lengths_ref ,
119
118
page_indices_ref ,
@@ -142,16 +141,21 @@ def paged_flash_attention_kernel(
142
141
program_ids = (),
143
142
):
144
143
"""Pallas kernel for paged attention."""
145
- if program_ids :
146
- core_index , b , h , i = program_ids
144
+ # xw32: core_index, b, h, i=num_cores, batch_size, num_kv_heads, i
145
+ # Note the original q.shape=[batch_size, query_len, num_heads, head_dim]
146
+ print (f'xw32 line147 { q_ref .shape = } ' )
147
+ if program_ids : # inline_seq_dim case.
148
+ core_index , q_idx , b , h , i = program_ids # The 2nd one is q_idx but we don't use it.
147
149
else :
148
- core_index , b , h , i = (
150
+ core_index , q_idx , b , h , i = (
149
151
pl .program_id (0 ),
150
152
pl .program_id (1 ),
151
153
pl .program_id (2 ),
152
154
pl .program_id (3 ),
155
+ pl .program_id (4 ),
153
156
)
154
157
num_kv_heads , _ , page_size , _ = k_pages_hbm_ref .shape
158
+ # xw32: bk should be the overall compute block size.
155
159
bk = page_size * pages_per_compute_block
156
160
num_cores = pl .num_programs (0 )
157
161
@@ -317,10 +321,16 @@ def paged_flash_attention_kernel_inline_seq_dim(
317
321
attn_logits_soft_cap : float | None ,
318
322
megacore_mode : str | None ,
319
323
):
320
- core_index , b , h = pl .program_id (0 ), pl .program_id (1 ), pl .program_id (2 )
324
+ print (f'xw32 line325 { m_ref .shape = } ' )
325
+ core_index , q_idx , b , h = pl .program_id (0 ), pl .program_id (1 ), pl .program_id (2 ), pl .program_id (3 )
321
326
322
327
# Initialize the output HBM buffers to avoid accessing garbage memory inside
323
328
# the kernel body below.
329
+ # Note, q_block_spec = pl.BlockSpec(
330
+ # (None, None, num_heads // num_kv_heads, head_dim), # bs,query_len,num_heads,head_dim
331
+ # lambda core_index, q, b, h, *_: (b, q, h, 0),
332
+ # )
333
+ # m_ref,l_ref,o_ref has out_specs=q_block_spec
324
334
m_ref [...] = jnp .full_like (m_ref , - jnp .inf )
325
335
l_ref [...] = jnp .zeros_like (l_ref )
326
336
o_ref [...] = jnp .zeros_like (o_ref )
@@ -332,7 +342,7 @@ def body(i, _):
332
342
buffer_index_ref ,
333
343
step_ref ,
334
344
q_ref ,
335
- k_pages_hbm_ref , #
345
+ k_pages_hbm_ref ,
336
346
k_scales_pages_hbm_ref ,
337
347
v_pages_hbm_ref ,
338
348
v_scales_pages_hbm_ref ,
@@ -350,11 +360,13 @@ def body(i, _):
350
360
mask_value = mask_value ,
351
361
attn_logits_soft_cap = attn_logits_soft_cap ,
352
362
megacore_mode = megacore_mode ,
353
- program_ids = (core_index , b , h , i ),
363
+ program_ids = (core_index , q_idx , b , h , i ),
354
364
)
355
365
return ()
356
366
357
- bk = pages_per_compute_block * k_pages_hbm_ref .shape [- 2 ]
367
+ # xw32: nb num_kv_heads, _, page_size, head_dim_k = k_pages.shape
368
+ # so k_pages_hbm_ref.shape[-2] is the page_size.
369
+ bk = pages_per_compute_block * k_pages_hbm_ref .shape [- 2 ] # The accumulated page sizes for all the pages in this compute block.
358
370
359
371
if megacore_mode == "batch" :
360
372
num_cores = pl .num_programs (0 )
@@ -391,7 +403,7 @@ def paged_attention(
391
403
"""Paged grouped query attention.
392
404
393
405
Args:
394
- q: A [batch_size, num_heads, head_dim] jax.Array.
406
+ q: A [batch_size, query_len, num_heads, head_dim] jax.Array.
395
407
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
396
408
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
397
409
lengths: A i32[batch_size] jax.Array the length of each example.
@@ -416,9 +428,8 @@ def paged_attention(
416
428
one kernel.
417
429
418
430
Returns:
419
- The output of attention([batch_size, num_heads, head_dim]).
431
+ The output of attention([batch_size, query_len, num_heads, head_dim]).
420
432
"""
421
- # return jnp.zeros_like(q, q.dtype)
422
433
if isinstance (k_pages , quantization_utils .QuantizedTensor ):
423
434
k_pages , k_scales_pages = k_pages .weight , k_pages .scales
424
435
assert isinstance (k_scales_pages , jax .Array ) # For typing.
@@ -436,7 +447,8 @@ def paged_attention(
436
447
else :
437
448
v_scales_pages = None
438
449
439
- batch_size , num_heads , head_dim = q .shape
450
+ # TODO(xw32): consider renaming num_heads to num_query_heads
451
+ batch_size , query_len , num_heads , head_dim = q .shape
440
452
num_kv_heads , _ , page_size , head_dim_k = k_pages .shape
441
453
batch_size_paged_indices , pages_per_sequence = page_indices .shape
442
454
@@ -486,6 +498,7 @@ def paged_attention(
486
498
raise ValueError ("megacore_mode must be one of ['kv_head', 'batch', None]" )
487
499
488
500
if (num_heads // num_kv_heads ) % 8 != 0 :
501
+ # TODO(xw32):add the query_len dim to this branch later.
489
502
# Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a
490
503
# <8x128> layout for a <1x128> memref inside the kernel and error out.
491
504
q = q .reshape (batch_size , num_heads , 1 , head_dim )
@@ -507,46 +520,53 @@ def paged_attention(
507
520
q_dtype_for_kernel_launch = jnp .float32
508
521
else :
509
522
if megacore_mode == "kv_head" :
510
- # q.shape=[batch_size, num_heads, head_dim]
523
+ # q.shape=[batch_size, query_len, num_heads, head_dim]
524
+ # xw32q: The way it chunks the `num_heads` dimension (num_heads // num_kv_heads),
525
+ # does it mean it is a MQA?
511
526
q_block_spec = pl .BlockSpec (
512
- (None , num_heads // num_kv_heads , head_dim ),
513
- lambda core_index , b , h , * _ : (b , h * num_cores + core_index , 0 ),
527
+ (None , None , num_heads // num_kv_heads , head_dim ),
528
+ lambda core_index , q , b , h , * _ : (b , q , h * num_cores + core_index , 0 ),
514
529
)
515
530
elif megacore_mode == "batch" :
516
531
q_block_spec = pl .BlockSpec (
517
- (None , num_heads // num_kv_heads , head_dim ),
518
- lambda core_index , b , h , * _ : (b * num_cores + core_index , h , 0 ),
532
+ (None , None , num_heads // num_kv_heads , head_dim ),
533
+ lambda core_index , q , b , h , * _ : (b * num_cores + core_index , q , h , 0 ),
519
534
)
520
535
else :
536
+ # Here, if (num_heads // num_kv_heads)%8==0 and megacore_mode is None and inline_seq_dim == True, then
537
+ # grid=[num_cores, query_len, batch_size, num_kv_heads]
521
538
q_block_spec = pl .BlockSpec (
522
- (None , num_heads // num_kv_heads , head_dim ),
523
- lambda core_index , b , h , * _ : (b , h , 0 ),
539
+ (None , None , num_heads // num_kv_heads , head_dim ),
540
+ lambda core_index , q , b , h , * _ : (b , q , h , 0 ),
524
541
)
525
542
q_dtype_for_kernel_launch = q .dtype
526
543
527
544
dimension_semantics : Sequence [Literal ["parallel" , "arbitrary" ]]
528
545
if inline_seq_dim :
529
546
kernel = paged_flash_attention_kernel_inline_seq_dim
547
+ # query_len goes before batch_size and num_kv_heads so the flash_attention kernel doesn't need to be changed.
530
548
grid = (
531
549
num_cores ,
550
+ query_len ,
532
551
batch_size // num_cores if megacore_mode == "batch" else batch_size ,
533
552
num_kv_heads // num_cores
534
553
if megacore_mode == "kv_head"
535
554
else num_kv_heads ,
536
555
)
537
556
# xw32q: shouldn't batch dim and kv_heads dim be parallel?
538
- dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" )
557
+ dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" , "arbitrary" )
539
558
else :
540
559
kernel = paged_flash_attention_kernel
541
560
grid = (
542
561
num_cores ,
562
+ query_len ,
543
563
batch_size // num_cores if megacore_mode == "batch" else batch_size ,
544
564
num_kv_heads // num_cores
545
565
if megacore_mode == "kv_head"
546
566
else num_kv_heads ,
547
567
pages_per_sequence // pages_per_compute_block ,
548
568
) # type: ignore
549
- dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" , "arbitrary" )
569
+ dimension_semantics = ("parallel" , "arbitrary" , "arbitrary" , "arbitrary" , "arbitrary" )
550
570
551
571
if k_scales_pages is not None and v_scales_pages is not None :
552
572
in_specs = [
@@ -597,7 +617,7 @@ def paged_attention(
597
617
), # v_scales_pages buffer
598
618
pltpu .SemaphoreType .DMA ,
599
619
)
600
- else :
620
+ else : # either k_scales_pages or v_scales_pages is None.
601
621
in_specs = [
602
622
q_block_spec ,
603
623
# Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, etc.
@@ -672,4 +692,4 @@ def paged_attention(
672
692
v_pages ,
673
693
v_scales_pages ,
674
694
)
675
- return out .reshape (batch_size , num_heads , head_dim ).astype (q .dtype )
695
+ return out .reshape (batch_size , query_len , num_heads , head_dim ).astype (q .dtype )
0 commit comments