Skip to content

Commit e051e60

Browse files
[None][feat] add GPUDirect RDMA draft offload for speculative decoding
Implement CPU-initiated libibverbs RDMA transport to offload the draft model to a separate GPU, replacing local draft inference with a remote RDMA peer. The target model (TRT-LLM) writes accumulated output tokens to the draft server via RDMA Write; the draft server returns speculative tokens via RDMA Write back. Key changes: - rdma_draft_offload.py: ibverbs RC QP client with GPUDirect MR registration, QP state machine (RESET->INIT->RTR->RTS), and per-round request/response - rdma_draft_protocol.py: fixed-size binary protocol (32B header + 256B tokens, MAGIC-checked, 4096B total) for target<->draft RDMA buffers - draft_target.py: RDMA offload path in DraftTargetOneModelWorker.forward(), output token history accumulation, warmup pre-connection - llm_args.py: DraftTargetDecodingConfig RDMA fields; allow speculative_model=None when draft_offload_enabled=True - model_loader.py: skip draft weight loading when draft_offload_enabled - modeling_speculative.py: skip draft model instantiation; thread is_warmup - _util.py: skip separate draft KV cache when draft_offload_enabled - model_engine.py: pass is_warmup flag through to model forward inputs - .gitignore: ignore cmake-created symlinks deep_ep/deep_gemm/flash_mla Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent e8ef1b5 commit e051e60

10 files changed

Lines changed: 1085 additions & 9 deletions

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,8 @@ enroot/tensorrt_llm.devel.sqsh
110110
.claude/agent-memory/
111111
.claude/agent-tests/perf-test-sync/report.html
112112
.claude/agent-tests/perf-test-sync/results.json
113+
114+
# Runtime third-party dependencies: symlinks created by cmake build, not part of this repo
115+
tensorrt_llm/deep_ep
116+
tensorrt_llm/deep_gemm
117+
tensorrt_llm/flash_mla
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Minimal target-side RDMA draft offload example.
4+
5+
Start the fake draft server first, then run this script. The target model is a
6+
real TensorRT-LLM LLM; the draft model is temporarily replaced by the fake RDMA
7+
peer.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import argparse
13+
import os
14+
import sys
15+
from pathlib import Path
16+
17+
_REPO_ROOT = Path(__file__).resolve().parents[2]
18+
if str(_REPO_ROOT) not in sys.path:
19+
sys.path.insert(0, str(_REPO_ROOT))
20+
21+
DEFAULT_MODEL = "/scratch.trt_llm_data/llm-models/Qwen3/Qwen3-8B"
22+
23+
24+
def parse_args() -> argparse.Namespace:
25+
parser = argparse.ArgumentParser(description=__doc__)
26+
parser.add_argument("--model", default=DEFAULT_MODEL)
27+
parser.add_argument("--prompt", default="Explain GPUDirect RDMA in one short sentence.")
28+
parser.add_argument("--max-tokens", type=int, default=16)
29+
parser.add_argument("--max-draft-len", type=int, default=5)
30+
parser.add_argument("--draft-host", default="127.0.0.1")
31+
parser.add_argument("--draft-port", type=int, default=47320)
32+
parser.add_argument("--ib-dev", default="mlx5_0")
33+
parser.add_argument(
34+
"--gpu-id",
35+
type=int,
36+
default=0,
37+
help="Physical GPU ID for RDMA memory registration (target side)",
38+
)
39+
parser.add_argument(
40+
"--cuda-visible-devices",
41+
default="0",
42+
help="Set CUDA_VISIBLE_DEVICES before importing TensorRT-LLM.",
43+
)
44+
parser.add_argument("--max-batch-size", type=int, default=1)
45+
parser.add_argument("--max-seq-len", type=int, default=512)
46+
parser.add_argument("--max-num-tokens", type=int, default=512)
47+
parser.add_argument("--kv-cache-max-tokens", type=int, default=512)
48+
parser.add_argument("--kv-cache-free-gpu-memory-fraction", type=float, default=0.05)
49+
return parser.parse_args()
50+
51+
52+
def main() -> None:
53+
args = parse_args()
54+
if args.cuda_visible_devices is not None:
55+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
56+
57+
from tensorrt_llm import LLM, SamplingParams
58+
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
59+
60+
spec_config = DraftTargetDecodingConfig(
61+
max_draft_len=args.max_draft_len,
62+
draft_offload_enabled=True,
63+
draft_offload_nic_name=args.ib_dev,
64+
draft_offload_server_host=args.draft_host,
65+
draft_offload_server_port=args.draft_port,
66+
draft_offload_gpu_id=args.gpu_id,
67+
)
68+
# Disable CUDA graphs: RDMA draft calls are Python-side operations and
69+
# would not be re-executed on each CUDA graph replay.
70+
# cuda_graph_config=None disables CUDA graphs entirely (empty CudaGraphConfig
71+
# does NOT disable graphs because the validator auto-fills batch sizes).
72+
73+
llm = LLM(
74+
model=args.model,
75+
speculative_config=spec_config,
76+
disable_overlap_scheduler=True,
77+
tensor_parallel_size=1,
78+
cuda_graph_config=None,
79+
max_batch_size=args.max_batch_size,
80+
max_seq_len=args.max_seq_len,
81+
max_num_tokens=args.max_num_tokens,
82+
kv_cache_config=KvCacheConfig(
83+
max_tokens=args.kv_cache_max_tokens,
84+
free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
85+
),
86+
)
87+
output = llm.generate(
88+
args.prompt,
89+
SamplingParams(max_tokens=args.max_tokens),
90+
use_tqdm=False,
91+
)
92+
print(output.outputs[0].text)
93+
94+
95+
if __name__ == "__main__":
96+
main()

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,8 +1027,11 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
10271027
spec_config = getattr(model_config, 'spec_config', None)
10281028
self.spec_config = spec_config
10291029
if spec_config and spec_config.spec_dec_mode.use_one_engine():
1030+
draft_offload_enabled = bool(
1031+
getattr(spec_config, "draft_offload_enabled", False))
10301032
# Only create draft_model for modes MTP, Eagle3 (not SA)
1031-
if not spec_config.spec_dec_mode.is_sa():
1033+
if not spec_config.spec_dec_mode.is_sa(
1034+
) and not draft_offload_enabled:
10321035
if spec_config.spec_dec_mode.is_eagle3_one_model():
10331036
if spec_config.eagle3_model_arch == "mistral_large3":
10341037
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
@@ -1105,6 +1108,7 @@ def forward(
11051108
return_context_logits: bool = False,
11061109
spec_metadata: Optional[SpecMetadata] = None,
11071110
resource_manager=None,
1111+
is_warmup: bool = False,
11081112
**kwargs,
11091113
) -> torch.Tensor:
11101114
hidden_states = self.model(
@@ -1150,7 +1154,8 @@ def forward(
11501154
attn_metadata=attn_metadata,
11511155
spec_metadata=spec_metadata,
11521156
draft_model=self.draft_model,
1153-
resource_manager=resource_manager)
1157+
resource_manager=resource_manager,
1158+
is_warmup=is_warmup)
11541159
else:
11551160
logits = self.logits_processor.forward(
11561161
hidden_states,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,8 @@ def _should_create_separate_draft_kv_cache(self) -> bool:
624624
"Attention DP is enabled, separate draft KV cache is not supported."
625625
)
626626
return False
627+
if getattr(self._speculative_config, "draft_offload_enabled", False):
628+
return False
627629
return should_use_separate_draft_kv_cache(self._speculative_config)
628630

