Skip to content

Commit 3fd5faf

Browse files
dhansen-nvidiamojombovenkywonka
authored
[https://nvbugs/5911143][fix] add async worker to MTP/Eagle3 sampler,… (NVIDIA#11573)
Signed-off-by: Dan Hansen <1+dhansen-nvidia@users.noreply.github.com> Signed-off-by: dhansen-nvidia <218031328+dhansen-nvidia@users.noreply.github.com> Co-authored-by: Dan Hansen <1+dhansen-nvidia@users.noreply.github.com> Co-authored-by: Venky <23023424+venkywonka@users.noreply.github.com>
1 parent 41dd9e0 commit 3fd5faf

35 files changed

+441
-251
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,3 +1473,8 @@ repos:
14731473
entry: ./scripts/dco_check.py
14741474
language: script
14751475
stages: [commit-msg]
1476+
- id: pinned memory policy
1477+
name: Disallow raw pinned-memory APIs in runtime code
1478+
entry: ./scripts/check_pinned_memory_usage.py
1479+
language: script
1480+
files: ^(tensorrt_llm|triton_backend)/.*\.py$
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import ast
6+
import pathlib
7+
import sys
8+
9+
10+
class PinnedMemoryUsageChecker(ast.NodeVisitor):
11+
def __init__(self, *, allow_direct_pin_memory: bool) -> None:
12+
self.allow_direct_pin_memory = allow_direct_pin_memory
13+
self.violations: list[tuple[int, str]] = []
14+
15+
def visit_Call(self, node: ast.Call) -> None:
16+
if isinstance(node.func, ast.Attribute) and node.func.attr == "pin_memory":
17+
if not self.allow_direct_pin_memory:
18+
self.violations.append(
19+
(
20+
node.lineno,
21+
"Use `maybe_pin_memory(tensor)` instead of direct `.pin_memory()`.",
22+
)
23+
)
24+
25+
for keyword in node.keywords:
26+
if (
27+
keyword.arg == "pin_memory"
28+
and isinstance(keyword.value, ast.Constant)
29+
and keyword.value.value is True
30+
):
31+
self.violations.append(
32+
(
33+
node.lineno,
34+
"Use `pin_memory=prefer_pinned()` instead of `pin_memory=True`.",
35+
)
36+
)
37+
38+
self.generic_visit(node)
39+
40+
41+
def _check_file(path: pathlib.Path) -> list[tuple[int, str]]:
42+
try:
43+
source = path.read_text(encoding="utf-8")
44+
except OSError as exc:
45+
return [(0, f"Failed to read file: {exc}")]
46+
47+
try:
48+
tree = ast.parse(source, filename=str(path))
49+
except SyntaxError as exc:
50+
return [(exc.lineno or 0, f"Failed to parse file: {exc.msg}")]
51+
52+
allow_direct_pin_memory = path.as_posix().endswith("tensorrt_llm/_utils.py")
53+
checker = PinnedMemoryUsageChecker(allow_direct_pin_memory=allow_direct_pin_memory)
54+
checker.visit(tree)
55+
return checker.violations
56+
57+
58+
def main(argv: list[str]) -> int:
59+
if len(argv) <= 1:
60+
return 0
61+
62+
has_violations = False
63+
for file_arg in argv[1:]:
64+
path = pathlib.Path(file_arg)
65+
violations = _check_file(path)
66+
for lineno, message in violations:
67+
has_violations = True
68+
print(f"{path}:{lineno}: {message}")
69+
70+
if has_violations:
71+
print("\nPinned-memory policy check failed.")
72+
print("Use `tensorrt_llm._utils.maybe_pin_memory()` for direct pinning.")
73+
return 1
74+
return 0
75+
76+
77+
if __name__ == "__main__":
78+
raise SystemExit(main(sys.argv))

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..speculative.interface import SpecMetadata
1515
from ..speculative.spec_tree_manager import SpecTreeManager
1616

17+
from tensorrt_llm._utils import maybe_pin_memory
1718
from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
1819
RotaryScalingType)
1920
from tensorrt_llm.mapping import Mapping
@@ -199,7 +200,7 @@ def seq_lens(self, value: Optional[torch.Tensor]):
199200

