1+ from dataclasses import replace
12from typing import Dict , Generic , List , Optional , Tuple
23
34import torch
2425from ..utils import AuxStreamType
2526from .checkpoints .base_weight_mapper import BaseWeightMapper
2627from .modeling_utils import (DecoderModel , DecoderModelForCausalLM , TModel ,
27- register_auto_model )
28+ get_model_architecture , register_auto_model )
2829
2930
3031def _ensure_draft_vocab_size (config : PretrainedConfig ) -> None :
@@ -108,9 +109,9 @@ def __init__(
108109 config = model_config .pretrained_config
109110 self ._next_layer_regular = next_layer_regular
110111
111- predicted_tokens_per_seq = (
112- model_config .spec_config . max_total_draft_tokens +
113- 1 if model_config . spec_config is not None else 1 )
112+ predicted_tokens_per_seq = (model_config . spec_config . tokens_per_gen_step
113+ if model_config .spec_config is not None else
114+ 1 )
114115
115116 super ().__init__ (
116117 hidden_size = config .hidden_size ,
@@ -702,6 +703,70 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
702703 return hidden_states
703704
704705
706+ class PARDForCausalLM (nn .Module ):
707+ """Draft model wrapper for PARD (Parallel Draft) speculative decoding.
708+
709+ See PARDWorker for the full algorithm description.
710+ """
711+
712+ def __init__ (self , draft_config ):
713+ super ().__init__ ()
714+ DraftModelClass , _ = get_model_architecture (
715+ draft_config .pretrained_config )
716+
717+ # Remove spec_config to prevent recursive spec-dec initialization
718+ draft_config_no_spec = replace (draft_config , spec_config = None )
719+
720+ # Weights will be loaded later by ModelLoader.load_draft_weights()
721+ self .draft_model_full = DraftModelClass (draft_config_no_spec )
722+ self .model = self .draft_model_full .model
723+ self .lm_head = self .draft_model_full .lm_head
724+
725+ # Required by weight mappers
726+ self .model_config = draft_config_no_spec
727+ self .config = draft_config_no_spec .pretrained_config
728+
729+ # Fall back: pard_token -> mask_token_id -> vocab_size
730+ pretrained_config = draft_config .pretrained_config
731+ self .mask_token_id = getattr (
732+ pretrained_config , 'pard_token' ,
733+ getattr (pretrained_config , 'mask_token_id' ,
734+ pretrained_config .vocab_size ))
735+ logger .info (
736+ f"PARD draft model initialized with mask_token_id: { self .mask_token_id } "
737+ )
738+
739+ self .logits_processor = None # Set by caller after construction
740+
741+ def load_weights (self , weights : Dict , weight_mapper = None , ** kwargs ):
742+ """Load weights into the PARD draft model."""
743+ self .draft_model_full .load_weights (weights = weights ,
744+ weight_mapper = weight_mapper ,
745+ ** kwargs )
746+
747+ def forward (
748+ self ,
749+ attn_metadata ,
750+ input_ids : torch .LongTensor = None ,
751+ position_ids : torch .LongTensor | None = None ,
752+ inputs_embeds : torch .FloatTensor | None = None ,
753+ return_context_logits : bool = False ,
754+ spec_metadata = None ,
755+ hidden_states : torch .Tensor | None = None ,
756+ ** kwargs ,
757+ ) -> tuple [torch .Tensor , torch .Tensor ]:
758+ hidden_states_out = self .model (
759+ input_ids = input_ids ,
760+ attn_metadata = attn_metadata ,
761+ position_ids = position_ids ,
762+ inputs_embeds = inputs_embeds ,
763+ spec_metadata = spec_metadata ,
764+ ** kwargs ,
765+ )
766+
767+ return hidden_states_out , hidden_states_out
768+
769+
705770class MTPForCausalLM (nn .Module ):
706771
707772 def __init__ (
@@ -917,6 +982,8 @@ def get_draft_model(model_config, draft_config, lm_head, model):
917982 lm_head , model )
918983 elif spec_dec_mode .is_mtp_eagle ():
919984 return MTPDraftModelForCausalLM (model_config )
985+ elif spec_dec_mode .is_pard ():
986+ return PARDForCausalLM (draft_config )
920987 else :
921988 raise NotImplementedError (
922989 f"get_draft_model does not support speculative decoding mode { spec_dec_mode } ."
@@ -967,11 +1034,27 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
9671034 self .draft_config .quant_config .kv_cache_quant_algo = \
9681035 model_config .quant_config .kv_cache_quant_algo
9691036
1037+ elif spec_config .spec_dec_mode .is_pard ():
1038+ self .draft_config = ModelConfig .from_pretrained (
1039+ model_config .spec_config .speculative_model ,
1040+ trust_remote_code = True ,
1041+ attn_backend = model_config .attn_backend ,
1042+ moe_backend = model_config .moe_backend ,
1043+ mapping = model_config .mapping ,
1044+ spec_config = None , # Avoid recursive spec-dec
1045+ max_num_tokens = model_config .max_num_tokens ,
1046+ moe_max_num_tokens = model_config .moe_max_num_tokens )
1047+ self .draft_config .quant_config .kv_cache_quant_algo = \
1048+ model_config .quant_config .kv_cache_quant_algo
1049+
9701050 self .use_separate_draft_kv_cache = should_use_separate_draft_kv_cache (
9711051 spec_config )
9721052
9731053 self .draft_model = get_draft_model (model_config , self .draft_config ,
9741054 self .lm_head , self .model )
1055+ if spec_config .spec_dec_mode .is_pard (
1056+ ) and self .draft_model is not None :
1057+ self .draft_model .logits_processor = self .logits_processor
9751058 self .spec_worker = get_spec_worker (
9761059 model_config .spec_config ,
9771060 model_config ,
@@ -980,7 +1063,10 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
9801063 self .epilogue .append (self .draft_model )
9811064 self .epilogue .append (self .spec_worker )
9821065
983- if self .draft_config is not None and model_config .spec_config .eagle3_model_arch == "llama3" :
1066+ # EAGLE3-specific logic: merge extra_attrs from draft model for Llama3
1067+ if (self .draft_config is not None and model_config .spec_config .
1068+ spec_dec_mode .is_eagle3_one_model ()
1069+ and model_config .spec_config .eagle3_model_arch == "llama3" ):
9841070 for key , value in self .draft_config .extra_attrs .items ():
9851071 assert key in ('attn_layers' , 'mla_layers' )
9861072 assert key in model_config .extra_attrs
@@ -1067,7 +1153,9 @@ def load_draft_weights(self,
10671153 weight_mapper : Optional [BaseWeightMapper ] = None ):
10681154 self .draft_model .load_weights (weights = weights ,
10691155 weight_mapper = weight_mapper )
1070- self .draft_model .load_weights_from_target_model (self )
1156+ # PARD has independent weights; other methods share with target model
1157+ if not self .model_config .spec_config .spec_dec_mode .is_pard ():
1158+ self .draft_model .load_weights_from_target_model (self )
10711159
10721160 def set_guided_decoder (self ,
10731161 guided_decoder : CapturableGuidedDecoder ) -> bool :
0 commit comments