Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,15 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
v_head_dim = self.config.v_head_dim
kv_lora_rank = self.config.kv_lora_rank

tp_rank = self.model_config.mapping.tp_rank
tp_size = self.model_config.mapping.tp_size
target_tp_rank = self.model_config.mapping.tp_rank
target_tp_size = self.model_config.mapping.tp_size
tp_rank = target_tp_rank
tp_size = target_tp_size
cp_rank = self.model_config.mapping.cp_rank
cp_size = self.model_config.mapping.cp_size

draft_mapping = getattr(self.model, '_draft_mapping', None)

params_map = {'gate_up_proj': ['gate_proj', 'up_proj']}
all_named_modules = dict(self.model.named_modules())

Expand All @@ -342,6 +346,16 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
else:
names = name.split('.')
parent_module_name = '.'.join(names[:-1])
is_mtp_layer = ("model.layers" in name and len(names) > 2
and names[2].isdigit()
and int(names[2])
>= self.config.num_hidden_layers)
if is_mtp_layer and draft_mapping is not None:
tp_rank = draft_mapping.tp_rank
tp_size = draft_mapping.tp_size
else:
tp_rank = target_tp_rank
tp_size = target_tp_size
if "model.layers" in name and int(
names[2]) >= self.config.num_hidden_layers:
mtp_layer_idx = int(
Expand Down
46 changes: 33 additions & 13 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import LlamaConfig, PretrainedConfig

from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import create_draft_mapping

from ...functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
Expand Down Expand Up @@ -961,7 +962,11 @@ def forward(self,
)


def get_draft_model(model_config, draft_config, lm_head, model):
def get_draft_model(model_config,
draft_config,
lm_head,
model,
draft_model_config=None):
assert getattr(model_config, 'spec_config', None) is not None
spec_dec_mode = model_config.spec_config.spec_dec_mode
if spec_dec_mode.is_eagle3_one_model():
Expand All @@ -979,7 +984,8 @@ def get_draft_model(model_config, draft_config, lm_head, model):
)

elif spec_dec_mode.is_mtp_one_model():
return MTPForCausalLM(model_config,
mtp_config = draft_model_config if draft_model_config is not None else model_config
return MTPForCausalLM(mtp_config,
model_config.pretrained_config.num_hidden_layers,
lm_head, model)
elif spec_dec_mode.is_mtp_eagle():
Expand Down Expand Up @@ -1009,6 +1015,13 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
spec_config = getattr(model_config, 'spec_config', None)
self.spec_config = spec_config
if spec_config and spec_config.spec_dec_mode.use_one_engine():
draft_tp = getattr(spec_config, 'draft_tp_size', None)
draft_mapping = create_draft_mapping(model_config.mapping,
draft_tp)
self._draft_mapping = (draft_mapping
if draft_mapping is not model_config.mapping
else None)

# Only create draft_model for modes MTP, Eagle3 (not SA)
if not spec_config.spec_dec_mode.is_sa():
if spec_config.spec_dec_mode.is_eagle3_one_model():
Expand All @@ -1017,7 +1030,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
MistralConfigLoader
self.draft_config = MistralConfigLoader().load(
spec_config.speculative_model,
mapping=model_config.mapping,
mapping=draft_mapping,
moe_backend=model_config.moe_backend,
moe_max_num_tokens=model_config.moe_max_num_tokens,
max_num_tokens=model_config.max_num_tokens,
Expand All @@ -1031,7 +1044,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,
mapping=model_config.mapping,
mapping=draft_mapping,
spec_config=model_config.spec_config,
max_num_tokens=model_config.max_num_tokens,
moe_max_num_tokens=model_config.moe_max_num_tokens)
Expand All @@ -1048,7 +1061,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,
mapping=model_config.mapping,
mapping=draft_mapping,
spec_config=None, # Avoid recursive spec-dec
max_num_tokens=model_config.max_num_tokens,
moe_max_num_tokens=model_config.moe_max_num_tokens)
Expand All @@ -1058,23 +1071,30 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
spec_config)

draft_model_config = None
if draft_mapping is not model_config.mapping:
draft_model_config = replace(model_config,
mapping=draft_mapping)
if self.draft_config is None:
self.draft_config = draft_model_config

self.draft_model = get_draft_model(model_config,
self.draft_config,
self.lm_head, self.model)
self.lm_head, self.model,
draft_model_config)
Comment on lines +1074 to +1084
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The synthetic MTP draft config still carries the target layer count.

