diff --git a/tensorrt_llm/_torch/speculative/pard.py b/tensorrt_llm/_torch/speculative/pard.py index b919d6e16b6..fa25627bc14 100644 --- a/tensorrt_llm/_torch/speculative/pard.py +++ b/tensorrt_llm/_torch/speculative/pard.py @@ -4,6 +4,7 @@ import torch from torch import nn +from tensorrt_llm._utils import prefer_pinned from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -34,7 +35,9 @@ def prepare(self): assert self.request_ids is not None num_seqs = len(self.request_ids) - batch_indices = torch.arange(num_seqs, dtype=torch.int, device="cpu", pin_memory=True) + batch_indices = torch.arange( + num_seqs, dtype=torch.int, device="cpu", pin_memory=prefer_pinned() + ) self.batch_indices_cuda[:num_seqs].copy_(batch_indices, non_blocking=True)