629631
def _get_effective_draft_config(self) -> ModelConfig:

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,7 @@ def _apply_incremental_update_target(
22002200
'inputs_embeds': None,
22012201
"multimodal_params": [],
22022202
'resource_manager': resource_manager,
2203+
'is_warmup': self.is_warmup,
22032204
}
22042205

22052206
if bool(lora_params):
@@ -3052,6 +3053,7 @@ def previous_seq_slots_device():
30523053
'inputs_embeds': None,
30533054
"multimodal_params": multimodal_params_list,
30543055
'resource_manager': resource_manager,
3056+
'is_warmup': self.is_warmup,
30553057
}
30563058

30573059
if bool(lora_params):
@@ -3224,6 +3226,7 @@ def _prepare_tp_inputs_no_cache(
32243226
'inputs_embeds': None,
32253227
"multimodal_params": multimodal_params_list,
32263228
'resource_manager': resource_manager,
3229+
'is_warmup': self.is_warmup,
32273230
}
32283231

32293232
if bool(lora_params):
@@ -3492,6 +3495,7 @@ def _prepare_star_attention_inputs(
34923495
'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0),
34933496
'inputs_embeds': None,
34943497
'resource_manager': resource_manager,
3498+
'is_warmup': self.is_warmup,
34953499
}, gather_ids if is_spec_decode else None
34963500