replace(model_config, mapping=draft_mapping) only swaps the mapping. The cloned config still has pretrained_config.num_hidden_layers == target_num_layers, but tensorrt_llm/_torch/pyexecutor/_util.py now sizes the separate draft KV cache from effective_draft_config. In MTP one-model mode that makes KV estimation count all target layers instead of just the draft layers, which can drastically shrink max_tokens or raise false OOMs for the new draft_tp_size path. Please either synthesize a draft config with the draft layer count, or keep the layer count explicit in the KV-size path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_speculative.py` around lines 1074 - 1084,
The synthetic draft config created with replace(model_config,
mapping=draft_mapping) still carries the original
pretrained_config.num_hidden_layers (target layer count), causing KV cache
sizing to use the wrong layer count; update the draft config creation in
modeling_speculative.py (the draft_model_config variable passed into
get_draft_model) to also set the draft layer count (e.g., clone/replace
pretrained_config.num_hidden_layers or set an explicit attribute like
draft_model_config.pretrained_config.num_hidden_layers = draft_layer_count) so
effective_draft_config used by tensorrt_llm/_torch/pyexecutor/_util.py reflects
the actual draft layers, or alternatively propagate the draft layer count
explicitly into the KV-sizing path (ensure functions using
effective_draft_config or draft_tp_size read the draft layer count instead of
the original num_hidden_layers).

if self.draft_model is not None:
self.epilogue.append(self.draft_model)
if spec_config.spec_dec_mode.is_pard(
) and self.draft_model is not None:
self.draft_model.logits_processor = self.logits_processor

# EAGLE3-specific logic: merge extra_attrs from draft model for Llama3
if (self.draft_config is not None and model_config.spec_config.
spec_dec_mode.is_eagle3_one_model()
and model_config.spec_config.eagle3_model_arch == "llama3"):
if (self.draft_config is not None
and hasattr(self.draft_config, 'extra_attrs')
and self.draft_config is not model_config):
for key, value in self.draft_config.extra_attrs.items():
assert key in ('attn_layers', 'mla_layers')
assert key in model_config.extra_attrs
model_config.extra_attrs[key].update(value)
if key in ('attn_layers', 'mla_layers'):
model_config.extra_attrs.setdefault(key, {}).update(
value)

# spec_worker is created for all one-engine modes (MTP, Eagle3, SA)
self.spec_worker = get_spec_worker(
Expand Down
88 changes: 50 additions & 38 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,16 +438,19 @@ def __init__(
assert self.mapping.has_cp_helix(
), f"CP type must be HELIX for Attention, but got {self.mapping.cp_config['cp_type']}."

mapping = Mapping(
world_size=dp_size * tp_size * pp_size * cp_size,
tp_size=tp_size,
pp_size=pp_size * dp_size,
cp_size=cp_size,
cp_config=self.mapping.cp_config,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
if dp_size == 1 and cp_size == 1:
mapping = self.mapping
else:
mapping = Mapping(
world_size=dp_size * tp_size * pp_size * cp_size,
tp_size=tp_size,
pp_size=pp_size * dp_size,
cp_size=cp_size,
cp_config=self.mapping.cp_config,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
Comment on lines +441 to +453
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve the draft mapping in the CP/attention-DP branches.

The fast path now keeps self.mapping intact, but the fallback still rebuilds a plain Mapping. That drops the draft-specific topology this PR is introducing, so any draft run that also enables Helix CP or attention-DP can fall back to raw-rank sharding here and feed the wrong local slices into qkv_proj / o_proj.

Also applies to: 501-512, 1108-1120, 1216-1227

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/attention.py` around lines 441 - 453, The
fallback branch that constructs a plain Mapping loses draft-specific topology
and drops the draft mapping used by Helix CP/attention-DP; instead of always
instantiating Mapping(...) when cp_size or dp_size != 1, preserve and reuse the
draft-aware self.mapping (or copy/clone it while adjusting only the necessary
fields) so the draft topology remains intact for qkv_proj/o_proj sharding;
update the logic in the branches around the Mapping construction (the block that
currently assigns mapping = Mapping(...), referenced by self.mapping, Mapping,
cp_size, dp_size, tp_size, pp_size, and mapping.cp_config) to maintain the draft
mapping information when building per-branch mappings (apply same fix to the
other occurrences noted at the other ranges).