200201
# The model executor sets seq_lens to None initially.
201202
if self._seq_lens is not None:
202-
self._seq_lens = self._seq_lens.pin_memory()
203+
self._seq_lens = maybe_pin_memory(self._seq_lens)
203204

204205
if self.is_cuda_graph and self._seq_lens_cuda is not None:
205206
# Very important: do not reallocate if we are using CUDA graphs.
@@ -249,7 +250,7 @@ def seq_lens_kv(self, value: Optional[torch.Tensor]):
249250
self.on_update()
250251
# The model executor sets seqlens to None initially.
251252
if self._seq_lens_kv is not None:
252-
self._seq_lens_kv = self._seq_lens_kv.pin_memory()
253+
self._seq_lens_kv = maybe_pin_memory(self._seq_lens_kv)
253254
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(non_blocking=True)
254255

255256
@property

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
1919
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2020
from tensorrt_llm._torch.utils import maybe_compile, maybe_compiled_cat
21-
from tensorrt_llm._utils import get_size_in_bytes, get_sm_version
21+
from tensorrt_llm._utils import get_size_in_bytes, get_sm_version, prefer_pinned
2222
from tensorrt_llm.bindings import DataType
2323
from tensorrt_llm.bindings.executor import KvCacheConfig
2424
from tensorrt_llm.bindings.internal.batch_manager import \
@@ -339,7 +339,7 @@ def __post_init__(self):
339339
self.host_indexer_k_cache_block_offsets = torch.zeros_like(
340340
self.indexer_k_cache_block_offsets,
341341
device='cpu',
342-
pin_memory=True,
342+
pin_memory=prefer_pinned(),
343343
)
344344

