Skip to content

Commit 88df022

Browse files
committed
round kv cache to allow expert sharding
Signed-off-by: Mohit Khatwani <mohitkhatwani@google.com>
1 parent 9d7b0d6 commit 88df022

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tpu_inference/runner/kv_cache_manager.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,18 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
394394
num_blocks = kv_cache_tensor.size // page_size_bytes
395395
if duplicate_shared_layers:
396396
num_blocks //= num_shared_layers
397-
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
398-
# num_blocks must be a multiple of dp_size
399-
num_blocks = (num_blocks // dp_size) * dp_size
397+
sharding_config = self.runner.vllm_config.sharding_config
398+
if self.use_mla and not sharding_config.sharding_strategy.get(
399+
"enable_dp_attention", False):
400+
# MLA KV cache is sharded with MLP_TENSOR = (attn_dp, attn_dp_expert, model, expert)
401+
divisor = (sharding_config.attn_dp_size *
402+
sharding_config.attn_dp_expert_size *
403+
sharding_config.tp_size *
404+
sharding_config.expert_size)
405+
else:
406+
divisor = sharding_config.total_dp_size
407+
# num_blocks must be a multiple of the sharding divisor
408+
num_blocks = (num_blocks // divisor) * divisor
400409

401410
if isinstance(layer_spec, MambaSpec):
402411
mamba_states = []

0 commit comments

Comments
 (0)