Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,15 +1599,14 @@ def update_spec_dec_param(
# Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length.
# So we create cache for position offsets and packed mask for each draft length to avoid reallocation.
assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree"
runtime_draft_len = (spec_metadata.runtime_draft_len
if spec_metadata is not None else
max_draft_len)
runtime_draft_token_buffer_width = (
spec_metadata.runtime_tokens_per_gen_step - 1)
self.generate_spec_decoding_generation_length(
runtime_draft_len=runtime_draft_len)
runtime_draft_len=runtime_draft_token_buffer_width)
self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets(
self.max_num_requests, runtime_draft_len)
self.max_num_requests, runtime_draft_token_buffer_width)
self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask(
self.max_num_requests, runtime_draft_len)
self.max_num_requests, runtime_draft_token_buffer_width)

def generate_spec_decoding_generation_length(self, runtime_draft_len):
self.spec_decoding_generation_lengths[:self.max_num_requests].fill_(
Expand Down
13 changes: 10 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ def __init__(self, config: CUDAGraphRunnerConfig):

def _create_shared_static_tensors(self):
"""Allocates static tensors sized for the largest possible batch."""
max_draft_len = self.config.original_max_total_draft_tokens if self.config.spec_config is not None else 0
token_per_request = max_draft_len + 1
runtime_draft_token_buffer_width = (
self.config.original_max_total_draft_tokens
if self.config.spec_config is not None else 0)
token_per_request = runtime_draft_token_buffer_width + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, self.config.max_num_tokens)
Expand Down Expand Up @@ -444,6 +446,11 @@ def _get_padded_batch(self, batch: ScheduledRequests,
if padding_size + batch.batch_size > self.config.batch_size:
return 0

runtime_tokens_per_gen_step = (
self.spec_config.get_runtime_tokens_per_gen_step(runtime_draft_len)
if self.spec_config is not None else 1 + runtime_draft_len)
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1

# No padding if it would create too many concurrent requests.
# This is not strictly required, but we should probably
# respect the requirement just in case that changes in the future.
Expand All @@ -461,7 +468,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
dummy_request = kv_cache_manager.add_dummy_requests(
[dummy_request_id],
is_gen=True,
max_num_draft_tokens=runtime_draft_len,
max_num_draft_tokens=runtime_draft_token_buffer_width,
use_mrope=self.config.use_mrope,
max_beam_width=self.config.max_beam_width,
draft_kv_cache_manager=draft_kv_cache_manager)
Expand Down
97 changes: 62 additions & 35 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def __init__(
spec_config.tokens_per_gen_step -
1) if spec_config is not None else 0
# Saved before zeroing for draft models; used by update_spec_dec_param.
self._spec_dec_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
self._spec_dec_max_total_draft_tokens = (
spec_config.max_total_draft_tokens
if spec_config is not None else 0)

preserve_wrapped_eagle3_widths = (spec_config is not None
and is_draft_model
Expand Down Expand Up @@ -334,11 +336,14 @@ def __init__(
self.llm_args.attn_backend,
sparse_attn_config=self.sparse_attention_config)

self.get_runtime_tokens_per_gen_step = spec_config.get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len: 1

if self.is_spec_decode:
self.spec_metadata = None
update_spec_config_from_model_config(self.spec_config,
self.model.config)
max_num_draft_tokens = self.original_max_total_draft_tokens * self.batch_size
max_num_draft_tokens = (self.original_max_total_draft_tokens *
self.batch_size)
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
dtype=torch.int,
device='cuda')
Expand Down Expand Up @@ -869,7 +874,7 @@ def _get_graphs_to_capture(
return graphs

# Case 3: Target model (two-model) or one-model without dynamic draft
draft_lengths = [self.max_total_draft_tokens]
draft_lengths = [self.max_draft_len]
should_capture_no_spec = (
self.max_total_draft_tokens > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
Expand Down Expand Up @@ -1194,12 +1199,15 @@ def _create_cuda_graph_warmup_request(

result = ScheduledRequests()
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
draft_len)
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1

# Add (batch_size - 1) dummy requests with seq_len=1.
requests = kv_cache_manager.add_dummy_requests(
list(range(batch_size - 1)),
is_gen=True,
max_num_draft_tokens=draft_len,
max_num_draft_tokens=runtime_draft_token_buffer_width,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps,
Expand All @@ -1216,26 +1224,29 @@ def _create_cuda_graph_warmup_request(
available_tokens = kv_cache_manager.get_num_available_tokens(
token_num_upper_bound=max_seq_len,
batch_size=batch_size,
max_num_draft_tokens=draft_len)
max_num_draft_tokens=runtime_draft_token_buffer_width)

# Also consider draft KV cache capacity when it exists
if draft_kv_cache_manager is not None:
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
batch_size=batch_size,
token_num_upper_bound=max_seq_len,
max_num_draft_tokens=draft_len)
max_num_draft_tokens=runtime_draft_token_buffer_width)
available_tokens = min(available_tokens, draft_available_tokens)

token_num = max(
1,
min(
available_tokens, max_seq_len - 1 -
get_num_extra_kv_tokens(self.spec_config) - draft_len))
available_tokens,
max_seq_len - 1 - get_num_extra_kv_tokens(self.spec_config) -
runtime_draft_token_buffer_width))
model_config = self.model.model_config.pretrained_config
max_position_embeddings = getattr(model_config,
'max_position_embeddings', None)
if max_position_embeddings is not None:
token_num = min(token_num, max_position_embeddings - draft_len)
token_num = min(
token_num,
max_position_embeddings - runtime_draft_token_buffer_width)

