Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ __global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict
for (uint32_t j = 0; j < elemPerAccess; j++)
{
auto const val = src.unpacked[j];
dstK.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val);
dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val + kvOffset[poolIdx]);
dstK.unpacked[j] = (val == BAD_PAGE_INDEX) ? 0 : (indexScales[poolIdx] * val);
dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? 0 : (indexScales[poolIdx] * val + kvOffset[poolIdx]);
}
}
}
Expand Down
102 changes: 91 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,26 @@ def _get_kv_size_per_token(self):
mapping,
tokens_per_block=self._tokens_per_block)
elif self._should_create_separate_draft_kv_cache():
# One-model draft with separate KV cache layout
# One-model draft with separate KV cache layout.
# Pass num_layers explicitly since the HF config may report a
# different layer count than what is actually used at runtime
# (e.g. EAGLE3: config says 1, runtime uses 4).
# For PP, draft layers are only on the last rank (see
# get_pp_layers), so only that rank should include draft cost.
effective_draft_config = self._get_effective_draft_config()
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
effective_draft_config,
mapping,
tokens_per_block=self._tokens_per_block)
if self._speculative_config.spec_dec_mode.is_external_drafter():
# External drafter: layers start from 0, normal PP distribution
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
effective_draft_config,
mapping,
tokens_per_block=self._tokens_per_block)
elif mapping.is_last_pp_rank():
# EAGLE3/MTP: draft layers only on last PP rank
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
effective_draft_config,
mapping,
tokens_per_block=self._tokens_per_block,
num_layers=self._get_num_draft_layers())
return kv_size_per_token

def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
Expand Down Expand Up @@ -601,9 +615,21 @@ def _get_effective_draft_config(self) -> ModelConfig:
# layers as well.
return self._model_engine.model.model_config

def _get_num_draft_layers(self) -> int:
"""Return the actual number of draft KV cache layers.

This must stay in sync with the num_layers passed to the draft KV
cache manager constructor in _create_one_model_draft_kv_cache_manager.
"""
if self._speculative_config.spec_dec_mode.is_external_drafter():
return self._draft_config.pretrained_config.num_hidden_layers
return get_num_spec_layers(self._speculative_config)

def _create_one_model_draft_kv_cache_manager(
self,
estimating_kv_cache: bool = False) -> Optional[KVCacheManager]:
self,
estimating_kv_cache: bool = False,
kv_cache_config_override: Optional[KvCacheConfig] = None,
) -> Optional[KVCacheManager]:
"""
Create a KV cache manager for draft model layers in one-model mode
when target and draft have different KV cache layouts.
Expand All @@ -615,11 +641,10 @@ def _create_one_model_draft_kv_cache_manager(

# PARD, External Drafter: draft is a separate model, layers start from 0.
# Other methods (EAGLE3, MTP): draft layers are appended after target layers.
num_draft_layers = self._get_num_draft_layers()
if self._speculative_config.spec_dec_mode.is_external_drafter():
num_draft_layers = self._draft_config.pretrained_config.num_hidden_layers
spec_dec_layer_mask = [True] * num_draft_layers
else:
num_draft_layers = get_num_spec_layers(self._speculative_config)
spec_dec_layer_mask = [False] * target_num_layers + [
True
] * num_draft_layers
Expand Down Expand Up @@ -650,11 +675,12 @@ def _create_one_model_draft_kv_cache_manager(
# the sparse_attention_config. Get it from effective_draft_config which
# falls back to the target model's config for MTP mode.
sparse_attn_config = effective_draft_config.sparse_attention_config
draft_kv_config = kv_cache_config_override if kv_cache_config_override is not None else self._kv_cache_config
return _create_kv_cache_manager(
model_engine=None,
kv_cache_manager_cls=draft_kv_cache_manager_cls,
mapping=self._mapping,
kv_cache_config=self._kv_cache_config,
kv_cache_config=draft_kv_config,
tokens_per_block=self._tokens_per_block,
max_seq_len=self._max_seq_len,
max_batch_size=self._max_batch_size,
Expand All @@ -673,13 +699,63 @@ def _create_one_model_draft_kv_cache_manager(
num_layers=num_draft_layers,
)

def _split_kv_cache_budget_for_draft(self) -> Optional[KvCacheConfig]:
"""Split max_gpu_total_bytes between target and draft KV caches.

When using KVCacheManagerV2 with a separate draft KV cache,
max_gpu_total_bytes represents the total budget for both target and
draft combined. This method splits the budget proportionally based
on their per-token KV cache sizes.

