Skip to content

Commit d99bde5

Browse files
committed
Split page_indices
1 parent a1ae3a9 commit d99bde5

File tree

2 files changed

+165
-56
lines changed

2 files changed

+165
-56
lines changed

test/test_ragged_paged_attention_kernel.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def _verify_ragged_paged_attention(
124124
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
125125
# The reason why we need to pad max_num_pages_per_seq is that
126126
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
127-
max_num_pages_per_seq = self._get_closest_power_of_two(
128-
max_num_pages_per_seq)
127+
num_kv_pages_per_block = 128
128+
max_num_pages_per_seq = self._round_up_closest_multiple_of(max_num_pages_per_seq, num_kv_pages_per_block)
129129
# The assert below mimics the reality that each page get a unique index.
130130
# But for testing, the assert could be omitted.
131131
# assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
@@ -149,6 +149,7 @@ def _verify_ragged_paged_attention(
149149
page_indices,
150150
cu_q_lens,
151151
num_seqs,
152+
num_kv_pages_per_block=num_kv_pages_per_block,
152153
num_queries_per_block=num_queries_per_block,
153154
)
154155
err.throw() # noop if there is not err.
@@ -183,6 +184,9 @@ def _verify_ragged_paged_attention(
183184
self.assertTrue(
184185
jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol))
185186

187+
def _round_up_closest_multiple_of(self, x, base):
188+
return (x + base - 1) // base * base
189+
186190
def _get_closest_power_of_two(self, x):
187191
if x <= 0:
188192
raise ValueError(f"x must be positive. Got {x}")
@@ -225,14 +229,16 @@ def test_paged_attention_varlen_comprehensive(
225229
page_size: int,
226230
num_pages: int,
227231
):
228-
# assuming q_blk_size=128
232+
if jtu.is_device_tpu(version=4) and head_dim == 256 and page_size == 32:
233+
self.skipTest("TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test.")
229234
self._verify_ragged_paged_attention(
230235
seq_lens,
231236
num_heads,
232237
head_dim,
233238
page_size,
234239
dtype,
235240
num_pages,
241+
num_queries_per_block=64,
236242
)
237243

238244
def test_paged_attention_mix_prefill_and_decode1(self,):
@@ -326,6 +332,7 @@ def test_paged_attention_extreme_one_tokens_per_sequence_min(self,):
326332

327333
def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
328334
# assuming q_blk_size=128
335+
# Here the q_len(1 or 511) is set up to be longer than the corresponding kv_len (0 or 256).
329336
seq_lens = [(1, 0), (511, 256)] # [(q_len, kv_len),...]
330337
num_heads = (1, 1)
331338
head_dim = 128
@@ -361,8 +368,8 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
361368
max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size
362369
# The reason why we need to pad max_num_pages_per_seq is that
363370
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
364-
max_num_pages_per_seq = self._get_closest_power_of_two(
365-
max_num_pages_per_seq)
371+
num_kv_pages_per_block=128
372+
max_num_pages_per_seq = self._round_up_closest_multiple_of(max_num_pages_per_seq, num_kv_pages_per_block)
366373
# The assert below mimics the reality that each page get a unique index.
367374
# But for testing, the assert could be omitted.
368375
assert max_num_pages_per_seq * num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
@@ -388,6 +395,7 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
388395
page_indices,
389396
cu_q_lens,
390397
num_seqs,
398+
num_kv_pages_per_block=num_kv_pages_per_block,
391399
)
392400
err.throw()
393401

0 commit comments

Comments
 (0)