self.tp_size = tp_size
self.cp_size = cp_size
self.tp_rank = mapping.tp_rank
Expand Down Expand Up @@ -495,15 +498,18 @@ def __init__(

# For Helix CP, combine TP and CP for the output projection so each
# rank's o_proj input is num_heads_tp_cp * head_dim.
mapping_o = Mapping(
world_size=dp_size * tp_size * pp_size * cp_size,
tp_size=tp_size * cp_size,
pp_size=pp_size * dp_size,
cp_size=1,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
if dp_size == 1 and cp_size == 1:
mapping_o = self.mapping
else:
mapping_o = Mapping(
world_size=dp_size * tp_size * pp_size * cp_size,
tp_size=tp_size * cp_size,
pp_size=pp_size * dp_size,
cp_size=1,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
self.mapping_o = mapping_o

self.o_proj = Linear(
Expand Down Expand Up @@ -1099,16 +1105,19 @@ def __init__(
assert self.mapping.has_cp_helix(
), f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}."

mapping = Mapping(
world_size=pp_size * dp_size * tp_size * cp_size,
tp_size=tp_size,
pp_size=pp_size * dp_size,
cp_size=cp_size,
cp_config=self.mapping.cp_config,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
if dp_size == 1 and cp_size == 1:
mapping = self.mapping
else:
mapping = Mapping(
world_size=pp_size * dp_size * tp_size * cp_size,
tp_size=tp_size,
pp_size=pp_size * dp_size,
cp_size=cp_size,
cp_config=self.mapping.cp_config,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)

assert self.num_heads % (tp_size * cp_size) == 0
self.num_heads_tp = self.num_heads // tp_size
Expand Down Expand Up @@ -1204,15 +1213,18 @@ def __init__(
requires_grad=False,
)

mapping_o = Mapping(
world_size=pp_size * dp_size * tp_size * cp_size,
tp_size=tp_size * cp_size,
pp_size=pp_size * dp_size,
cp_size=1,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
if dp_size == 1 and cp_size == 1:
mapping_o = self.mapping
else:
mapping_o = Mapping(
world_size=pp_size * dp_size * tp_size * cp_size,
tp_size=tp_size * cp_size,
pp_size=pp_size * dp_size,
cp_size=1,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
self.mapping_o = mapping_o
self.o_proj = Linear(
self.num_key_value_heads * self.v_head_dim,
Expand Down
17 changes: 15 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ def _get_kv_size_per_token(self):
elif self._should_create_separate_draft_kv_cache():
# One-model draft with separate KV cache layout
effective_draft_config = self._get_effective_draft_config()
draft_mapping = effective_draft_config.mapping
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
effective_draft_config,
mapping,
draft_mapping,
tokens_per_block=self._tokens_per_block)
return kv_size_per_token

Expand Down Expand Up @@ -571,6 +572,16 @@ def _should_create_separate_draft_kv_cache(self) -> bool:
back to the target model's config via _get_effective_draft_config().
"""
if self._mapping.enable_attention_dp:
draft_tp = getattr(self._speculative_config, 'draft_tp_size',
None)
if (draft_tp is not None
and draft_tp < self._mapping.tp_size):
raise ValueError(
f"draft_tp_size ({draft_tp}) < tp_size "
f"({self._mapping.tp_size}) requires separate "
"draft/target KV caches, which are not supported "
"with attention data parallelism."
)
logger.info(
"Attention DP is enabled, separate draft KV cache is not supported."
)
Expand Down Expand Up @@ -637,11 +648,13 @@ def _create_one_model_draft_kv_cache_manager(
"Falling back to KVCacheManager for draft model.")
draft_kv_cache_manager_cls = KVCacheManager

draft_mapping = effective_draft_config.mapping

estimating_kv_cache = estimating_kv_cache and not self._skip_est
return _create_kv_cache_manager(
model_engine=None,
kv_cache_manager_cls=draft_kv_cache_manager_cls,
mapping=self._mapping,
mapping=draft_mapping,
kv_cache_config=self._kv_cache_config,
tokens_per_block=self._tokens_per_block,
max_seq_len=self._max_seq_len,
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,10 @@ def init_meta_tensor(t: torch.Tensor):

if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
):
draft_mapping = model.draft_config.mapping
weights = checkpoint_loader.load_weights(
self.spec_config.speculative_model,
mapping=self.mapping)
mapping=draft_mapping)

draft_model_arch = model.draft_config.pretrained_config.architectures[
0]
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,13 @@ def create_py_executor(
# Disable separate draft KV cache in disaggregated mode
# Enable separate pool for None DI + Non-KVBM and Aggregated + KVBM
if cache_transceiver_config is not None:
draft_tp = getattr(spec_config, 'draft_tp_size', None)
if draft_tp is not None and draft_tp < mapping.tp_size:
raise ValueError(
f"draft_tp_size ({draft_tp}) < tp_size ({mapping.tp_size}) "
"requires separate draft/target KV caches, which are not "
"supported in disaggregated serving mode."
)
spec_config._allow_separate_draft_kv_cache = False

# chunk_unit_size may be changed to 64 when using flash mla
Expand Down
21 changes: 21 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,15 @@ class DecodingBaseConfig(StrictBaseModel):
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
# Internal: record decoding_type alias used during parsing (for warnings).
_decoding_type_alias: Optional[str] = PrivateAttr(default=None)
draft_tp_size: Optional[PositiveInt] = Field(
default=None,
description=
"Tensor parallelism size for the draft model in one-model speculative decoding. "
"When set, must divide the target model's TP size evenly. If None, the draft model "
"uses the same TP size as the target model. Smaller values reduce communication "
"overhead for the (typically small) draft model at the cost of redundant computation."
)

# If set, drafting will use separate KV cache in one-model speculative decoding.
_allow_separate_draft_kv_cache: bool = PrivateAttr(True)

Expand Down Expand Up @@ -3294,6 +3303,18 @@ def validate_speculative_config(self):
if self.backend == "_autodeploy":
self.speculative_config._draft_target_one_model = False

draft_tp = self.speculative_config.draft_tp_size
if draft_tp is not None:
target_tp = self.parallel_config.tp_size
if draft_tp > target_tp:
raise ValueError(
f"draft_tp_size ({draft_tp}) must be <= target tp_size ({target_tp})."
)
if target_tp % draft_tp != 0:
raise ValueError(
f"target tp_size ({target_tp}) must be divisible by "
f"draft_tp_size ({draft_tp}).")

else:
self.decoding_config = None

Expand Down
Loading
Loading