assert token_num > num_extra_decoding_steps, (
"Cannot fuse drafting loop. Not enough KV cache space for all draft tokens."
Expand All @@ -1246,7 +1257,7 @@ def _create_cuda_graph_warmup_request(
request_ids=[batch_size - 1],
token_nums=[token_num],
is_gen=True,
max_num_draft_tokens=draft_len,
max_num_draft_tokens=runtime_draft_token_buffer_width,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps,
Expand Down Expand Up @@ -1968,8 +1979,10 @@ def _update_target_input_tensors(
non_blocking=True)

# Prepare draft tokens
num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1
self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_(
next_draft_tokens_device[previous_slots, :].flatten(),
next_draft_tokens_device[
previous_slots, :num_draft_tokens_per_extend_request].flatten(),
non_blocking=True)

# Compute kv_len_offsets and update offset tensors
Expand Down Expand Up @@ -2005,8 +2018,10 @@ def _apply_incremental_update_target(
# Pre-compute constants
extend_requests = scheduled_requests.generation_requests
num_extend_requests = len(extend_requests)
num_tokens_per_extend_request = self.runtime_draft_len + 1
spec_config = self.spec_config
num_tokens_per_extend_request = self.get_runtime_tokens_per_gen_step(
self.runtime_draft_len)
runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1

prompt_lengths = torch.empty(num_extend_requests,
dtype=torch.int,
Expand Down Expand Up @@ -2068,7 +2083,8 @@ def _apply_incremental_update_target(
prompt_lengths = prompt_lengths.tolist()
num_cached_tokens_per_seq = num_cached_tokens_per_seq.tolist()

previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self.runtime_draft_len
previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy *
runtime_draft_token_buffer_width)

self._update_target_input_tensors(
num_accepted_tokens_device=num_accepted_tokens_device,
Expand Down Expand Up @@ -2347,6 +2363,9 @@ def _prepare_tp_inputs(
# will contain previous batch indices of generation requests
previous_batch_indices = []
previous_pos_indices = []
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
self.runtime_draft_len)
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1
for request in extend_requests:
request_ids.append(request.py_request_id)
request_accepted_path[
Expand Down Expand Up @@ -2405,16 +2424,16 @@ def _prepare_tp_inputs(
previous_batch_idx = request.py_batch_idx
request.py_batch_idx = request.py_seq_slot

sequence_lengths.append(1 + self.runtime_draft_len)
sequence_lengths.append(runtime_tokens_per_gen_step)
num_accepted_draft_tokens.append(
request.py_num_accepted_draft_tokens)
past_seen_token_num = request.max_beam_num_tokens - 1

draft_lens.append(self.runtime_draft_len)
draft_lens.append(runtime_draft_token_buffer_width)
gather_ids.extend(
list(
range(len(position_ids),
len(position_ids) + 1 + self.runtime_draft_len)))
len(position_ids) + runtime_tokens_per_gen_step)))
# For the target model + tree decoding
if not self.is_draft_model and not spec_config.is_linear_tree:
assert spec_tree_manager is not None
Expand All @@ -2427,19 +2446,19 @@ def _prepare_tp_inputs(
position_ids.extend(
list(
range(
past_seen_token_num, past_seen_token_num + 1 +
self.runtime_draft_len)))
past_seen_token_num, past_seen_token_num +
runtime_tokens_per_gen_step)))
# previous tensor
previous_batch_indices.append(previous_batch_idx)
previous_pos_indices.extend([previous_batch_idx] *
(1 + self.runtime_draft_len))
runtime_tokens_per_gen_step)

num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len + 1)
runtime_tokens_per_gen_step)
request.cached_tokens = num_cached_tokens_per_seq[-1]
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
self.attn_backend) and spec_config.is_linear_tree:
prompt_lengths.append(1 + self.runtime_draft_len)
prompt_lengths.append(runtime_tokens_per_gen_step)
else:
prompt_lengths.append(request.py_prompt_len)

