Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from dataclasses import replace
from typing import Dict, Generic, List, Optional, Tuple

Expand All @@ -24,6 +25,7 @@
should_use_separate_draft_kv_cache)
from ..utils import AuxStreamType
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .modeling_auto import AutoModelForCausalLM
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
get_model_architecture, register_auto_model)

Expand Down Expand Up @@ -984,6 +986,8 @@ def get_draft_model(model_config, draft_config, lm_head, model):
return MTPDraftModelForCausalLM(model_config)
elif spec_dec_mode.is_pard():
return PARDForCausalLM(draft_config)
elif spec_dec_mode.is_draft_target_one_model():
return AutoModelForCausalLM.from_config(draft_config)
else:
raise NotImplementedError(
f"get_draft_model does not support speculative decoding mode {spec_dec_mode}."
Expand All @@ -1003,6 +1007,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
self.spec_worker = None
self.use_separate_draft_kv_cache = False
spec_config = getattr(model_config, 'spec_config', None)
self.spec_config = spec_config
if spec_config and spec_config.spec_dec_mode.use_one_engine():
# Only create draft_model for modes MTP, Eagle3 (not SA)
if not spec_config.spec_dec_mode.is_sa():
Expand Down Expand Up @@ -1037,7 +1042,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
self.draft_config.quant_config.kv_cache_quant_algo = \
model_config.quant_config.kv_cache_quant_algo

elif spec_config.spec_dec_mode.is_pard():
elif spec_config.spec_dec_mode.is_external_drafter():
self.draft_config = ModelConfig.from_pretrained(
model_config.spec_config.speculative_model,
trust_remote_code=True,
Expand Down Expand Up @@ -1160,10 +1165,15 @@ def load_weights(self,
def load_draft_weights(self,
weights: Dict,
weight_mapper: Optional[BaseWeightMapper] = None):
self.draft_model.load_weights(weights=weights,
weight_mapper=weight_mapper)
# PARD has independent weights; other methods share with target model
if not self.model_config.spec_config.spec_dec_mode.is_pard():
args = inspect.getfullargspec(self.draft_model.load_weights).args
if "weight_mapper" in args:
self.draft_model.load_weights(weights=weights,
weight_mapper=weight_mapper)
else:
self.draft_model.load_weights(weights=weights)

if self.spec_config and not self.spec_config.spec_dec_mode.is_external_drafter(
):
self.draft_model.load_weights_from_target_model(self)

def set_guided_decoder(self,
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,9 @@ def _create_one_model_draft_kv_cache_manager(
target_pretrained_config = self._model_engine.model.model_config.pretrained_config
target_num_layers = target_pretrained_config.num_hidden_layers

# PARD: draft is a separate model, layers start from 0.
# PARD, External Drafter: draft is a separate model, layers start from 0.
# Other methods (EAGLE3, MTP): draft layers are appended after target layers.
if self._speculative_config.spec_dec_mode.is_pard():
if self._speculative_config.spec_dec_mode.is_external_drafter():
num_draft_layers = self._draft_config.pretrained_config.num_hidden_layers
spec_dec_layer_mask = [True] * num_draft_layers
else:
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
== self.mapping.cp_size - 1 else 0),
req_beam_width, req)
else:
# Chunked prefill may schedule the same request across multiple
# context chunks. Sequence allocation must happen only once.
if not req.is_first_context_chunk:
continue

if self.impl.add_sequence(req.py_request_id, req.prompt_len,
req_beam_width, req):
for _ in range(self.num_extra_kv_tokens):
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/speculative/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .auto_heuristic import suggest_spec_config
from .draft_target import (DraftTargetOneModelSpecMetadata,
DraftTargetOneModelWorker)
from .eagle3 import Eagle3SpecMetadata
from .interface import (SpecMetadata, SpecWorkerBase,
should_use_separate_draft_kv_cache)
Expand All @@ -18,6 +20,8 @@
get_spec_worker, update_spec_config_from_model_config)

__all__ = [
"DraftTargetOneModelSpecMetadata",
"DraftTargetOneModelWorker",
"Eagle3SpecMetadata",
"MTPEagleWorker",
"MTPSampler",
Expand Down
Loading