Skip to content

Commit 405ef39

Browse files
committed
round kv cache to allow expert sharding
Signed-off-by: Mohit Khatwani <mohitkhatwani@google.com>
1 parent 9d7b0d6 commit 405ef39

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

tpu_inference/runner/kv_cache_manager.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
389389
for j, layer_name in enumerate(kv_cache_tensor.shared_by):
390390
layer_spec = layer_name_to_spec[layer_name]
391391

392+
<<<<<<< HEAD
392393
page_size_bytes = layer_spec.page_size_bytes
393394
assert kv_cache_tensor.size % page_size_bytes == 0
394395
num_blocks = kv_cache_tensor.size // page_size_bytes
@@ -397,6 +398,42 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
397398
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
398399
# num_blocks must be a multiple of dp_size
399400
num_blocks = (num_blocks // dp_size) * dp_size
401+
=======
402+
page_size_bytes = layer_spec.page_size_bytes
403+
assert kv_cache_tensor.size % page_size_bytes == 0
404+
num_blocks = kv_cache_tensor.size // page_size_bytes
405+
sharding_config = self.runner.vllm_config.sharding_config
406+
if self.use_mla and not sharding_config.enable_dp_attention:
407+
# MLA KV cache is sharded with MLP_TENSOR = (attn_dp, attn_dp_expert, model, expert)
408+
divisor = (sharding_config.attn_dp_size *
409+
sharding_config.attn_dp_expert_size *
410+
sharding_config.tp_size *
411+
sharding_config.expert_size)
412+
else:
413+
divisor = sharding_config.total_dp_size
414+
# num_blocks must be a multiple of the sharding divisor
415+
num_blocks = (num_blocks // divisor) * divisor
416+
# NOTE: we'll multiply the num_kv_heads by 2 in the function
417+
if self.use_mla:
418+
head_size = self.runner.model_config.hf_config.kv_lora_rank + \
419+
self.runner.model_config.hf_config.qk_rope_head_dim
420+
else:
421+
head_size = layer_spec.head_size
422+
kv_cache = create_kv_caches(
423+
num_blocks=num_blocks,
424+
block_size=layer_spec.block_size,
425+
num_kv_heads=layer_spec.num_kv_heads,
426+
head_size=head_size,
427+
mesh=self.runner.mesh,
428+
layer_names=[f'kv_cache_tensor.{i}'],
429+
cache_dtype=t2j_dtype(layer_spec.dtype),
430+
use_mla=self.use_mla,
431+
)[0]
432+
kv_caches.append(kv_cache)
433+
num_blocks_list.append(num_blocks)
434+
for layer_name in kv_cache_tensor.shared_by:
435+
self.runner.layer_name_to_kvcache_index[layer_name] = i
436+
>>>>>>> 0c544707 (round kv cache to allow expert sharding)
400437

401438
if isinstance(layer_spec, MambaSpec):
402439
mamba_states = []

0 commit comments

Comments
 (0)