34973501
def _get_lora_params_from_requests(

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,9 @@ def init_meta_tensor(t: torch.Tensor):
394394
self._call_load_weights(model.load_weights, weights,
395395
self.weight_mapper)
396396

397-
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
398-
):
397+
if (self.spec_config is not None and self.spec_config.
398+
spec_dec_mode.need_load_draft_weights() and not getattr(
399+
self.spec_config, "draft_offload_enabled", False)):
399400
weights = checkpoint_loader.load_weights(
400401
self.spec_config.speculative_model,
401402
mapping=self.mapping)
@@ -414,8 +415,9 @@ def init_meta_tensor(t: torch.Tensor):
414415
self.weight_mapper = checkpoint_loader.get_initialized_weight_mapper(
415416
model, config)
416417
initialize_dummy_weights(model)
417-
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
418-
):
418+
if (self.spec_config is not None and self.spec_config.
419+
spec_dec_mode.need_load_draft_weights() and not getattr(
420+
self.spec_config, "draft_offload_enabled", False)):
419421
model.draft_model.load_weights_from_target_model(model)
420422

421423
elif load_format == LoadFormat.VISION_ONLY:

tensorrt_llm/_torch/speculative/draft_target.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
layers are integrated into the target model's KV cache and run in a single forward pass.
2121
"""
2222

23+
import os
2324
from dataclasses import dataclass
2425
from typing import TYPE_CHECKING, Optional
2526

2627
import torch
2728
from torch import nn
2829

2930
from tensorrt_llm._utils import prefer_pinned
31+
from tensorrt_llm.logger import logger
3032
from tensorrt_llm.mapping import Mapping
3133

3234
from ..attention_backend import AttentionMetadata
@@ -38,6 +40,13 @@
3840
from ...llmapi.llm_args import DraftTargetDecodingConfig
3941

4042

43+
def _env_enabled(name: str, default: bool = False) -> bool:
44+
value = os.environ.get(name)
45+
if value is None:
46+
return default
47+
return str(value).strip().lower() in {"1", "true", "yes", "on"}
48+
49+
4150
@dataclass
4251
class DraftTargetOneModelSpecMetadata(SpecMetadata):
4352
"""
@@ -96,6 +105,54 @@ def __init__(
96105
super().__init__(use_separate_draft_kv_cache)
97106
self.spec_config = spec_config
98107
self.mapping = mapping
108+
self._rdma_offload_enabled = bool(
109+
getattr(spec_config, "draft_offload_enabled", False)
110+
or _env_enabled("TLLM_DRAFT_RDMA_OFFLOAD")
111+
)
112+
self._rdma_draft_client = None
113+
# Accumulates ALL output tokens across decode rounds so the draft server
114+
# can reconstruct the full generation context (prompt tokens are prepended
115+
# by the server side using the known prompt text).
116+
self._rdma_output_history: list[int] = []
117+
if self._rdma_offload_enabled:
118+
if getattr(mapping, "tp_size", 1) != 1 or getattr(mapping, "pp_size", 1) != 1:
119+
raise RuntimeError(
120+
"RDMA draft offload target path currently supports only "
121+
"single-rank TP/PP. Disable draft_offload_enabled for "
122+
"multi-rank runs."
123+
)
124+
from .rdma_draft_offload import RdmaDraftOffloadClient, RdmaDraftOffloadConfig
125+
126+
self._rdma_draft_client = RdmaDraftOffloadClient(
127+
RdmaDraftOffloadConfig(
128+
nic_name=getattr(
129+
spec_config,
130+
"draft_offload_nic_name",
131+
os.environ.get("TLLM_DRAFT_RDMA_NIC", "mlx5_0"),
132+
),
133+
server_host=getattr(
134+
spec_config,
135+
"draft_offload_server_host",
136+
os.environ.get("TLLM_DRAFT_RDMA_HOST", "127.0.0.1"),
137+
),
138+
server_port=int(
139+
getattr(
140+
spec_config,
141+
"draft_offload_server_port",
142+
os.environ.get("TLLM_DRAFT_RDMA_PORT", "47320"),
143+
)
144+
),
145+
gpu_id=getattr(spec_config, "draft_offload_gpu_id", None),
146+
max_draft_len=int(spec_config.max_draft_len),
147+
buffer_size=int(getattr(spec_config, "draft_offload_buffer_size", 4096)),
148+
)
149+
)
150+
logger.info(
151+
"DraftTarget RDMA draft offload enabled: host=%s port=%s nic=%s",
152+
self._rdma_draft_client.config.server_host,
153+
self._rdma_draft_client.config.server_port,
154+
self._rdma_draft_client.config.nic_name,
155+
)
99156

100157
@property
101158
def max_draft_len(self) -> int:
@@ -162,6 +219,7 @@ def forward(
162219
spec_metadata: DraftTargetOneModelSpecMetadata,
163220
draft_model: nn.Module,
164221
resource_manager=None,
222+
is_warmup: bool = False,
165223
):
166224
"""
167225
Technically incorrect at the moment.
@@ -184,6 +242,46 @@ def forward(
184242
logits, attn_metadata, spec_metadata
185243
)
186244

245+
if self._rdma_offload_enabled:
246+
if bool(is_warmup):
247+
# Warmup: initialize RDMA connection and exercise one real
248+
# round-trip (tokens=[] → draft server sees only the prompt).
249+
# This pre-warms the QP, GPU buffers, and NIC queues so the
250+
# first real decode step does not pay connection-setup latency.
251+
# Output history is NOT updated; the result is discarded.
252+
self._rdma_draft_client.request(
253+
tokens=[],
254+
position=0,
255+
max_draft_len=self.max_draft_len,
256+
device=logits.device,
257+
)
258+
next_draft_tokens = torch.zeros(
259+
(batch_size, self.max_draft_len), dtype=torch.int32, device=logits.device
260+
)
261+
else:
262+
next_draft_tokens = self._rdma_offload_draft_tokens(
263+
accepted_tokens=accepted_tokens,
264+
num_accepted_tokens=num_accepted_tokens,
265+
position_ids=position_ids,
266+
logits=logits,
267+
batch_size=batch_size,
268+
)
269+
next_new_tokens = self._prepare_next_new_tokens(
270+
accepted_tokens,
271+
next_draft_tokens,
272+
spec_metadata.batch_indices_cuda,
273+
batch_size,
274+
num_accepted_tokens,
275+
)
276+
attn_metadata.use_spec_decoding = True
277+
return {
278+
"logits": raw_logits,
279+
"new_tokens": accepted_tokens,
280+
"new_tokens_lens": num_accepted_tokens,
281+
"next_draft_tokens": next_draft_tokens,
282+
"next_new_tokens": next_new_tokens,
283+
}
284+
187285
# Prepare attention metadata for speculative decoding and save state for restore
188286
self._prepare_attn_metadata_for_draft_target(attn_metadata, spec_metadata)
189287

@@ -297,6 +395,56 @@ def forward(
297395
"next_new_tokens": next_new_tokens,
298396
}
299397

398+
def _rdma_offload_draft_tokens(
399+
self,
400+
*,
401+
accepted_tokens: torch.Tensor,
402+
num_accepted_tokens: torch.Tensor,
403+
position_ids: Optional[torch.Tensor],
404+
logits: torch.Tensor,
405+
batch_size: int,
406+
) -> torch.Tensor:
407+
if self._rdma_draft_client is None:
408+
raise RuntimeError("RDMA draft offload client was not initialized")
409+
if int(batch_size) != 1:
410+
raise RuntimeError("RDMA draft offload target path currently supports batch_size=1")
411+
412+
accepted_count = int(num_accepted_tokens[0].detach().cpu().item())
413+
accepted_count = max(1, min(accepted_count, accepted_tokens.shape[1]))
414+
415+
if position_ids is None or int(position_ids.numel()) == 0:
416+
position = 0
417+
else:
418+
position = int(position_ids.reshape(-1)[-1].detach().cpu().item())
419+
420+
# Accumulate all accepted output tokens for full-context draft inference.
421+
for i in range(accepted_count):
422+
self._rdma_output_history.append(int(accepted_tokens[0, i].detach().cpu().item()))
423+
# Cap at MAX_TOKENS (64) to fit in the RDMA buffer.
424+
tokens_to_send = self._rdma_output_history[-64:]
425+
426+
logger.info(
427+
"[RDMA] _rdma_offload_draft_tokens: round=%d pos=%d ctx_len=%d",
428+
self._rdma_draft_client.round_seq,
429+
position,
430+
len(tokens_to_send),
431+
)
432+
draft_tokens = self._rdma_draft_client.request(
433+
tokens=tokens_to_send,
434+
position=position,
435+
max_draft_len=self.max_draft_len,
436+
device=logits.device,
437+
)
438+
logger.info("[RDMA] got draft tokens: %s", draft_tokens)
439+
if not draft_tokens:
440+
draft_tokens = [0]
441+
if len(draft_tokens) < self.max_draft_len:
442+
draft_tokens = draft_tokens + [draft_tokens[-1]] * (
443+
self.max_draft_len - len(draft_tokens)
444+
)
445+
draft_tokens = draft_tokens[: self.max_draft_len]
446+
return torch.tensor([draft_tokens], dtype=torch.int32, device=logits.device)
447+
300448
def sample_and_accept_draft_tokens(
301449
self,
302450
logits: torch.Tensor,

0 commit comments

Comments
 (0)