@@ -219,7 +219,7 @@ def paged_flash_attention_kernel(
219
219
v_scales_vmem_buffer ,
220
220
sem ,
221
221
* ,
222
- pages_per_sequence : int , # bs, pages_per_sequence = page_indices.shape
222
+ pages_per_sequence : int , # [ bs, pages_per_sequence] = page_indices.shape
223
223
batch_size : int ,
224
224
num_kv_pages_per_compute_block : int ,
225
225
num_queries_per_compute_block : int ,
@@ -234,6 +234,7 @@ def paged_flash_attention_kernel(
234
234
pl .program_id (2 ),
235
235
pl .program_id (3 ),
236
236
)
237
+ num_q_blks = pl .num_programs (2 )
237
238
num_queries_per_compute_block , num_q_heads_per_kv_head , head_dim = q_ref .shape
238
239
assert q_ref .shape == (num_queries_per_compute_block , num_q_heads_per_kv_head , head_dim )
239
240
num_kv_heads , total_num_pages , page_size , head_dim = k_pages_hbm_ref .shape
@@ -267,12 +268,16 @@ def advance_to_next_non_zero_length():
267
268
)
268
269
269
270
def advance_kv_head_idx ():
271
+ # assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b]
270
272
next_kv_head_idx = kv_head_idx + 1
271
- return lax .cond (next_kv_head_idx < num_kv_heads , lambda : (b , next_kv_head_idx , 0 ), advance_b )
273
+ return lax .cond (q_blk_idx == num_q_blks - 1 ,
274
+ lambda : lax .cond (next_kv_head_idx < num_kv_heads , lambda : (b , next_kv_head_idx , 0 ), advance_b ),
275
+ lambda : (b , kv_head_idx , 0 ))
272
276
273
277
return lax .cond (kv_blk_idx * compute_blk_size_kv < lengths_ref [b ], lambda : (b , kv_head_idx , kv_blk_idx ), advance_kv_head_idx )
274
278
275
279
def create_kv_async_copy_descriptors (b , kv_head_idx , kv_blk_idx , buffer_index ):
280
+ pl .debug_print ('line45 b={}, kv_head_idx={}' , b , kv_head_idx )
276
281
page_offset = b * pages_per_sequence + kv_blk_idx * num_kv_pages_per_compute_block
277
282
pages_to_load = num_kv_pages_per_compute_block
278
283
async_copy_k = MultiPageAsyncCopyDescriptor (
@@ -366,20 +371,21 @@ def prefetch_next_block(): # pylint: disable=unused-variable
366
371
print (f'xw32 line357 { temp_out .shape = } ' )
367
372
o_ref [...] = temp_out
368
373
print (f'xw32 line359 { o_ref .shape = } ' )
374
+ step_ref [0 ] = step + 1
369
375
370
376
MIN_BLOCK_SIZE = 128
371
377
372
- @functools .partial (
373
- jax .jit ,
374
- static_argnames = [
375
- "num_kv_pages_per_compute_block" ,
376
- "num_queries_per_compute_block" ,
377
- "attn_logits_soft_cap" ,
378
- "mask_value" ,
379
- "megacore_mode" ,
380
- "inline_seq_dim" ,
381
- ],
382
- )
378
+ # @functools.partial(
379
+ # jax.jit,
380
+ # static_argnames=[
381
+ # "num_kv_pages_per_compute_block",
382
+ # "num_queries_per_compute_block",
383
+ # "attn_logits_soft_cap",
384
+ # "mask_value",
385
+ # "megacore_mode",
386
+ # "inline_seq_dim",
387
+ # ],
388
+ # )
383
389
def paged_attention (
384
390
q : jax .Array ,
385
391
k_pages : jax .Array | quantization_utils .QuantizedTensor ,
@@ -485,6 +491,11 @@ def paged_attention(
485
491
486
492
# Here, we guarantee (num_q_heads // num_kv_heads) % 8 == 0
487
493
# grid
494
+ # query_len dim has to come before kv_len dim for the fa v1 implementation because if we do the other way around,
495
+ # then for each j ([0, T_c]) and i ([0, T_r]), we load l_i and m_i from HBM and store to HBM.
496
+ # then for j+1, we have to loop over i ([0, T_r]) which requires the load l_i and m_i.
497
+ # But this is forbidden in Pallas: https://jax.readthedocs.io/en/latest/pallas/tpu/sparse.html#example-sparse-dense-matrix-multiplication
498
+ # "When we change output block Pallas will finally store the output into HBM and assume we never touch it again."
488
499
grid = (
489
500
batch_size ,
490
501
num_kv_heads ,
@@ -562,7 +573,7 @@ def qo_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
562
573
pltpu .SemaphoreType .DMA ,
563
574
)
564
575
565
- out = pl .pallas_call (
576
+ kernel = pl .pallas_call (
566
577
functools .partial (
567
578
paged_flash_attention_kernel ,
568
579
pages_per_sequence = pages_per_sequence ,
@@ -585,7 +596,24 @@ def qo_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
585
596
# compiler_params=pltpu.TPUCompilerParams(
586
597
# dimension_semantics=dimension_semantics), # do we need it?
587
598
out_shape = out_shape ,
588
- )(
599
+ )
600
+ compiled_kernel = (
601
+ jax .jit (kernel )
602
+ .lower (
603
+ # The first 4 are prefetched scalars.
604
+ lengths ,
605
+ page_indices .reshape (- 1 ),
606
+ jnp .zeros ((1 ,), jnp .int32 ), # buffer index
607
+ jnp .zeros ((1 ,), jnp .int32 ), # step
608
+ q .astype (q_dtype_for_kernel_launch ),
609
+ k_pages ,
610
+ k_scales_pages ,
611
+ v_pages ,
612
+ v_scales_pages ,
613
+ )
614
+ .compile ({'xla_tpu_enable_log_recorder' : 'true' })
615
+ )
616
+ outs = compiled_kernel (
589
617
# The first 4 are prefetched scalars.
590
618
lengths ,
591
619
page_indices .reshape (- 1 ),
@@ -597,7 +625,7 @@ def qo_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
597
625
v_pages ,
598
626
v_scales_pages ,
599
627
) # should get 4 return values
600
- ret = out [0 ]
628
+ ret = outs [0 ]
601
629
print (f'xw32 finished the pallas kernel. { ret .shape = } Returning...' , flush = True )
602
630
return ret .reshape (batch_size , query_len , num_q_heads , head_dim ).astype (q .dtype )
603
631
0 commit comments