|
31 | 31 | from tpu_inference import envs |
32 | 32 | from tpu_inference.kernels.flash_attention.kernel import flash_attention |
33 | 33 | from tpu_inference.kernels.mla.v2.kernel import mla_ragged_paged_attention |
34 | | -from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \ |
35 | | - get_tuned_block_sizes |
36 | 34 | from tpu_inference.layers.common.attention_metadata import AttentionMetadata |
37 | 35 | from tpu_inference.layers.common.sharding import ShardingAxisName |
38 | 36 | from tpu_inference.logger import init_logger |
@@ -521,16 +519,9 @@ def mla_attention( |
521 | 519 | ) |
522 | 520 |
|
523 | 521 | def _mla_ragged_paged_attention(q, q_rope, k, k_rope, cache, *args): |
524 | | - max_num_tokens = q.shape[0] |
525 | | - max_num_seqs = md.seq_lens.shape[0] |
526 | | - pages_per_seq = md.block_tables.shape[0] // max_num_seqs |
527 | | - |
528 | | - bkv_p, bq_sz = get_tuned_block_sizes(q.dtype, cache.dtype, |
529 | | - num_attention_heads, 1, |
530 | | - qk_nope_head_dim, cache.shape[1], |
531 | | - max_num_tokens, pages_per_seq) |
532 | | - num_kv_pages_per_block = min(min(pages_per_seq, bkv_p), 4) |
533 | | - num_queries_per_block = min(min(max_num_tokens, bq_sz), 4) |
| 522 | + # TODO: use auto tuner to find the best block sizes. |
| 523 | + num_kv_pages_per_block = 3 |
| 524 | + num_queries_per_block = 1 |
534 | 525 |
|
535 | 526 | out, new_cache = mla_ragged_paged_attention( |
536 | 527 | q, |
|
0 commit comments