Skip to content

Commit 80aa8ca

Browse files
authored
[TRTLLM-10886][feat] Support PARD(Parallel Draft Model) in one-model spec dec (NVIDIA#11438)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
1 parent 3fd5faf commit 80aa8ca

File tree

26 files changed

+797
-62
lines changed

26 files changed

+797
-62
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ repos:
14411441
additional_dependencies:
14421442
- tomli
14431443
# add ignore words list
1444-
args: ["-L", "Mor,ans,thirdparty,subtiles", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
1444+
args: ["-L", "Mor,ans,thirdparty,subtiles,PARD,pard", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
14451445
exclude: 'scripts/attribution/data/cas/.*$'
14461446
- repo: https://github.com/astral-sh/ruff-pre-commit
14471447
rev: v0.9.4

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def __init__(
518518

519519
# check for max total draft tokens
520520
if self.spec_config is not None:
521-
self.max_total_draft_tokens = self.spec_config.max_total_draft_tokens
521+
self.max_total_draft_tokens = self.spec_config.tokens_per_gen_step - 1
522522
else:
523523
self.max_total_draft_tokens = 0
524524

@@ -1063,7 +1063,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
10631063
max_total_draft_tokens = (
10641064
0
10651065
if ad_config.speculative_config is None
1066-
else ad_config.speculative_config.max_total_draft_tokens
1066+
else ad_config.speculative_config.tokens_per_gen_step - 1
10671067
)
10681068

10691069
# initialize model engine

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def __init__(
718718
reduce_output: bool = True,
719719
):
720720
config = model_config.pretrained_config
721-
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
721+
predicted_tokens_per_seq = model_config.spec_config.tokens_per_gen_step if model_config.spec_config is not None else 1
722722
super().__init__(hidden_size=config.hidden_size,
723723
num_attention_heads=config.num_attention_heads,
724724
num_key_value_heads=config.num_key_value_heads,
@@ -766,7 +766,7 @@ def __init__(
766766
reduce_output: bool = True,
767767
):
768768
config = model_config.pretrained_config
769-
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
769+
predicted_tokens_per_seq = model_config.spec_config.tokens_per_gen_step if model_config.spec_config is not None else 1
770770

771771
super().__init__(hidden_size=config.hidden_size,
772772
num_attention_heads=config.num_attention_heads,

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import replace
12
from typing import Dict, Generic, List, Optional, Tuple
23

34
import torch
@@ -24,7 +25,7 @@
2425
from ..utils import AuxStreamType
2526
from .checkpoints.base_weight_mapper import BaseWeightMapper
2627
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
27-
register_auto_model)
28+
get_model_architecture, register_auto_model)
2829

2930

3031
def _ensure_draft_vocab_size(config: PretrainedConfig) -> None:
@@ -108,9 +109,9 @@ def __init__(
108109
config = model_config.pretrained_config
109110
self._next_layer_regular = next_layer_regular
110111

111-
predicted_tokens_per_seq = (
112-
model_config.spec_config.max_total_draft_tokens +
113-
1 if model_config.spec_config is not None else 1)
112+
predicted_tokens_per_seq = (model_config.spec_config.tokens_per_gen_step
113+
if model_config.spec_config is not None else
114+
1)
114115

115116
super().__init__(
116117
hidden_size=config.hidden_size,
@@ -702,6 +703,70 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
702703
return hidden_states
703704

704705

706+
class PARDForCausalLM(nn.Module):
707+
"""Draft model wrapper for PARD (Parallel Draft) speculative decoding.
708+
709+
See PARDWorker for the full algorithm description.
710+
"""
711+
712+
def __init__(self, draft_config):
713+
super().__init__()
714+
DraftModelClass, _ = get_model_architecture(
715+
draft_config.pretrained_config)
716+
717+
# Remove spec_config to prevent recursive spec-dec initialization
718+
draft_config_no_spec = replace(draft_config, spec_config=None)
719+
720+
# Weights will be loaded later by ModelLoader.load_draft_weights()
721+
self.draft_model_full = DraftModelClass(draft_config_no_spec)
722+
self.model = self.draft_model_full.model
723+
self.lm_head = self.draft_model_full.lm_head
724+
725+
# Required by weight mappers
726+
self.model_config = draft_config_no_spec
727+
self.config = draft_config_no_spec.pretrained_config
728+
729+
# Fall back: pard_token -> mask_token_id -> vocab_size
730+
pretrained_config = draft_config.pretrained_config
731+
self.mask_token_id = getattr(
732+
pretrained_config, 'pard_token',
733+
getattr(pretrained_config, 'mask_token_id',
734+
pretrained_config.vocab_size))
735+
logger.info(
736+
f"PARD draft model initialized with mask_token_id: {self.mask_token_id}"
737+
)
738+
739+
self.logits_processor = None # Set by caller after construction
740+
741+
def load_weights(self, weights: Dict, weight_mapper=None, **kwargs):
742+
"""Load weights into the PARD draft model."""
743+
self.draft_model_full.load_weights(weights=weights,
744+
weight_mapper=weight_mapper,
745+
**kwargs)
746+
747+
def forward(
748+
self,
749+
attn_metadata,
750+
input_ids: torch.LongTensor = None,
751+
position_ids: torch.LongTensor | None = None,
752+
inputs_embeds: torch.FloatTensor | None = None,
753+
return_context_logits: bool = False,
754+
spec_metadata=None,
755+
hidden_states: torch.Tensor | None = None,
756+
**kwargs,
757+
) -> tuple[torch.Tensor, torch.Tensor]:
758+
hidden_states_out = self.model(
759+
input_ids=input_ids,
760+
attn_metadata=attn_metadata,
761+
position_ids=position_ids,
762+
inputs_embeds=inputs_embeds,
763+
spec_metadata=spec_metadata,
764+
**kwargs,
765+
)
766+
767+
return hidden_states_out, hidden_states_out
768+
769+
705770
class MTPForCausalLM(nn.Module):
706771

707772
def __init__(
@@ -917,6 +982,8 @@ def get_draft_model(model_config, draft_config, lm_head, model):
917982
lm_head, model)
918983
elif spec_dec_mode.is_mtp_eagle():
919984
return MTPDraftModelForCausalLM(model_config)
985+
elif spec_dec_mode.is_pard():
986+
return PARDForCausalLM(draft_config)
920987
else:
921988
raise NotImplementedError(
922989
f"get_draft_model does not support speculative decoding mode {spec_dec_mode}."
@@ -967,11 +1034,27 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
9671034
self.draft_config.quant_config.kv_cache_quant_algo = \
9681035
model_config.quant_config.kv_cache_quant_algo
9691036

1037+
elif spec_config.spec_dec_mode.is_pard():
1038+
self.draft_config = ModelConfig.from_pretrained(
1039+
model_config.spec_config.speculative_model,
1040+
trust_remote_code=True,
1041+
attn_backend=model_config.attn_backend,
1042+
moe_backend=model_config.moe_backend,
1043+
mapping=model_config.mapping,
1044+
spec_config=None, # Avoid recursive spec-dec
1045+
max_num_tokens=model_config.max_num_tokens,
1046+
moe_max_num_tokens=model_config.moe_max_num_tokens)
1047+
self.draft_config.quant_config.kv_cache_quant_algo = \
1048+
model_config.quant_config.kv_cache_quant_algo
1049+
9701050
self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
9711051
spec_config)
9721052

9731053
self.draft_model = get_draft_model(model_config, self.draft_config,
9741054
self.lm_head, self.model)
1055+
if spec_config.spec_dec_mode.is_pard(
1056+
) and self.draft_model is not None:
1057+
self.draft_model.logits_processor = self.logits_processor
9751058
self.spec_worker = get_spec_worker(
9761059
model_config.spec_config,
9771060
model_config,
@@ -980,7 +1063,10 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
9801063
self.epilogue.append(self.draft_model)
9811064
self.epilogue.append(self.spec_worker)
9821065

983-
if self.draft_config is not None and model_config.spec_config.eagle3_model_arch == "llama3":
1066+
# EAGLE3-specific logic: merge extra_attrs from draft model for Llama3
1067+
if (self.draft_config is not None and model_config.spec_config.
1068+
spec_dec_mode.is_eagle3_one_model()
1069+
and model_config.spec_config.eagle3_model_arch == "llama3"):
9841070
for key, value in self.draft_config.extra_attrs.items():
9851071
assert key in ('attn_layers', 'mla_layers')
9861072
assert key in model_config.extra_attrs
@@ -1067,7 +1153,9 @@ def load_draft_weights(self,
10671153
weight_mapper: Optional[BaseWeightMapper] = None):
10681154
self.draft_model.load_weights(weights=weights,
10691155
weight_mapper=weight_mapper)
1070-
self.draft_model.load_weights_from_target_model(self)
1156+
# PARD has independent weights; other methods share with target model
1157+
if not self.model_config.spec_config.spec_dec_mode.is_pard():
1158+
self.draft_model.load_weights_from_target_model(self)
10711159

10721160
def set_guided_decoder(self,
10731161
guided_decoder: CapturableGuidedDecoder) -> bool:

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ def _get_token_num_for_estimation(self) -> int:
277277
num_extra_tokens_per_seq = 1 # account for generated tokens
278278
spec_cfg = self._speculative_config
279279
if not self._llm_args.disable_overlap_scheduler and spec_cfg is not None:
280-
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
280+
num_extra_tokens_per_seq += spec_cfg.tokens_per_gen_step - 1
281281

282282
if spec_cfg is not None:
283-
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
283+
num_extra_tokens_per_seq += spec_cfg.tokens_per_gen_step - 1
284284
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
285285

286286
if self._dummy_reqs is None:
@@ -570,15 +570,17 @@ def _create_one_model_draft_kv_cache_manager(
570570
# Draft model layers in one-model mode start at target_num_layers.
571571
target_pretrained_config = self._model_engine.model.model_config.pretrained_config
572572
target_num_layers = target_pretrained_config.num_hidden_layers
573-
# Use get_num_spec_layers to get the correct number of draft layers
574-
# for the speculative decoding mode (e.g., num_eagle_layers for Eagle3)
575-
num_draft_layers = get_num_spec_layers(self._speculative_config)
576573

577-
# Create layer_mask: False for target layers, True for draft layers.
578-
# This ensures the draft KV cache manager uses the correct layer indices
579-
# (e.g., layers 32, 33, ... instead of 0, 1, ...).
580-
spec_dec_layer_mask = [False
581-
] * target_num_layers + [True] * num_draft_layers
574+
# PARD: draft is a separate model, layers start from 0.
575+
# Other methods (EAGLE3, MTP): draft layers are appended after target layers.
576+
if self._speculative_config.spec_dec_mode.is_pard():
577+
num_draft_layers = self._draft_config.pretrained_config.num_hidden_layers
578+
spec_dec_layer_mask = [True] * num_draft_layers
579+
else:
580+
num_draft_layers = get_num_spec_layers(self._speculative_config)
581+
spec_dec_layer_mask = [False] * target_num_layers + [
582+
True
583+
] * num_draft_layers
582584

583585
# Get the effective draft config (explicit draft_config if available,
584586
# otherwise fall back to target model config for MTP).
@@ -1091,8 +1093,8 @@ def create_py_executor_instance(
10911093
max_beam_width=max_beam_width,
10921094
max_draft_len=spec_config.max_draft_len
10931095
if spec_config is not None else 0,
1094-
max_total_draft_tokens=spec_config.max_total_draft_tokens
1095-
if spec_config is not None else 0,
1096+
max_total_draft_tokens=(spec_config.tokens_per_gen_step -
1097+
1) if spec_config is not None else 0,
10961098
kv_cache_transceiver=kv_cache_transceiver,
10971099
guided_decoder=guided_decoder,
10981100
start_worker=start_worker,
@@ -1120,7 +1122,7 @@ def create_torch_sampler_args(
11201122
max_draft_len = (0 if speculative_config is None else
11211123
speculative_config.max_draft_len)
11221124
max_total_draft_tokens = (0 if speculative_config is None else
1123-
speculative_config.max_total_draft_tokens)
1125+
speculative_config.tokens_per_gen_step - 1)
11241126

11251127
return TorchSampler.Args(
11261128
max_seq_len=max_seq_len,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ def __init__(
173173
ExpertStatistic.create(self.dist.rank)
174174
self.llm_args = llm_args
175175
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
176-
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
176+
self.original_max_total_draft_tokens = (
177+
spec_config.tokens_per_gen_step -
178+
1) if spec_config is not None else 0
179+
# Saved before zeroing for draft models; used by update_spec_dec_param.
180+
self._spec_dec_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
177181

178182
# The draft model won't have any draft tokens attached to
179183
# generation requests when we invoke it autoregressively
@@ -342,7 +346,7 @@ def __init__(
342346
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
343347
) or self.model_is_wrapped
344348
self.max_draft_len = spec_config.max_draft_len
345-
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
349+
self.max_total_draft_tokens = spec_config.tokens_per_gen_step - 1
346350
else:
347351
self.without_logits = False
348352
self.max_draft_len = 0
@@ -389,8 +393,9 @@ def __init__(
389393
# Pre-allocated buffers for draft model to avoid implicit synchronization
390394
# These are used to build index tensors without creating tensors from Python lists
391395
max_first_draft_tokens = self.batch_size * (
392-
self.original_max_draft_len + 1) if spec_config else self.batch_size
393-
tokens_per_draft = self.original_max_draft_len + 1
396+
self.original_max_total_draft_tokens +
397+
1) if spec_config else self.batch_size
398+
tokens_per_draft = self.original_max_total_draft_tokens + 1
394399
self.idx_accepted_tokens_cache = None
395400
self.draft_token_positions_cache = None
396401
if spec_config:
@@ -1892,7 +1897,7 @@ def _apply_incremental_update_target(
18921897
# Pre-compute constants
18931898
extend_requests = scheduled_requests.generation_requests
18941899
num_extend_requests = len(extend_requests)
1895-
num_tokens_per_extend_request = self.original_max_draft_len + 1
1900+
num_tokens_per_extend_request = self.original_max_total_draft_tokens + 1
18961901
spec_config = self.spec_config
18971902

18981903
prompt_lengths = torch.empty(num_extend_requests,
@@ -3480,7 +3485,7 @@ def forward(self,
34803485
is_spec_dec_tree=spec_metadata.is_spec_dec_tree,
34813486
is_spec_dec_dynamic_tree=spec_metadata.is_spec_dec_dynamic_tree,
34823487
max_draft_len=self.original_max_draft_len,
3483-
max_total_draft_tokens=self.original_max_total_draft_tokens,
3488+
max_total_draft_tokens=self._spec_dec_max_total_draft_tokens,
34843489
model_is_wrapped=self.model_is_wrapped,
34853490
spec_metadata=spec_metadata,
34863491
spec_tree_manager=spec_tree_manager,

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,12 @@ def drafting_loop_wrapper(model):
415415
if use_tree_drafter:
416416
return TreeDraftingLoopWrapper(
417417
spec_config.max_draft_len,
418-
spec_config.max_total_draft_tokens, max_batch_size,
418+
spec_config.tokens_per_gen_step - 1, max_batch_size,
419419
model)
420420
else:
421421
return LinearDraftingLoopWrapper(
422422
spec_config.max_draft_len,
423-
spec_config.max_total_draft_tokens, model)
423+
spec_config.tokens_per_gen_step - 1, model)
424424
else:
425425
drafting_loop_wrapper = None
426426

@@ -460,11 +460,11 @@ def drafting_loop_wrapper(model):
460460
model_engine_max_seq_len = model_engine.max_seq_len
461461
net_max_seq_len = model_engine_max_seq_len
462462
if not llm_args.disable_overlap_scheduler and spec_config is not None:
463-
model_engine_max_seq_len += spec_config.max_total_draft_tokens
463+
model_engine_max_seq_len += spec_config.tokens_per_gen_step - 1
464464

465465
if spec_config is not None:
466466
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
467-
model_engine_max_seq_len += spec_config.max_total_draft_tokens
467+
model_engine_max_seq_len += spec_config.tokens_per_gen_step - 1
468468

469469
if has_draft_model_engine and not llm_args.disable_overlap_scheduler:
470470
logger.warning(
@@ -546,7 +546,7 @@ def drafting_loop_wrapper(model):
546546
}
547547
if spec_config is not None:
548548
kwargs[
549-
"max_num_draft_tokens"] = spec_config.max_total_draft_tokens
549+
"max_num_draft_tokens"] = spec_config.tokens_per_gen_step - 1
550550

551551
if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder(
552552
):

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
351351
self.attention_dp_events_gather_period_ms = kv_cache_config.attention_dp_events_gather_period_ms
352352
self.max_num_tokens = max_num_tokens
353353
self.max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
354-
self.max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
354+
self.max_total_draft_tokens = (spec_config.tokens_per_gen_step -
355+
1) if spec_config is not None else 0
355356

356357
# Determine max_attention_window_vec
357358
if kv_cache_config.max_attention_window is None:

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
should_use_separate_draft_kv_cache)
55
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
66
from .ngram import NGramDrafter, NGramPoolManager
7+
from .pard import PARDSpecMetadata, PARDWorker
78
from .save_hidden_state import (SaveHiddenStatesResourceManager,
89
SaveHiddenStatesSpecMetadata)
910
from .spec_tree_manager import SpecTreeManager
@@ -19,6 +20,8 @@
1920
"MTPWorker",
2021
"NGramDrafter",
2122
"NGramPoolManager",
23+
"PARDSpecMetadata",
24+
"PARDWorker",
2225
"SaveHiddenStatesResourceManager",
2326
"SaveHiddenStatesSpecMetadata",
2427
"SpecMetadata",

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
4343
from ...llmapi.llm_args import EagleDecodingConfig
4444

4545
if isinstance(config, EagleDecodingConfig):
46-
self.max_total_draft_tokens = config.max_total_draft_tokens
46+
self.max_total_draft_tokens = config.tokens_per_gen_step - 1
4747
else:
4848
self.max_total_draft_tokens = self.max_draft_len
4949

0 commit comments

Comments
 (0)