@@ -389,6 +389,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
389389 for j , layer_name in enumerate (kv_cache_tensor .shared_by ):
390390 layer_spec = layer_name_to_spec [layer_name ]
391391
392+ < << << << HEAD
392393 page_size_bytes = layer_spec .page_size_bytes
393394 assert kv_cache_tensor .size % page_size_bytes == 0
394395 num_blocks = kv_cache_tensor .size // page_size_bytes
@@ -397,6 +398,42 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
397398 dp_size = self .runner .vllm_config .sharding_config .total_dp_size
398399 # num_blocks must be a multiple of dp_size
399400 num_blocks = (num_blocks // dp_size ) * dp_size
401+ == == == =
402+ page_size_bytes = layer_spec .page_size_bytes
403+ assert kv_cache_tensor .size % page_size_bytes == 0
404+ num_blocks = kv_cache_tensor .size // page_size_bytes
405+ sharding_config = self .runner .vllm_config .sharding_config
406+ if self .use_mla and not sharding_config .enable_dp_attention :
407+ # MLA KV cache is sharded with MLP_TENSOR = (attn_dp, attn_dp_expert, model, expert)
408+ divisor = (sharding_config .attn_dp_size *
409+ sharding_config .attn_dp_expert_size *
410+ sharding_config .tp_size *
411+ sharding_config .expert_size )
412+ else :
413+ divisor = sharding_config .total_dp_size
414+ # num_blocks must be a multiple of the sharding divisor
415+ num_blocks = (num_blocks // divisor ) * divisor
416+ # NOTE: we'll multiply the num_kv_heads by 2 in the function
417+ if self .use_mla :
418+ head_size = self .runner .model_config .hf_config .kv_lora_rank + \
419+ self .runner .model_config .hf_config .qk_rope_head_dim
420+ else :
421+ head_size = layer_spec .head_size
422+ kv_cache = create_kv_caches (
423+ num_blocks = num_blocks ,
424+ block_size = layer_spec .block_size ,
425+ num_kv_heads = layer_spec .num_kv_heads ,
426+ head_size = head_size ,
427+ mesh = self .runner .mesh ,
428+ layer_names = [f'kv_cache_tensor.{ i } ' ],
429+ cache_dtype = t2j_dtype (layer_spec .dtype ),
430+ use_mla = self .use_mla ,
431+ )[0 ]
432+ kv_caches .append (kv_cache )
433+ num_blocks_list .append (num_blocks )
434+ for layer_name in kv_cache_tensor .shared_by :
435+ self .runner .layer_name_to_kvcache_index [layer_name ] = i
436+ > >> >> >> 0 c544707 (round kv cache to allow expert sharding )
400437
401438 if isinstance (layer_spec , MambaSpec ):
402439 mamba_states = []
0 commit comments