345345
if not self.enable_context_mla_with_cached_kv:
@@ -353,7 +353,7 @@ def __post_init__(self):
353353
self.host_ctx_cached_token_indptr = torch.zeros_like(
354354
self.ctx_cached_token_indptr,
355355
device='cpu',
356-
pin_memory=True,
356+
pin_memory=prefer_pinned(),
357357
)
358358
self.ctx_kv_indptr = self.get_empty(
359359
self.cuda_graph_buffers,
@@ -365,7 +365,7 @@ def __post_init__(self):
365365
self.host_ctx_kv_indptr = torch.zeros_like(
366366
self.ctx_kv_indptr,
367367
device='cpu',
368-
pin_memory=True,
368+
pin_memory=prefer_pinned(),
369369
)
370370

371371
# Only when MLA chunked prefill is enabled, we need to gather the full KV for indexer's logit computation.
@@ -385,7 +385,7 @@ def __post_init__(self):
385385
self.host_gen_cached_token_indptr = torch.zeros_like(
386386
self.gen_cached_token_indptr,
387387
device='cpu',
388-
pin_memory=True,
388+
pin_memory=prefer_pinned(),
389389
)
390390
self.gen_kv_indptr = self.get_empty(
391391
self.cuda_graph_buffers,
@@ -397,7 +397,7 @@ def __post_init__(self):
397397
self.host_gen_kv_indptr = torch.zeros_like(
398398
self.gen_kv_indptr,
399399
device='cpu',
400-
pin_memory=True,
400+
pin_memory=prefer_pinned(),
401401
)
402402
# Indexer metadata
403403
# Separate slot mappings for non-interleaved layout (flat byte indices)
@@ -411,7 +411,7 @@ def __post_init__(self):
411411
self.host_slot_mapping_fp8 = torch.zeros_like(
412412
self.slot_mapping_fp8,
413413
device='cpu',
414-
pin_memory=True,
414+
pin_memory=prefer_pinned(),
415415
)
416416
self.slot_mapping_scale = self.get_empty(
417417
self.cuda_graph_buffers,
@@ -423,7 +423,7 @@ def __post_init__(self):
423423
self.host_slot_mapping_scale = torch.zeros_like(
424424
self.slot_mapping_scale,
425425
device='cpu',
426-
pin_memory=True,
426+
pin_memory=prefer_pinned(),
427427
)
428428
# Per-token request index buffer for topk_indices conversion
429429
self.req_idx_per_token = self.get_empty(
@@ -474,7 +474,7 @@ def __post_init__(self):
474474
self.host_topk_indices_buffer = torch.zeros_like(
475475
self.topk_indices_buffer,
476476
device='cpu',
477-
pin_memory=True,
477+
pin_memory=prefer_pinned(),
478478
)
479479
# Create expanded buffers for MTP support
480480
self.create_expanded_buffers(capture_graph=capture_graph)
@@ -491,7 +491,7 @@ def create_expanded_buffers(self, capture_graph=False):
491491
self.kv_lens_expanded_host = torch.zeros_like(
492492
self.kv_lens_expanded_cuda,
493493
device='cpu',
494-
pin_memory=True,
494+
pin_memory=prefer_pinned(),
495495
)
496496
self.block_table_expanded = self.get_empty(
497497
self.cuda_graph_buffers,
@@ -506,7 +506,7 @@ def create_expanded_buffers(self, capture_graph=False):
506506
self.host_block_table_expanded = torch.zeros_like(
507507
self.block_table_expanded,
508508
device='cpu',
509-
pin_memory=True,
509+
pin_memory=prefer_pinned(),
510510
)
511511
self.scheduler_metadata_buffer_expanded = self.get_empty(
512512
self.cuda_graph_buffers,
@@ -1171,12 +1171,10 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
11711171
total_kv_per_request = seq_lens[:
11721172
num_contexts] + start_positions[:
11731173
num_contexts]
1174-
host_slot_mapping_fp8_fullkv = torch.empty(total_kv_len,
1175-
dtype=torch.int64,
1176-
pin_memory=True)
1177-
host_slot_mapping_scale_fullkv = torch.empty(total_kv_len,
1178-
dtype=torch.int64,
1179-
pin_memory=True)
1174+
host_slot_mapping_fp8_fullkv = torch.empty(
1175+
total_kv_len, dtype=torch.int64, pin_memory=prefer_pinned())
1176+
host_slot_mapping_scale_fullkv = torch.empty(
1177+
total_kv_len, dtype=torch.int64, pin_memory=prefer_pinned())
11801178

11811179
req_indices = torch.repeat_interleave(
11821180
torch.arange(num_contexts, dtype=torch.int64, device='cpu'),

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
1616
from tensorrt_llm._torch.pyexecutor.resource_manager import (BlockManager,
1717
KVCacheManager)
18-
from tensorrt_llm._utils import get_size_in_bytes
18+
from tensorrt_llm._utils import get_size_in_bytes, prefer_pinned
1919
from tensorrt_llm.bindings import DataType
2020
from tensorrt_llm.bindings.executor import KvCacheConfig
2121
from tensorrt_llm.bindings.internal.batch_manager import \
@@ -143,7 +143,7 @@ def __post_init__(self):
143143
self.host_kt_cache_block_offsets = torch.zeros_like(
144144
self.kt_cache_block_offsets,
145145
device='cpu',
146-
pin_memory=True,
146+
pin_memory=prefer_pinned(),
147147
)
148148

149149
# Number of KT tokens for each sequence
@@ -594,7 +594,7 @@ def __post_init__(self):
594594
self.host_kt_cache_block_offsets = torch.zeros_like(
595595
self.kt_cache_block_offsets,
596596
device='cpu',
597-
pin_memory=True,
597+
pin_memory=prefer_pinned(),
598598
)
599599

600600
def prepare(self) -> None:

0 commit comments

Comments
 (0)