@@ -124,8 +124,8 @@ def _verify_ragged_paged_attention(
124
124
max_num_pages_per_seq = (max_kv_len + page_size - 1 ) // page_size
125
125
# The reason why we need to pad max_num_pages_per_seq is that
126
126
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
127
- max_num_pages_per_seq = self . _get_closest_power_of_two (
128
- max_num_pages_per_seq )
127
+ num_kv_pages_per_block = 128
128
+ max_num_pages_per_seq = self . _round_up_closest_multiple_of ( max_num_pages_per_seq , num_kv_pages_per_block )
129
129
# The assert below mimics the reality that each page get a unique index.
130
130
# But for testing, the assert could be omitted.
131
131
# assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
@@ -149,6 +149,7 @@ def _verify_ragged_paged_attention(
149
149
page_indices ,
150
150
cu_q_lens ,
151
151
num_seqs ,
152
+ num_kv_pages_per_block = num_kv_pages_per_block ,
152
153
num_queries_per_block = num_queries_per_block ,
153
154
)
154
155
err .throw () # noop if there is not err.
@@ -183,6 +184,9 @@ def _verify_ragged_paged_attention(
183
184
self .assertTrue (
184
185
jnp .allclose (actual_output , expected_output , atol = atol , rtol = rtol ))
185
186
187
+ def _round_up_closest_multiple_of (self , x , base ):
188
+ return (x + base - 1 ) // base * base
189
+
186
190
def _get_closest_power_of_two (self , x ):
187
191
if x <= 0 :
188
192
raise ValueError (f"x must be positive. Got { x } " )
@@ -225,14 +229,16 @@ def test_paged_attention_varlen_comprehensive(
225
229
page_size : int ,
226
230
num_pages : int ,
227
231
):
228
- # assuming q_blk_size=128
232
+ if jtu .is_device_tpu (version = 4 ) and head_dim == 256 and page_size == 32 :
233
+ self .skipTest ("TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test." )
229
234
self ._verify_ragged_paged_attention (
230
235
seq_lens ,
231
236
num_heads ,
232
237
head_dim ,
233
238
page_size ,
234
239
dtype ,
235
240
num_pages ,
241
+ num_queries_per_block = 64 ,
236
242
)
237
243
238
244
def test_paged_attention_mix_prefill_and_decode1 (self ,):
@@ -326,6 +332,7 @@ def test_paged_attention_extreme_one_tokens_per_sequence_min(self,):
326
332
327
333
def test_paged_attention_q_len_should_be_no_longer_than_kv_len (self ,):
328
334
# assuming q_blk_size=128
335
+ # Here the q_len(1 or 511) is set up to be longer than the corresponding kv_len (0 or 256).
329
336
seq_lens = [(1 , 0 ), (511 , 256 )] # [(q_len, kv_len),...]
330
337
num_heads = (1 , 1 )
331
338
head_dim = 128
@@ -361,8 +368,8 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
361
368
max_num_pages_per_seq = (max_kv_len + page_size - 1 ) // page_size
362
369
# The reason why we need to pad max_num_pages_per_seq is that
363
370
# page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
364
- max_num_pages_per_seq = self . _get_closest_power_of_two (
365
- max_num_pages_per_seq )
371
+ num_kv_pages_per_block = 128
372
+ max_num_pages_per_seq = self . _round_up_closest_multiple_of ( max_num_pages_per_seq , num_kv_pages_per_block )
366
373
# The assert below mimics the reality that each page get a unique index.
367
374
# But for testing, the assert could be omitted.
368
375
assert max_num_pages_per_seq * num_q_tokens <= num_pages , f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got { max_num_pages_per_seq * num_q_tokens } and { num_pages } "
@@ -388,6 +395,7 @@ def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,):
388
395
page_indices ,
389
396
cu_q_lens ,
390
397
num_seqs ,
398
+ num_kv_pages_per_block = num_kv_pages_per_block ,
391
399
)
392
400
err .throw ()
393
401
0 commit comments