Expand Down Expand Up @@ -2740,30 +2759,36 @@ def previous_seq_slots_device():
# Initialize these two values to zeros
self.previous_pos_id_offsets_cuda *= 0
self.previous_kv_lens_offsets_cuda *= 0
runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step(
self.runtime_draft_len)
runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1

if previous_batch_len > 0:
previous_slots = previous_seq_slots_device()
# previous input ids
previous_batch_tokens = previous_batch_len * (
1 + self.runtime_draft_len)
previous_batch_tokens = (previous_batch_len *
runtime_tokens_per_gen_step)
new_tokens = new_tokens_device.transpose(
0,
1)[previous_slots, :(1 + self.runtime_draft_len)].flatten()
1)[previous_slots, :runtime_tokens_per_gen_step].flatten()
self.input_ids_cuda[num_tokens:num_tokens +
previous_batch_tokens].copy_(
new_tokens, non_blocking=True)

# previous draft tokens
previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len
if self.runtime_draft_len > 0:
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
previous_batch_draft_tokens].copy_(
next_draft_tokens_device[
previous_slots, :self.
runtime_draft_len].flatten(),
non_blocking=True)
previous_batch_draft_tokens = (previous_batch_len *
runtime_draft_token_buffer_width)
if runtime_draft_token_buffer_width > 0:
self.draft_tokens_cuda[
num_draft_tokens:num_draft_tokens +
previous_batch_draft_tokens].copy_(
next_draft_tokens_device[
previous_slots, :
runtime_draft_token_buffer_width].flatten(),
non_blocking=True)
# prepare data for the preprocess inputs
kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1
kv_len_offsets_device = (new_tokens_lens_device -
runtime_tokens_per_gen_step)
previous_pos_indices_host = torch.tensor(
previous_pos_indices,
dtype=torch.int,
Expand All @@ -2789,8 +2814,8 @@ def previous_seq_slots_device():
extend_dummy_requests)
self.previous_pos_id_offsets_cuda[
(num_extend_reqeust_wo_dummy - previous_batch_len) *
(1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy *
(1 + self.runtime_draft_len)].copy_(
runtime_tokens_per_gen_step:num_extend_reqeust_wo_dummy *
runtime_tokens_per_gen_step].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
0:previous_batch_tokens]],
non_blocking=True)
Expand Down Expand Up @@ -3626,6 +3651,8 @@ def forward(self,
# Propagate runtime_draft_len (already set on self by py_executor)
# to spec_metadata so downstream code (eagle3, interface, trtllm) can read it.
spec_metadata.runtime_draft_len = self.runtime_draft_len
spec_metadata.runtime_tokens_per_gen_step = (
self.get_runtime_tokens_per_gen_step(self.runtime_draft_len))

attn_metadata.update_spec_dec_param(
batch_size=scheduled_requests.batch_size,
Expand Down
Loading
Loading