Returns a cloned KvCacheConfig for the draft, or None if no split is
needed. Also modifies self._kv_cache_config.max_gpu_total_bytes
in-place for the target.
"""
total_budget = self._kv_cache_config.max_gpu_total_bytes
if total_budget is None or total_budget <= 0:
return None

total_kv = self._get_kv_size_per_token()
target_kv = self._kv_cache_manager_cls.get_cache_size_per_token(
self._model_engine.model.model_config,
self._mapping,
tokens_per_block=self._tokens_per_block)
draft_kv = total_kv - target_kv
if total_kv <= 0 or draft_kv <= 0:
return None

draft_budget = int(total_budget * draft_kv / total_kv)
target_budget = total_budget - draft_budget

logger.info(
f"Splitting KV cache budget: total={total_budget / GB:.2f} GiB, "
f"target={target_budget / GB:.2f} GiB ({target_kv}B/tok), "
f"draft={draft_budget / GB:.2f} GiB ({draft_kv}B/tok)")

self._kv_cache_config.max_gpu_total_bytes = target_budget

draft_kv_cache_config = self._kv_cache_config.model_copy()
draft_kv_cache_config.max_gpu_total_bytes = draft_budget
return draft_kv_cache_config

def build_managers(self,
resources: Dict,
estimating_kv_cache: bool = False) -> None:
"""Construct KV caches for model and draft model (if applicable)."""
if self._skip_est:
self.configure_kv_cache_capacity()

# For V2 with separate one-model draft KV cache, split the total budget
# between target and draft before creating either manager.
# Only split for the final managers, not during estimation — estimation
# uses max_tokens-based logic and must not have its config mutated.
# Two-model draft is excluded: V2 does not support two-model mode.
draft_kv_cache_config = None
if (not estimating_kv_cache
and self._should_create_separate_draft_kv_cache()
and issubclass(self._kv_cache_manager_cls, KVCacheManagerV2)):
draft_kv_cache_config = self._split_kv_cache_budget_for_draft()

kv_cache_manager = self._create_kv_cache_manager(
self._model_engine, estimating_kv_cache)

Expand All @@ -691,12 +767,16 @@ def build_managers(self,

# Two-model speculative decoding: draft model has separate engine
if self._draft_model_engine is not None:
assert draft_kv_cache_config is None, (
"KVCacheManagerV2 does not support two-model speculative decoding "
"with separate draft KV cache budget splitting.")
draft_kv_cache_manager = self._create_kv_cache_manager(
self._draft_model_engine, estimating_kv_cache)
# One-model speculative decoding with different KV layouts
elif self._should_create_separate_draft_kv_cache():
draft_kv_cache_manager = self._create_one_model_draft_kv_cache_manager(
estimating_kv_cache)
estimating_kv_cache,
kv_cache_config_override=draft_kv_cache_config)

resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
resources[
Expand Down
51 changes: 37 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,9 @@ def calculate_scaling_factor_size_bytes(
# TODO: refactor get_cache_size_per_token and get_cache_bytes_per_token to use the same logic
@staticmethod
def get_cache_size_per_token(model_config: ModelConfigPython,
mapping: Mapping, **kwargs):
mapping: Mapping,
num_layers: Optional[int] = None,
**kwargs):

# get num key value heads
config = model_config.pretrained_config
Expand All @@ -833,9 +835,18 @@ def get_cache_size_per_token(model_config: ModelConfigPython,
head_dim = head_dim * num_key_value_heads // tp_size
kv_factor = 2

# provide at least 1 layer to prevent division by zero cache size
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
# When num_layers is explicitly provided (e.g. for draft models
# where the HF config layer count differs from runtime), use it
# directly without PP distribution. Draft layers have their own
# PP assignment logic (see get_pp_layers) that doesn't match the
# standard uniform split, so pp_layers() would give wrong results.
if num_layers is not None:
num_attention_layers = max(num_layers, 1)
else:
# provide at least 1 layer to prevent division by zero cache size
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())),
1)
# K and V
mem_per_token = kv_factor * num_attention_layers * head_dim
# The data type bytes.
Expand Down Expand Up @@ -2107,15 +2118,13 @@ def release_resources(current_request: LlmRequest,
new_capacity = kv_cache.capacity + max_num_draft_tokens + 1
success = kv_cache.resize(new_capacity)
if not success:
raise ValueError(
f"Failed to resize capacity of KV cache for request {req.py_request_id} to {new_capacity} tokens for dummy request"
)
release_resources(req)
return None
if draft_kv_cache is not None:
success = draft_kv_cache.resize(new_capacity)
if not success:
raise ValueError(
f"Failed to resize capacity of draft KV cache for request {req.py_request_id} to {new_capacity} tokens for dummy request"
)
release_resources(req, free_draft_resources=True)
return None

# TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrope_config to the request here.
if use_mrope:
Expand Down Expand Up @@ -2314,7 +2323,9 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
# TODO: refactor get_cache_size_per_token and get_cache_bytes_per_token to use the same logic
@staticmethod
def get_cache_size_per_token(model_config: ModelConfigPython,
mapping: Mapping, **kwargs):
mapping: Mapping,
num_layers: Optional[int] = None,
**kwargs):
# get kv cache dtype bytes
mem_per_token = 2
quant_config = model_config.quant_config
Expand Down Expand Up @@ -2343,9 +2354,18 @@ def get_cache_size_per_token(model_config: ModelConfigPython,
head_dim = head_dim * num_key_value_heads // tp_size
kv_factor = 2

# provide at least 1 layer to prevent division by zero cache size
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
# When num_layers is explicitly provided (e.g. for draft models
# where the HF config layer count differs from runtime), use it
# directly without PP distribution. Draft layers have their own
# PP assignment logic (see get_pp_layers) that doesn't match the
# standard uniform split, so pp_layers() would give wrong results.
if num_layers is not None:
num_attention_layers = max(num_layers, 1)
else:
# provide at least 1 layer to prevent division by zero cache size
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())),
1)
mem_per_token *= num_attention_layers * head_dim

# K and V
Expand Down Expand Up @@ -2421,6 +2441,9 @@ def _create_kv_cache(self, request_id: int, lora_task_id: int | None,
memoryview(buffer.numpy()))
return kv_cache

def reset_reuse_state(self):
self.impl.clear_reusable_blocks()


class SlotManager:

Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,7 +2159,7 @@ class KvCacheConfig(StrictBaseModel, PybindMirror):
description="The number of tokens per block.")

use_kv_cache_manager_v2: bool = Field(
default=False,
default=True,
status="prototype",
description="Whether to use the KV cache manager v2 (experimental).")

Expand Down
Loading
Loading