Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import math
import os
import weakref
Expand Down Expand Up @@ -758,6 +759,33 @@ def is_sm_version_trtllm_gen_kernel(self, sm):
return not (sm < 100 or sm in [120, 121])


@functools.cache
def generate_spec_decoding_position_offsets(max_num_requests: int,
draft_len: int) -> torch.Tensor:
width = draft_len + 1
row = torch.arange(width, dtype=torch.int, device='cuda')
return row.unsqueeze(0).expand(max_num_requests, -1).contiguous()


@functools.cache
def generate_spec_decoding_packed_mask(max_num_requests: int,
draft_len: int) -> torch.Tensor:
width = draft_len + 1
num_blocks = math.ceil(width / 32)
mask = torch.zeros([max_num_requests, width, num_blocks],
dtype=torch.int,
device='cuda')
remaining = width
for blk in range(num_blocks):
if remaining <= 0:
break
n = min(32, remaining)
vals = (torch.pow(2, torch.arange(n) + 1) - 1).int()
mask[:, blk * 32:blk * 32 + n, blk] = vals
remaining -= 32
return mask


@dataclass(kw_only=True)
class TrtllmAttentionMetadata(AttentionMetadata):
workspace: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -1466,15 +1494,19 @@ def update_spec_dec_param(

# Parameters can be fixed and not changed during runtime if the
if self.is_spec_decoding_enabled:
# Skip pre-allocating position_offsets and packed_mask when dynamic draft length is enabled.
# We will use per-draft-len cached tensors for position_offsets and packed_mask instead.
# Currently dynamic draft length is only supported for linear tree (not is_spec_dec_tree).

# These buffers are accessed more like removing input padding,
# rather than using max_total_draft_tokens + 1 as the offset between different requests.
if self.spec_decoding_position_offsets is None:
if is_spec_dec_tree and self.spec_decoding_position_offsets is None:
self.spec_decoding_position_offsets = torch.empty(
[self.max_num_requests, max_total_draft_tokens + 1],
dtype=torch.int,
device='cuda',
)
if self.spec_decoding_packed_mask is None:
if is_spec_dec_tree and self.spec_decoding_packed_mask is None:
self.spec_decoding_packed_mask = torch.empty(
[
self.max_num_requests, max_total_draft_tokens + 1,
Expand Down Expand Up @@ -1510,7 +1542,7 @@ def update_spec_dec_param(
spec_decoding_generation_lengths, non_blocking=True)
else:
self.generate_spec_decoding_generation_length(
max_draft_len=max_total_draft_tokens)
runtime_draft_len=max_total_draft_tokens)

# Case 2/3: static tree
elif self.is_spec_dec_tree and not self.is_spec_dec_dynamic_tree and spec_metadata is not None:
Expand Down Expand Up @@ -1558,48 +1590,28 @@ def update_spec_dec_param(
self.spec_decoding_packed_mask.reshape(
-1)[:(max_draft_len + 1) * batch_size].copy_(
spec_decoding_packed_mask, non_blocking=True)
# generation_lengths
self.generate_spec_decoding_generation_length(
max_draft_len=max_draft_len)
runtime_draft_len=max_draft_len)

# Case 4: linear tree
else:
# Currently dynamic draft length is only supported for linear tree
# Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length.
# So we create cache for position offsets and packed mask for each draft length to avoid reallocation.
assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree"
# Prepare for the linear-tree.
# Populate the mask that won't change during inference phase.
self.generate_spec_decoding_position_offsets(
max_draft_len=max_draft_len)
self.generate_spec_decoding_packed_mask(
max_draft_len=max_draft_len)
runtime_draft_len = (spec_metadata.runtime_draft_len
if spec_metadata is not None else
max_draft_len)
self.generate_spec_decoding_generation_length(
max_draft_len=max_draft_len)

def generate_spec_decoding_position_offsets(self, max_draft_len):
position_offset = torch.arange(max_draft_len + 1,
dtype=torch.int,
device='cpu',
pin_memory=prefer_pinned())
# fill all the batches with same position offset
self.spec_decoding_position_offsets.copy_(position_offset,
non_blocking=True)

def generate_spec_decoding_packed_mask(self, max_draft_len):
num_blocks = math.ceil((max_draft_len + 1) / 32)
tmp_max_draft_len = max_draft_len + 1
for block_idx in range(num_blocks):
if tmp_max_draft_len < 0:
break
dummy_idx = torch.arange(min(32, tmp_max_draft_len))
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
self.spec_decoding_packed_mask[:, :, block_idx].copy_(
spec_decoding_packed_mask, non_blocking=True)
tmp_max_draft_len -= 32

def generate_spec_decoding_generation_length(self, max_draft_len):
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
max_draft_len + 1)
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
spec_decoding_generation_length, non_blocking=True)
runtime_draft_len=runtime_draft_len)
self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets(
self.max_num_requests, runtime_draft_len)
self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask(
self.max_num_requests, runtime_draft_len)

def generate_spec_decoding_generation_length(self, runtime_draft_len):
self.spec_decoding_generation_lengths[:self.max_num_requests].fill_(
runtime_draft_len + 1)

def is_sm_version_trtllm_gen_kernel(self, sm):
return not (sm < 100 or sm in [120, 121])
Expand Down
65 changes: 51 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class CUDAGraphRunnerConfig:
mapping: Optional[Mapping]
dist: Optional[Distributed]
kv_cache_manager_key: Any
dynamic_draft_len_mapping: Optional[Dict[int, int]] = None
sparse_attention_config: Optional[BaseSparseAttentionConfig] = None


