Skip to content

Commit 92672e4

Browse files
committed
fixed a bug
1 parent e8ccd04 commit 92672e4

File tree

2 files changed

+45
-17
lines changed

2 files changed

+45
-17
lines changed

test/test_pallas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def test_extended_paged_attention_v1_multiple_queries(self):
797797
# import pdb; pdb.set_trace()
798798
out_np = np.array(out) # xw32: why does it hang?!
799799
actual_output = torch.from_numpy(out_np)
800-
print('my new extended paged attention finished')
800+
print('my new extended paged attention finished yay')
801801

802802
# Run Woosuk's non-kernel impl.
803803
ref_q_torch = q.detach().clone()

torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py

+44-16
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def paged_flash_attention_kernel(
219219
v_scales_vmem_buffer,
220220
sem,
221221
*,
222-
pages_per_sequence: int, # bs, pages_per_sequence = page_indices.shape
222+
pages_per_sequence: int, # [bs, pages_per_sequence] = page_indices.shape
223223
batch_size: int,
224224
num_kv_pages_per_compute_block: int,
225225
num_queries_per_compute_block: int,
@@ -234,6 +234,7 @@ def paged_flash_attention_kernel(
234234
pl.program_id(2),
235235
pl.program_id(3),
236236
)
237+
num_q_blks = pl.num_programs(2)
237238
num_queries_per_compute_block, num_q_heads_per_kv_head, head_dim = q_ref.shape
238239
assert q_ref.shape == (num_queries_per_compute_block, num_q_heads_per_kv_head, head_dim)
239240
num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape
@@ -267,12 +268,16 @@ def advance_to_next_non_zero_length():
267268
)
268269

269270
def advance_kv_head_idx():
271+
# assumption: kv_blk_idx * compute_blk_size_kv >= lengths_ref[b]
270272
next_kv_head_idx = kv_head_idx + 1
271-
return lax.cond(next_kv_head_idx < num_kv_heads, lambda: (b, next_kv_head_idx, 0), advance_b)
273+
return lax.cond(q_blk_idx==num_q_blks-1,
274+
lambda: lax.cond(next_kv_head_idx < num_kv_heads, lambda: (b, next_kv_head_idx, 0), advance_b),
275+
lambda: (b, kv_head_idx, 0))
272276

273277
return lax.cond(kv_blk_idx * compute_blk_size_kv < lengths_ref[b], lambda: (b, kv_head_idx, kv_blk_idx), advance_kv_head_idx)
274278

275279
def create_kv_async_copy_descriptors(b, kv_head_idx, kv_blk_idx, buffer_index):
280+
pl.debug_print('line45 b={}, kv_head_idx={}', b, kv_head_idx)
276281
page_offset = b * pages_per_sequence + kv_blk_idx * num_kv_pages_per_compute_block
277282
pages_to_load = num_kv_pages_per_compute_block
278283
async_copy_k = MultiPageAsyncCopyDescriptor(
@@ -366,20 +371,21 @@ def prefetch_next_block(): # pylint: disable=unused-variable
366371
print(f'xw32 line357 {temp_out.shape=}')
367372
o_ref[...] = temp_out
368373
print(f'xw32 line359 {o_ref.shape=}')
374+
step_ref[0] = step + 1
369375

370376
MIN_BLOCK_SIZE = 128
371377

372-
@functools.partial(
373-
jax.jit,
374-
static_argnames=[
375-
"num_kv_pages_per_compute_block",
376-
"num_queries_per_compute_block",
377-
"attn_logits_soft_cap",
378-
"mask_value",
379-
"megacore_mode",
380-
"inline_seq_dim",
381-
],
382-
)
378+
# @functools.partial(
379+
# jax.jit,
380+
# static_argnames=[
381+
# "num_kv_pages_per_compute_block",
382+
# "num_queries_per_compute_block",
383+
# "attn_logits_soft_cap",
384+
# "mask_value",
385+
# "megacore_mode",
386+
# "inline_seq_dim",
387+
# ],
388+
# )
383389
def paged_attention(
384390
q: jax.Array,
385391
k_pages: jax.Array | quantization_utils.QuantizedTensor,
@@ -485,6 +491,11 @@ def paged_attention(
485491

486492
# Here, we guarantee (num_q_heads // num_kv_heads) % 8 == 0
487493
# grid
494+
# query_len dim has to come before kv_len dim for the fa v1 implementation because if we do the other way around,
495+
# then for each j ([0, T_c]) and i ([0, T_r]), we load l_i and m_i from HBM and store to HBM.
496+
# then for j+1, we have to loop over i ([0, T_r]) which requires the load l_i and m_i.
497+
# But this is forbidden in Pallas: https://jax.readthedocs.io/en/latest/pallas/tpu/sparse.html#example-sparse-dense-matrix-multiplication
498+
# "When we change output block Pallas will finally store the output into HBM and assume we never touch it again."
488499
grid = (
489500
batch_size,
490501
num_kv_heads,
@@ -562,7 +573,7 @@ def qo_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
562573
pltpu.SemaphoreType.DMA,
563574
)
564575

565-
out = pl.pallas_call(
576+
kernel = pl.pallas_call(
566577
functools.partial(
567578
paged_flash_attention_kernel,
568579
pages_per_sequence=pages_per_sequence,
@@ -585,7 +596,24 @@ def qo_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
585596
# compiler_params=pltpu.TPUCompilerParams(
586597
# dimension_semantics=dimension_semantics), # do we need it?
587598
out_shape=out_shape,
588-
)(
599+
)
600+
compiled_kernel = (
601+
jax.jit(kernel)
602+
.lower(
603+
# The first 4 are prefetched scalars.
604+
lengths,
605+
page_indices.reshape(-1),
606+
jnp.zeros((1,), jnp.int32), # buffer index
607+
jnp.zeros((1,), jnp.int32), # step
608+
q.astype(q_dtype_for_kernel_launch),
609+
k_pages,
610+
k_scales_pages,
611+
v_pages,
612+
v_scales_pages,
613+
)
614+
.compile({'xla_tpu_enable_log_recorder': 'true'})
615+
)
616+
outs = compiled_kernel(
589617
# The first 4 are prefetched scalars.
590618
lengths,
591619
page_indices.reshape(-1),
@@ -597,7 +625,7 @@ def qo_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
597625
v_pages,
598626
v_scales_pages,
599627
) # should get 4 return values
600-
ret = out[0]
628+
ret = outs[0]
601629
print(f'xw32 finished the pallas kernel. {ret.shape=} Returning...', flush=True)
602630
return ret.reshape(batch_size, query_len, num_q_heads, head_dim).astype(q.dtype)
603631

0 commit comments

Comments
 (0)