Skip to content

Commit 5acc9f2

Browse files
authored
[MLA] use manually tuned block size for MLA (#2021)
Signed-off-by: Guangxiang Du <gxd@google.com>
1 parent 6b7cf3d commit 5acc9f2

File tree

5 files changed

+30
-34
lines changed

5 files changed

+30
-34
lines changed

tests/layers/common/test_attention_interface.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,6 @@ def test_mla_attention(monkeypatch, mesh):
338338
request_distribution=jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32),
339339
)
340340

341-
mock_tuned_block_sizes = MagicMock(return_value=(8, 8))
342-
monkeypatch.setattr(
343-
"tpu_inference.layers.common.attention_interface.get_tuned_block_sizes",
344-
mock_tuned_block_sizes)
345-
346341
expected_output = jnp.full(q_TNA.shape, 0.5)
347342
expected_new_cache = jnp.full(kv_cache_shape, 0.1)
348343

@@ -365,15 +360,13 @@ def test_mla_attention(monkeypatch, mesh):
365360
sm_scale=0.1,
366361
)
367362

368-
# Verify mocked functions were called
369-
mock_tuned_block_sizes.assert_called_once()
370363
mock_mla_kernel.assert_called_once()
371364

372365
# Verify output correctness
373366
assert jnp.array_equal(output, expected_output)
374367
assert jnp.array_equal(final_kv_cache, expected_new_cache)
375368

376369
_, kernel_kwargs = mock_mla_kernel.call_args
377-
assert kernel_kwargs["num_kv_pages_per_block"] == 4
378-
assert kernel_kwargs["num_queries_per_block"] == 4
370+
assert kernel_kwargs["num_kv_pages_per_block"] == 3
371+
assert kernel_kwargs["num_queries_per_block"] == 1
379372
assert kernel_kwargs["sm_scale"] == 0.1

tests/platforms/test_tpu_platform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def test_check_and_update_config_block_size(self, mock_logger, mock_update,
251251
vllm_config.cache_config = MagicMock()
252252
vllm_config.cache_config.user_specified_block_size = False
253253
vllm_config.cache_config.block_size = 16
254+
vllm_config.model_config.use_mla = False
254255

255256
with patch.dict(
256257
'sys.modules', {

tpu_inference/layers/common/attention_interface.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
from tpu_inference import envs
3232
from tpu_inference.kernels.flash_attention.kernel import flash_attention
3333
from tpu_inference.kernels.mla.v2.kernel import mla_ragged_paged_attention
34-
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
35-
get_tuned_block_sizes
3634
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
3735
from tpu_inference.layers.common.sharding import ShardingAxisName
3836
from tpu_inference.logger import init_logger
@@ -521,16 +519,9 @@ def mla_attention(
521519
)
522520

523521
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, cache, *args):
524-
max_num_tokens = q.shape[0]
525-
max_num_seqs = md.seq_lens.shape[0]
526-
pages_per_seq = md.block_tables.shape[0] // max_num_seqs
527-
528-
bkv_p, bq_sz = get_tuned_block_sizes(q.dtype, cache.dtype,
529-
num_attention_heads, 1,
530-
qk_nope_head_dim, cache.shape[1],
531-
max_num_tokens, pages_per_seq)
532-
num_kv_pages_per_block = min(min(pages_per_seq, bkv_p), 4)
533-
num_queries_per_block = min(min(max_num_tokens, bq_sz), 4)
522+
# TODO: use auto tuner to find the best block sizes.
523+
num_kv_pages_per_block = 3
524+
num_queries_per_block = 1
534525

535526
out, new_cache = mla_ragged_paged_attention(
536527
q,

tpu_inference/layers/vllm/backends/flash_attn_mla.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from jax.sharding import Mesh
1919
from torchax.interop import jax_view
20+
from vllm.config import VllmConfig
2021
from vllm.model_executor.layers.attention.mla_attention import MLAAttention
2122
from vllm.v1.attention.backend import (AttentionBackend, AttentionLayer,
2223
MLAAttentionImpl)
@@ -43,6 +44,10 @@ def get_name() -> str:
4344
def get_impl_cls() -> type["PallasMLAttentionBackend"]:
4445
return PallasMLAttentionBackendImpl
4546

47+
@staticmethod
48+
def get_page_size(vllm_config: VllmConfig) -> int:
49+
return 1024
50+
4651

4752
class PallasMLAttentionBackendImpl(MLAAttentionImpl):
4853

tpu_inference/platforms/tpu_platform.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,25 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
212212
# For v0, the default block size is 16.
213213
if cache_config and not cache_config.user_specified_block_size:
214214
if vllm_config.model_config:
215-
from tpu_inference.layers.vllm.backends.flash_attn import \
216-
PallasAttentionBackend
217-
cache_config.block_size = PallasAttentionBackend.get_page_size(
218-
vllm_config) # type: ignore[assignment]
219-
min_page_size = PallasAttentionBackend.get_min_page_size(
220-
vllm_config)
221-
if min_page_size > cache_config.block_size:
222-
logger.warning(
223-
"Increase the page size from %s to %s to avoid SMEM OOM",
224-
cache_config.block_size,
225-
min_page_size,
226-
)
227-
cache_config.block_size = min_page_size # type: ignore[assignment]
215+
if vllm_config.model_config.use_mla:
216+
from tpu_inference.layers.vllm.backends.flash_attn_mla import \
217+
PallasMLAttentionBackend
218+
cache_config.block_size = PallasMLAttentionBackend.get_page_size(
219+
vllm_config) # type: ignore[assignment]
220+
else:
221+
from tpu_inference.layers.vllm.backends.flash_attn import \
222+
PallasAttentionBackend
223+
cache_config.block_size = PallasAttentionBackend.get_page_size(
224+
vllm_config) # type: ignore[assignment]
225+
min_page_size = PallasAttentionBackend.get_min_page_size(
226+
vllm_config)
227+
if min_page_size > cache_config.block_size:
228+
logger.warning(
229+
"Increase the page size from %s to %s to avoid SMEM OOM",
230+
cache_config.block_size,
231+
min_page_size,
232+
)
233+
cache_config.block_size = min_page_size # type: ignore[assignment]
228234
logger.info(
229235
f"Using KV cache block size: {cache_config.block_size}")
230236

0 commit comments

Comments
 (0)