2020layers are integrated into the target model's KV cache and run in a single forward pass.
2121"""
2222
23+ import os
2324from dataclasses import dataclass
2425from typing import TYPE_CHECKING , Optional
2526
2627import torch
2728from torch import nn
2829
2930from tensorrt_llm ._utils import prefer_pinned
31+ from tensorrt_llm .logger import logger
3032from tensorrt_llm .mapping import Mapping
3133
3234from ..attention_backend import AttentionMetadata
3840 from ...llmapi .llm_args import DraftTargetDecodingConfig
3941
4042
43+ def _env_enabled (name : str , default : bool = False ) -> bool :
44+ value = os .environ .get (name )
45+ if value is None :
46+ return default
47+ return str (value ).strip ().lower () in {"1" , "true" , "yes" , "on" }
48+
49+
4150@dataclass
4251class DraftTargetOneModelSpecMetadata (SpecMetadata ):
4352 """
@@ -96,6 +105,54 @@ def __init__(
96105 super ().__init__ (use_separate_draft_kv_cache )
97106 self .spec_config = spec_config
98107 self .mapping = mapping
108+ self ._rdma_offload_enabled = bool (
109+ getattr (spec_config , "draft_offload_enabled" , False )
110+ or _env_enabled ("TLLM_DRAFT_RDMA_OFFLOAD" )
111+ )
112+ self ._rdma_draft_client = None
113+ # Accumulates ALL output tokens across decode rounds so the draft server
114+ # can reconstruct the full generation context (prompt tokens are prepended
115+ # by the server side using the known prompt text).
116+ self ._rdma_output_history : list [int ] = []
117+ if self ._rdma_offload_enabled :
118+ if getattr (mapping , "tp_size" , 1 ) != 1 or getattr (mapping , "pp_size" , 1 ) != 1 :
119+ raise RuntimeError (
120+ "RDMA draft offload target path currently supports only "
121+ "single-rank TP/PP. Disable draft_offload_enabled for "
122+ "multi-rank runs."
123+ )
124+ from .rdma_draft_offload import RdmaDraftOffloadClient , RdmaDraftOffloadConfig
125+
126+ self ._rdma_draft_client = RdmaDraftOffloadClient (
127+ RdmaDraftOffloadConfig (
128+ nic_name = getattr (
129+ spec_config ,
130+ "draft_offload_nic_name" ,
131+ os .environ .get ("TLLM_DRAFT_RDMA_NIC" , "mlx5_0" ),
132+ ),
133+ server_host = getattr (
134+ spec_config ,
135+ "draft_offload_server_host" ,
136+ os .environ .get ("TLLM_DRAFT_RDMA_HOST" , "127.0.0.1" ),
137+ ),
138+ server_port = int (
139+ getattr (
140+ spec_config ,
141+ "draft_offload_server_port" ,
142+ os .environ .get ("TLLM_DRAFT_RDMA_PORT" , "47320" ),
143+ )
144+ ),
145+ gpu_id = getattr (spec_config , "draft_offload_gpu_id" , None ),
146+ max_draft_len = int (spec_config .max_draft_len ),
147+ buffer_size = int (getattr (spec_config , "draft_offload_buffer_size" , 4096 )),
148+ )
149+ )
150+ logger .info (
151+ "DraftTarget RDMA draft offload enabled: host=%s port=%s nic=%s" ,
152+ self ._rdma_draft_client .config .server_host ,
153+ self ._rdma_draft_client .config .server_port ,
154+ self ._rdma_draft_client .config .nic_name ,
155+ )
99156
100157 @property
101158 def max_draft_len (self ) -> int :
@@ -162,6 +219,7 @@ def forward(
162219 spec_metadata : DraftTargetOneModelSpecMetadata ,
163220 draft_model : nn .Module ,
164221 resource_manager = None ,
222+ is_warmup : bool = False ,
165223 ):
166224 """
167225 Technically incorrect at the moment.
@@ -184,6 +242,46 @@ def forward(
184242 logits , attn_metadata , spec_metadata
185243 )
186244
245+ if self ._rdma_offload_enabled :
246+ if bool (is_warmup ):
247+ # Warmup: initialize RDMA connection and exercise one real
248+ # round-trip (tokens=[] → draft server sees only the prompt).
249+ # This pre-warms the QP, GPU buffers, and NIC queues so the
250+ # first real decode step does not pay connection-setup latency.
251+ # Output history is NOT updated; the result is discarded.
252+ self ._rdma_draft_client .request (
253+ tokens = [],
254+ position = 0 ,
255+ max_draft_len = self .max_draft_len ,
256+ device = logits .device ,
257+ )
258+ next_draft_tokens = torch .zeros (
259+ (batch_size , self .max_draft_len ), dtype = torch .int32 , device = logits .device
260+ )
261+ else :
262+ next_draft_tokens = self ._rdma_offload_draft_tokens (
263+ accepted_tokens = accepted_tokens ,
264+ num_accepted_tokens = num_accepted_tokens ,
265+ position_ids = position_ids ,
266+ logits = logits ,
267+ batch_size = batch_size ,
268+ )
269+ next_new_tokens = self ._prepare_next_new_tokens (
270+ accepted_tokens ,
271+ next_draft_tokens ,
272+ spec_metadata .batch_indices_cuda ,
273+ batch_size ,
274+ num_accepted_tokens ,
275+ )
276+ attn_metadata .use_spec_decoding = True
277+ return {
278+ "logits" : raw_logits ,
279+ "new_tokens" : accepted_tokens ,
280+ "new_tokens_lens" : num_accepted_tokens ,
281+ "next_draft_tokens" : next_draft_tokens ,
282+ "next_new_tokens" : next_new_tokens ,
283+ }
284+
187285 # Prepare attention metadata for speculative decoding and save state for restore
188286 self ._prepare_attn_metadata_for_draft_target (attn_metadata , spec_metadata )
189287
@@ -297,6 +395,56 @@ def forward(
297395 "next_new_tokens" : next_new_tokens ,
298396 }
299397
398+ def _rdma_offload_draft_tokens (
399+ self ,
400+ * ,
401+ accepted_tokens : torch .Tensor ,
402+ num_accepted_tokens : torch .Tensor ,
403+ position_ids : Optional [torch .Tensor ],
404+ logits : torch .Tensor ,
405+ batch_size : int ,
406+ ) -> torch .Tensor :
407+ if self ._rdma_draft_client is None :
408+ raise RuntimeError ("RDMA draft offload client was not initialized" )
409+ if int (batch_size ) != 1 :
410+ raise RuntimeError ("RDMA draft offload target path currently supports batch_size=1" )
411+
412+ accepted_count = int (num_accepted_tokens [0 ].detach ().cpu ().item ())
413+ accepted_count = max (1 , min (accepted_count , accepted_tokens .shape [1 ]))
414+
415+ if position_ids is None or int (position_ids .numel ()) == 0 :
416+ position = 0
417+ else :
418+ position = int (position_ids .reshape (- 1 )[- 1 ].detach ().cpu ().item ())
419+
420+ # Accumulate all accepted output tokens for full-context draft inference.
421+ for i in range (accepted_count ):
422+ self ._rdma_output_history .append (int (accepted_tokens [0 , i ].detach ().cpu ().item ()))
423+ # Cap at MAX_TOKENS (64) to fit in the RDMA buffer.
424+ tokens_to_send = self ._rdma_output_history [- 64 :]
425+
426+ logger .info (
427+ "[RDMA] _rdma_offload_draft_tokens: round=%d pos=%d ctx_len=%d" ,
428+ self ._rdma_draft_client .round_seq ,
429+ position ,
430+ len (tokens_to_send ),
431+ )
432+ draft_tokens = self ._rdma_draft_client .request (
433+ tokens = tokens_to_send ,
434+ position = position ,
435+ max_draft_len = self .max_draft_len ,
436+ device = logits .device ,
437+ )
438+ logger .info ("[RDMA] got draft tokens: %s" , draft_tokens )
439+ if not draft_tokens :
440+ draft_tokens = [0 ]
441+ if len (draft_tokens ) < self .max_draft_len :
442+ draft_tokens = draft_tokens + [draft_tokens [- 1 ]] * (
443+ self .max_draft_len - len (draft_tokens )
444+ )
445+ draft_tokens = draft_tokens [: self .max_draft_len ]
446+ return torch .tensor ([draft_tokens ], dtype = torch .int32 , device = logits .device )
447+
300448 def sample_and_accept_draft_tokens (
301449 self ,
302450 logits : torch .Tensor ,
0 commit comments