Expand Down Expand Up @@ -108,7 +109,8 @@ def __init__(self, config: CUDAGraphRunnerConfig):
Callable[[], Optional[torch.Tensor]]] = {}
self.graph_metadata: Dict[KeyType, Dict[str, Any]] = {}
self.memory_pool = config.cuda_graph_mem_pool
self.padding_dummy_request: Optional["Request"] = None
self.padding_dummy_requests: Dict[int, "Request"] = {}
self.dynamic_draft_len_mapping = config.dynamic_draft_len_mapping

self.shared_static_tensors: Dict[str, torch.Tensor] = {}
if self.enabled:
Expand Down Expand Up @@ -209,7 +211,7 @@ def get_graph_key(
key = (batch_size, draft_len, spec_resource_manager.is_first_draft,
short_seq_len_mode)
else:
# With dynamic spec decode, the draft length maybe zero even when enable_spec_decode is True,
# With dynamic spec decode, the draft length may be zero even when enable_spec_decode is True,
# so we need to get the draft length from the batch instead of using enable_spec_decode.
draft_len_list = []
for request in batch.generation_requests:
Expand Down Expand Up @@ -423,51 +425,66 @@ def _get_padded_batch(self, batch: ScheduledRequests,
or new_batch_size > self.max_supported_batch_size):
return 0

padded_batch_size = self._round_up_batch_size(new_batch_size)
# When dynamic draft length is enabled (one-model path), we treat the determined runtime draft length
# as the source of truth and pad the batch size up to the nearest existing graph
# for that draft length.
if (self.spec_config and self.spec_config.draft_len_schedule
and self.spec_config.spec_dec_mode.support_dynamic_draft_len()):
padded_batch_size = self._round_up_batch_size_with_draft_len(
new_batch_size, runtime_draft_len)
else:
padded_batch_size = self._round_up_batch_size(new_batch_size)

if batch_size == padded_batch_size:
return 0

padding_size = padded_batch_size - batch_size
if padding_size <= 0:
return 0
if padding_size + batch.batch_size > self.config.batch_size:
return 0

# No padding if it would create too many concurrent requests.
# This is not strictly required, but we should probably
# respect the requirement just in case that changes in the future.
if self.padding_dummy_request is None:
# Use per-draft-len dummy requests for dynamic draft length support.
if runtime_draft_len not in self.padding_dummy_requests:

# Get draft KV cache manager only for one-model speculative decoding.
# In two-model mode, each model has its own KV cache manager, so
# draft_kv_cache_manager should be None.
draft_kv_cache_manager = get_draft_kv_cache_manager(
self.spec_config, resource_manager)

self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
# Use unique dummy request ID per draft length
dummy_request_id = CUDA_GRAPH_DUMMY_REQUEST_ID - runtime_draft_len
dummy_request = kv_cache_manager.add_dummy_requests(
[dummy_request_id],
is_gen=True,
max_num_draft_tokens=runtime_draft_len,
use_mrope=self.config.use_mrope,
max_beam_width=self.config.max_beam_width,
draft_kv_cache_manager=draft_kv_cache_manager)

if self.padding_dummy_request is None:
if dummy_request is None:
return 0
else:
self.padding_dummy_request = self.padding_dummy_request[0]
self.padding_dummy_request.is_cuda_graph_dummy = True
dummy_request = dummy_request[0]
dummy_request.is_cuda_graph_dummy = True

spec_res_mgr = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_res_mgr:
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
spec_res_mgr.add_dummy_requests([dummy_request_id])
self.padding_dummy_requests[runtime_draft_len] = dummy_request

if (isinstance(kv_cache_manager, MambaCacheManager)
and not use_cpp_mamba_cache_manager()):
kv_cache_manager.reorder_state_indices_when_padding_requests(
batch_size, padding_size)

self.padding_dummy_request.py_draft_tokens = [0] * runtime_draft_len
batch.generation_requests.extend([self.padding_dummy_request] *
padding_size)
padding_dummy_request = self.padding_dummy_requests[runtime_draft_len]
batch.generation_requests.extend([padding_dummy_request] * padding_size)
return padding_size

def _round_up_batch_size(self, batch_size: int) -> int:
Expand All @@ -479,6 +496,26 @@ def _round_up_batch_size(self, batch_size: int) -> int:
return 0
return self.supported_batch_sizes[idx]

def _round_up_batch_size_with_draft_len(self, batch_size: int,
draft_len: int) -> int:
"""Finds the smallest graph batch size >= batch_size that also matches the given draft_len."""
if not self.dynamic_draft_len_mapping:
# Fallback to regular round up if no mapping
return self._round_up_batch_size(batch_size)

start_idx = bisect.bisect_left(self.supported_batch_sizes, batch_size)
# Negate the list to make it non-decreasing for bisect
# (draft_len decreases as batch_size increases in the schedule)
draft_lens = [
self.dynamic_draft_len_mapping.get(self.supported_batch_sizes[i], 0)
for i in range(start_idx, len(self.supported_batch_sizes))
]
idx = bisect.bisect_left(draft_lens, -draft_len, key=lambda x: -x)
if idx < len(draft_lens) and draft_lens[idx] == draft_len:
return self.supported_batch_sizes[start_idx + idx]
# No suitable graph found
return 0

@contextlib.contextmanager
def pad_batch(self,
scheduled_requests: ScheduledRequests,
Expand All @@ -502,7 +539,7 @@ def clear(self):
self.graphs.clear()
self.graph_outputs.clear()
self.graph_metadata.clear()
self.padding_dummy_request = None
self.padding_dummy_requests = {}
del self.memory_pool
self.memory_pool = None
torch.cuda.empty_cache()
Loading