diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 0f95c443a93..1f0b03f5ccf 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1599,15 +1599,14 @@ def update_spec_dec_param( # Dynamic draft length needs position offsets and packed mask to be shaped for each runtime draft length. # So we create cache for position offsets and packed mask for each draft length to avoid reallocation. assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree" - runtime_draft_len = (spec_metadata.runtime_draft_len - if spec_metadata is not None else - max_draft_len) + runtime_draft_token_buffer_width = ( + spec_metadata.runtime_tokens_per_gen_step - 1) self.generate_spec_decoding_generation_length( - runtime_draft_len=runtime_draft_len) + runtime_draft_len=runtime_draft_token_buffer_width) self.spec_decoding_position_offsets = generate_spec_decoding_position_offsets( - self.max_num_requests, runtime_draft_len) + self.max_num_requests, runtime_draft_token_buffer_width) self.spec_decoding_packed_mask = generate_spec_decoding_packed_mask( - self.max_num_requests, runtime_draft_len) + self.max_num_requests, runtime_draft_token_buffer_width) def generate_spec_decoding_generation_length(self, runtime_draft_len): self.spec_decoding_generation_lengths[:self.max_num_requests].fill_( diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 93d20883347..5b68a89e4c9 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -119,8 +119,10 @@ def __init__(self, config: CUDAGraphRunnerConfig): def _create_shared_static_tensors(self): """Allocates static tensors sized for the largest possible batch.""" - max_draft_len = self.config.original_max_total_draft_tokens if self.config.spec_config is not None else 0 - token_per_request = max_draft_len + 1 + runtime_draft_token_buffer_width = ( + self.config.original_max_total_draft_tokens + if self.config.spec_config is not None else 0) + token_per_request = runtime_draft_token_buffer_width + 1 max_total_tokens = (self.max_supported_batch_size * self.max_beam_width * token_per_request) max_total_tokens = min(max_total_tokens, self.config.max_num_tokens) @@ -444,6 +446,11 @@ def _get_padded_batch(self, batch: ScheduledRequests, if padding_size + batch.batch_size > self.config.batch_size: return 0 + runtime_tokens_per_gen_step = ( + self.spec_config.get_runtime_tokens_per_gen_step(runtime_draft_len) + if self.spec_config is not None else 1 + runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 + # No padding if it would create too many concurrent requests. # This is not strictly required, but we should probably # respect the requirement just in case that changes in the future. @@ -461,7 +468,7 @@ def _get_padded_batch(self, batch: ScheduledRequests, dummy_request = kv_cache_manager.add_dummy_requests( [dummy_request_id], is_gen=True, - max_num_draft_tokens=runtime_draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, use_mrope=self.config.use_mrope, max_beam_width=self.config.max_beam_width, draft_kv_cache_manager=draft_kv_cache_manager) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index ebff20bb6a6..708b48c0cd8 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -180,7 +180,9 @@ def __init__( spec_config.tokens_per_gen_step - 1) if spec_config is not None else 0 # Saved before zeroing for draft models; used by update_spec_dec_param. - self._spec_dec_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0 + self._spec_dec_max_total_draft_tokens = ( + spec_config.max_total_draft_tokens + if spec_config is not None else 0) preserve_wrapped_eagle3_widths = (spec_config is not None and is_draft_model @@ -334,11 +336,14 @@ def __init__( self.llm_args.attn_backend, sparse_attn_config=self.sparse_attention_config) + self.get_runtime_tokens_per_gen_step = spec_config.get_runtime_tokens_per_gen_step if spec_config is not None else lambda runtime_draft_len: 1 + if self.is_spec_decode: self.spec_metadata = None update_spec_config_from_model_config(self.spec_config, self.model.config) - max_num_draft_tokens = self.original_max_total_draft_tokens * self.batch_size + max_num_draft_tokens = (self.original_max_total_draft_tokens * + self.batch_size) self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ), dtype=torch.int, device='cuda') @@ -869,7 +874,7 @@ def _get_graphs_to_capture( return graphs # Case 3: Target model (two-model) or one-model without dynamic draft - draft_lengths = [self.max_total_draft_tokens] + draft_lengths = [self.max_draft_len] should_capture_no_spec = ( self.max_total_draft_tokens > 0 and not self.spec_config.spec_dec_mode.use_one_engine() @@ -1194,12 +1199,15 @@ def _create_cuda_graph_warmup_request( result = ScheduledRequests() num_extra_decoding_steps = self._get_num_extra_decoding_steps() + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 # Add (batch_size - 1) dummy requests with seq_len=1. requests = kv_cache_manager.add_dummy_requests( list(range(batch_size - 1)), is_gen=True, - max_num_draft_tokens=draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, num_extra_decoding_steps=num_extra_decoding_steps, @@ -1216,26 +1224,29 @@ def _create_cuda_graph_warmup_request( available_tokens = kv_cache_manager.get_num_available_tokens( token_num_upper_bound=max_seq_len, batch_size=batch_size, - max_num_draft_tokens=draft_len) + max_num_draft_tokens=runtime_draft_token_buffer_width) # Also consider draft KV cache capacity when it exists if draft_kv_cache_manager is not None: draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens( batch_size=batch_size, token_num_upper_bound=max_seq_len, - max_num_draft_tokens=draft_len) + max_num_draft_tokens=runtime_draft_token_buffer_width) available_tokens = min(available_tokens, draft_available_tokens) token_num = max( 1, min( - available_tokens, max_seq_len - 1 - - get_num_extra_kv_tokens(self.spec_config) - draft_len)) + available_tokens, + max_seq_len - 1 - get_num_extra_kv_tokens(self.spec_config) - + runtime_draft_token_buffer_width)) model_config = self.model.model_config.pretrained_config max_position_embeddings = getattr(model_config, 'max_position_embeddings', None) if max_position_embeddings is not None: - token_num = min(token_num, max_position_embeddings - draft_len) + token_num = min( + token_num, + max_position_embeddings - runtime_draft_token_buffer_width) assert token_num > num_extra_decoding_steps, ( "Cannot fuse drafting loop. Not enough KV cache space for all draft tokens." @@ -1246,7 +1257,7 @@ def _create_cuda_graph_warmup_request( request_ids=[batch_size - 1], token_nums=[token_num], is_gen=True, - max_num_draft_tokens=draft_len, + max_num_draft_tokens=runtime_draft_token_buffer_width, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, num_extra_decoding_steps=num_extra_decoding_steps, @@ -1968,8 +1979,10 @@ def _update_target_input_tensors( non_blocking=True) # Prepare draft tokens + num_draft_tokens_per_extend_request = num_tokens_per_extend_request - 1 self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_( - next_draft_tokens_device[previous_slots, :].flatten(), + next_draft_tokens_device[ + previous_slots, :num_draft_tokens_per_extend_request].flatten(), non_blocking=True) # Compute kv_len_offsets and update offset tensors @@ -2005,8 +2018,10 @@ def _apply_incremental_update_target( # Pre-compute constants extend_requests = scheduled_requests.generation_requests num_extend_requests = len(extend_requests) - num_tokens_per_extend_request = self.runtime_draft_len + 1 spec_config = self.spec_config + num_tokens_per_extend_request = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = num_tokens_per_extend_request - 1 prompt_lengths = torch.empty(num_extend_requests, dtype=torch.int, @@ -2068,7 +2083,8 @@ def _apply_incremental_update_target( prompt_lengths = prompt_lengths.tolist() num_cached_tokens_per_seq = num_cached_tokens_per_seq.tolist() - previous_batch_draft_tokens = num_extend_reqeust_wo_dummy * self.runtime_draft_len + previous_batch_draft_tokens = (num_extend_reqeust_wo_dummy * + runtime_draft_token_buffer_width) self._update_target_input_tensors( num_accepted_tokens_device=num_accepted_tokens_device, @@ -2347,6 +2363,9 @@ def _prepare_tp_inputs( # will contain previous batch indices of generation requests previous_batch_indices = [] previous_pos_indices = [] + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 for request in extend_requests: request_ids.append(request.py_request_id) request_accepted_path[ @@ -2405,16 +2424,16 @@ def _prepare_tp_inputs( previous_batch_idx = request.py_batch_idx request.py_batch_idx = request.py_seq_slot - sequence_lengths.append(1 + self.runtime_draft_len) + sequence_lengths.append(runtime_tokens_per_gen_step) num_accepted_draft_tokens.append( request.py_num_accepted_draft_tokens) past_seen_token_num = request.max_beam_num_tokens - 1 - draft_lens.append(self.runtime_draft_len) + draft_lens.append(runtime_draft_token_buffer_width) gather_ids.extend( list( range(len(position_ids), - len(position_ids) + 1 + self.runtime_draft_len))) + len(position_ids) + runtime_tokens_per_gen_step))) # For the target model + tree decoding if not self.is_draft_model and not spec_config.is_linear_tree: assert spec_tree_manager is not None @@ -2427,19 +2446,19 @@ def _prepare_tp_inputs( position_ids.extend( list( range( - past_seen_token_num, past_seen_token_num + 1 + - self.runtime_draft_len))) + past_seen_token_num, past_seen_token_num + + runtime_tokens_per_gen_step))) # previous tensor previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * - (1 + self.runtime_draft_len)) + runtime_tokens_per_gen_step) num_cached_tokens_per_seq.append(past_seen_token_num + - self.runtime_draft_len + 1) + runtime_tokens_per_gen_step) request.cached_tokens = num_cached_tokens_per_seq[-1] if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( self.attn_backend) and spec_config.is_linear_tree: - prompt_lengths.append(1 + self.runtime_draft_len) + prompt_lengths.append(runtime_tokens_per_gen_step) else: prompt_lengths.append(request.py_prompt_len) @@ -2740,30 +2759,36 @@ def previous_seq_slots_device(): # Initialize these two values to zeros self.previous_pos_id_offsets_cuda *= 0 self.previous_kv_lens_offsets_cuda *= 0 + runtime_tokens_per_gen_step = self.get_runtime_tokens_per_gen_step( + self.runtime_draft_len) + runtime_draft_token_buffer_width = runtime_tokens_per_gen_step - 1 if previous_batch_len > 0: previous_slots = previous_seq_slots_device() # previous input ids - previous_batch_tokens = previous_batch_len * ( - 1 + self.runtime_draft_len) + previous_batch_tokens = (previous_batch_len * + runtime_tokens_per_gen_step) new_tokens = new_tokens_device.transpose( 0, - 1)[previous_slots, :(1 + self.runtime_draft_len)].flatten() + 1)[previous_slots, :runtime_tokens_per_gen_step].flatten() self.input_ids_cuda[num_tokens:num_tokens + previous_batch_tokens].copy_( new_tokens, non_blocking=True) # previous draft tokens - previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len - if self.runtime_draft_len > 0: - self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + - previous_batch_draft_tokens].copy_( - next_draft_tokens_device[ - previous_slots, :self. - runtime_draft_len].flatten(), - non_blocking=True) + previous_batch_draft_tokens = (previous_batch_len * + runtime_draft_token_buffer_width) + if runtime_draft_token_buffer_width > 0: + self.draft_tokens_cuda[ + num_draft_tokens:num_draft_tokens + + previous_batch_draft_tokens].copy_( + next_draft_tokens_device[ + previous_slots, : + runtime_draft_token_buffer_width].flatten(), + non_blocking=True) # prepare data for the preprocess inputs - kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1 + kv_len_offsets_device = (new_tokens_lens_device - + runtime_tokens_per_gen_step) previous_pos_indices_host = torch.tensor( previous_pos_indices, dtype=torch.int, @@ -2789,8 +2814,8 @@ def previous_seq_slots_device(): extend_dummy_requests) self.previous_pos_id_offsets_cuda[ (num_extend_reqeust_wo_dummy - previous_batch_len) * - (1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy * - (1 + self.runtime_draft_len)].copy_( + runtime_tokens_per_gen_step:num_extend_reqeust_wo_dummy * + runtime_tokens_per_gen_step].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ 0:previous_batch_tokens]], non_blocking=True) @@ -3626,6 +3651,8 @@ def forward(self, # Propagate runtime_draft_len (already set on self by py_executor) # to spec_metadata so downstream code (eagle3, interface, trtllm) can read it. spec_metadata.runtime_draft_len = self.runtime_draft_len + spec_metadata.runtime_tokens_per_gen_step = ( + self.get_runtime_tokens_per_gen_step(self.runtime_draft_len)) attn_metadata.update_spec_dec_param( batch_size=scheduled_requests.batch_size, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 361b9548491..c0db8cd8de8 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1642,22 +1642,42 @@ def _handle_dynamic_draft_len(self, from tensorrt_llm._torch.speculative.utils import \ get_draft_len_for_batch_size + spec_dec_mode = self.model_engine.spec_config.spec_dec_mode + # 1. Resolve runtime draft length from schedule runtime_draft_len = get_draft_len_for_batch_size( self.model_engine.spec_config.draft_len_schedule, scheduled_batch.batch_size, self.model_engine.max_draft_len) - # 2. Pad or truncate draft tokens to the resolved length - PADDING_TOKEN = 0 + DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id. for request in scheduled_batch.generation_requests: - current_draft_len = len(request.py_draft_tokens) - if current_draft_len < runtime_draft_len: - padding_needed = runtime_draft_len - current_draft_len - request.py_draft_tokens.extend([PADDING_TOKEN] * - padding_needed) - elif current_draft_len > runtime_draft_len: - request.py_draft_tokens = request.py_draft_tokens[: - runtime_draft_len] + current_num_draft_tokens = len(request.py_draft_tokens) + if spec_dec_mode.is_pard(): + # special case as PARD carries 2K-1 draft tokens per request + runtime_draft_token_buffer_width = ( + self.model_engine.spec_config. + get_runtime_tokens_per_gen_step(runtime_draft_len) - 1) + current_runtime_draft_len = ( + current_num_draft_tokens + + 1) // 2 if current_num_draft_tokens > 0 else 0 + real_draft_tokens = request.py_draft_tokens[:min( + current_runtime_draft_len, runtime_draft_len)] + real_draft_tokens.extend( + [DRAFT_BUFFER_PAD] * + (runtime_draft_len - len(real_draft_tokens))) + request.py_draft_tokens = real_draft_tokens + [ + DRAFT_BUFFER_PAD + ] * (runtime_draft_token_buffer_width - + len(real_draft_tokens)) + else: + if current_num_draft_tokens < runtime_draft_len: + padding_needed = (runtime_draft_len - + current_num_draft_tokens) + request.py_draft_tokens.extend([DRAFT_BUFFER_PAD] * + padding_needed) + elif current_num_draft_tokens > runtime_draft_len: + request.py_draft_tokens = request.py_draft_tokens[: + runtime_draft_len] self.model_engine.runtime_draft_len = runtime_draft_len else: diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 22fc5ad7738..88d6a496a4a 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -142,8 +142,8 @@ def support_capturable_guided_decoder(self): ) or self.is_external_drafter() or self.is_sa() def support_dynamic_draft_len(self): - # TODO: expand to all one-model algorithms - return self.is_eagle3_one_model() + return self.is_mtp_one_model() or self.is_eagle3_one_model( + ) or self.is_pard() def has_draft_model(self): return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle() @@ -282,6 +282,9 @@ class SpecMetadata: # draft_len_schedule. Otherwise it equals max_draft_len (the static max). # Always set by model_engine.forward() before any downstream code reads it. runtime_draft_len: int = 0 + # Total runtime tokens per generation request for the current iteration, + # Normally, it equals 1 + runtime_draft_len. But for PARD, it equals 2 * runtime_draft_len. + runtime_tokens_per_gen_step: int = 1 # For non-greedy sampling on 1-model. allow_advanced_sampling: bool = False @@ -575,9 +578,8 @@ def _sample_and_accept_draft_tokens_base( num_accepted_tokens: [batch_size] - Number of accepted tokens per request """ # Derive draft length from the actual draft_tokens shape rather than - # spec_metadata.runtime_draft_len, because they can differ: PARD sets - # runtime_draft_len = 2K-1 for input sizing but only passes K draft - # tokens for acceptance; + # spec_metadata.runtime_draft_len, because callers may slice a wider + # runtime token layout down to the K draft tokens used for acceptance. runtime_draft_len = draft_tokens.shape[-1] num_gens = batch_size - num_contexts diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index e381b5d7021..02fb97aa47d 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -227,7 +227,7 @@ def prepare(self): num_contexts = num_seqs - self.num_generations gen_request_ids = self.request_ids[num_contexts:] if gen_request_ids: - sa_manager.prepare(gen_request_ids, self.max_draft_len) + sa_manager.prepare(gen_request_ids, self.runtime_draft_len) class MTPSampler(SpecSamplerBase): @@ -341,7 +341,7 @@ def forward( - hidden states: H_E, H_F, H_G, H_H (H_H is invalid) Draft model: MTP1: - # For generation request, `mtp_num_modules` of tokens will be used as input. + # For generation request, `runtime_draft_len` tokens are used as input. - input tokens: FGX - input hidden states: H_E, H_F, H_G - KV cache: (BCDE) + FGX @@ -392,6 +392,13 @@ def forward( - new generated draft tokens: UVQ ''' + runtime_draft_len = spec_metadata.runtime_draft_len + # skip the draft forward if the runtime draft length is 0 + if runtime_draft_len == 0: + return self.skip_drafting(input_ids, position_ids, hidden_states, + logits, attn_metadata, spec_metadata, + draft_model) + batch_size = attn_metadata.num_seqs raw_logits = logits @@ -422,7 +429,8 @@ def forward( # update attn metadata if attn_metadata is not None: - self.change_attn_metadata(num_accepted_tokens, attn_metadata) + self.change_attn_metadata(num_accepted_tokens, attn_metadata, + spec_metadata) # Run MTP layers to predict draft tokens next_draft_tokens = [] @@ -433,7 +441,8 @@ def forward( resource_manager) with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): - for i, mtp_layer in enumerate(draft_model.mtp_layers): + for i, mtp_layer in enumerate( + draft_model.mtp_layers[:runtime_draft_len]): if self.guided_decoder is not None: new_tokens = draft_inputs['input_ids'][last_tokens_idx] self.guided_decoder.add_draft_batch(new_tokens, @@ -506,17 +515,17 @@ def skip_forward( resource_manager=None, ): batch_size = attn_metadata.num_seqs - mtp_num_modules = self.spec_config.num_nextn_predict_layers - accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + runtime_draft_len = spec_metadata.runtime_draft_len + accepted_tokens = torch.empty((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, dtype=torch.int, device=logits.device) - next_draft_tokens = torch.empty((batch_size, mtp_num_modules), + next_draft_tokens = torch.empty((batch_size, runtime_draft_len), dtype=torch.int, device=logits.device) - next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + next_new_tokens = torch.empty((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) return { @@ -589,14 +598,15 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): seq_lens = attn_metadata.seq_lens_cuda seq_lens_cpu = attn_metadata.seq_lens hidden_size = hidden_states.shape[-1] - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len + max_draft_len = self.spec_config.num_nextn_predict_layers if self.is_thop: _, _ = torch.ops.trtllm.mtp_update_hidden_states_op( input_ids, seq_lens, hidden_states, spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens, - mtp_num_modules, batch_size, num_contexts, hidden_size) + runtime_draft_len, batch_size, num_contexts, hidden_size) else: assert len(spec_metadata.request_ids) == batch_size mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool @@ -625,7 +635,7 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): dim=1) ctx_batch_idx = spec_metadata.batch_indices_cuda[:num_contexts] row_indices_ctx = ctx_batch_idx.unsqueeze(1).expand( - -1, mtp_num_modules) + -1, max_draft_len) col_indices_ctx = (seq_lens_ctx.unsqueeze(1) + spec_metadata.draft_token_indices_cuda) new_mtp_past_tokens.append(cat_tokens_ctx[row_indices_ctx, @@ -636,10 +646,10 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): # generation if num_gens > 0: unpacked_input_ids_gen = input_ids[num_ctx_tokens:].reshape( - num_gens, mtp_num_modules + 1).int() + num_gens, runtime_draft_len + 1).int() hidden_states_gen = hidden_states[num_ctx_tokens:, :] unpacked_hidden_states_gen = hidden_states_gen.reshape( - num_gens, mtp_num_modules + 1, hidden_size) + num_gens, runtime_draft_len + 1, hidden_size) cat_tokens_gen = torch.cat( (mtp_tokens[num_contexts:], unpacked_input_ids_gen), dim=1) cat_hidden_states_gen = torch.cat( @@ -648,10 +658,10 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): dim=1) gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens] row_indices_gen = gen_batch_idx.unsqueeze(1).expand( - -1, mtp_num_modules) + -1, max_draft_len) col_indices_gen = ( num_accepted_tokens[num_contexts:].unsqueeze(1) + - spec_metadata.draft_token_indices_cuda) + spec_metadata.draft_token_indices_cuda[:max_draft_len]) new_mtp_past_tokens.append(cat_tokens_gen[row_indices_gen, col_indices_gen]) new_mtp_past_hidden_states.append( @@ -666,17 +676,17 @@ def unpack_sequence(packed_seq_cuda, seq_lens_cuda, seq_lens_cpu): new_mtp_past_hidden_states) @torch.compile(options={"max-autotune": True}) - def topk_kernel(self, gen_logprobs, num_gens, mtp_num_modules, + def topk_kernel(self, gen_logprobs, num_gens, runtime_draft_len, spec_metadata): topk_value, topk_indices = torch.topk(gen_logprobs, k=self.spec_config.relaxed_topk, dim=-1) - topk_indices = topk_indices.reshape(num_gens, mtp_num_modules + 1, + topk_indices = topk_indices.reshape(num_gens, runtime_draft_len + 1, self.spec_config.relaxed_topk) - topk_value = topk_value.reshape(num_gens, mtp_num_modules + 1, + topk_value = topk_value.reshape(num_gens, runtime_draft_len + 1, self.spec_config.relaxed_topk) draft_tokens = spec_metadata.draft_tokens.reshape( - num_gens, mtp_num_modules) + num_gens, runtime_draft_len) return topk_value, topk_indices, draft_tokens @torch.compile(options={"max-autotune": True}) @@ -761,14 +771,14 @@ def sample_and_accept_draft_tokens( batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len if logits.dim() == 1: logits = logits.unsqueeze(0) # The return buffer if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop: - accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)), + accepted_tokens = torch.ones((batch_size, (runtime_draft_len + 1)), dtype=torch.int, device=logits.device) num_accepted_tokens = torch.ones(batch_size, @@ -804,41 +814,40 @@ def sample_and_accept_draft_tokens( # generation gen_logprobs = self.process_generation_logits(logits, num_contexts) topk_value, topk_indices, draft_tokens = self.topk_kernel( - gen_logprobs, num_gens, mtp_num_modules, spec_metadata) + gen_logprobs, num_gens, runtime_draft_len, spec_metadata) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_relaxed_acceptance_op( spec_metadata.slot_ids, topk_value, topk_indices, draft_tokens, mtp_relaxed_delta_pool, num_accepted_tokens, accepted_tokens, - mtp_num_modules, batch_size, num_contexts, + runtime_draft_len, batch_size, num_contexts, self.spec_config.relaxed_topk, self.spec_config.relaxed_delta, self.spec_config.begin_thinking_phase_token, self.spec_config.end_thinking_phase_token) # Apply force override for relaxed acceptance path num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, num_contexts, - spec_metadata.runtime_draft_len) + num_accepted_tokens, num_contexts, runtime_draft_len) # Strict acceptance else: if self.is_thop: # Temporary buffer target_tokens_cache = torch.zeros(batch_size * - (mtp_num_modules + 1), + (runtime_draft_len + 1), dtype=torch.int, device=logits.device) accepted_tokens, num_accepted_tokens = torch.ops.trtllm.mtp_sampling_and_accepted_draft_tokens_op( logits, spec_metadata.draft_tokens, target_tokens_cache, - mtp_num_modules, batch_size, num_contexts, logits.shape[-1]) + runtime_draft_len, batch_size, num_contexts, + logits.shape[-1]) # Apply force override for THOP path num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, num_contexts, - spec_metadata.runtime_draft_len) + num_accepted_tokens, num_contexts, runtime_draft_len) else: # Reshape draft tokens for base implementation draft_tokens = spec_metadata.draft_tokens.reshape( - num_gens, mtp_num_modules) + num_gens, runtime_draft_len) # Use base implementation for strict acceptance accepted_tokens, num_accepted_tokens = self._sample_and_accept_draft_tokens_base( @@ -855,16 +864,17 @@ def sample_and_accept_draft_tokens( num_accepted_tokens=num_accepted_tokens, num_gens=num_gens, num_contexts=num_contexts, - max_draft_len=mtp_num_modules, + max_draft_len=runtime_draft_len, ) return accepted_tokens, num_accepted_tokens def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, - attn_metadata: AttentionMetadata): + attn_metadata: AttentionMetadata, + spec_metadata: MTPSpecMetadata): self._prepare_attn_metadata_for_spec_dec(attn_metadata) batch_size = attn_metadata.num_seqs - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len num_contexts = attn_metadata.num_contexts attn_metadata._seq_lens[num_contexts:batch_size] -= 1 @@ -876,14 +886,14 @@ def change_attn_metadata(self, num_accepted_tokens: torch.Tensor, # buffer once the graph has been captured also - this will invalidate # the graph and force an expensive recapture. attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - mtp_num_modules + 1 - + runtime_draft_len + 1 - num_accepted_tokens[num_contexts:batch_size]) if attn_metadata.kv_cache_params is not None and not attn_metadata.is_cuda_graph: for i in range(num_contexts, batch_size): # used for vanilla MLA, list on cpu attn_metadata.kv_cache_params.num_cached_tokens_per_seq[ - i] -= mtp_num_modules + 1 - num_accepted_tokens[i].item() + i] -= runtime_draft_len + 1 - num_accepted_tokens[i].item() def prepare_drafter_inputs( self, @@ -901,8 +911,8 @@ def prepare_drafter_inputs( Args: input_ids: torch.IntTensor [num_tokens] - The input ids of all requests. Flattened. - num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1) + The input ids of all requests. Flatten. + num_tokens = sum(all prompts) + num_generation * (runtime_draft_len + 1) position_ids: torch.IntTensor [1][num_tokens] @@ -929,8 +939,8 @@ def prepare_drafter_inputs( Returns: draft_inputs input_ids: torch.Tensor [num_tokens] - The new input ids of all requests. Flattened. - num_tokens = sum(all prompts) + num_generation * (mtp_num_modules) + The new input ids of all requests. Flatten. + num_tokens = sum(all prompts) + num_generation * (runtime_draft_len) position_ids: torch.Tensor [1, num_tokens] @@ -954,7 +964,7 @@ def prepare_drafter_inputs( num_gens = batch_size - num_contexts mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool - mtp_num_modules = self.spec_config.num_nextn_predict_layers + runtime_draft_len = spec_metadata.runtime_draft_len if self.is_thop: # Temporary buffer @@ -976,7 +986,7 @@ def prepare_drafter_inputs( spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, hidden_states, accepted_tokens, num_accepted_tokens, return_input_ids, - return_hidden_states, mtp_num_modules, batch_size, + return_hidden_states, runtime_draft_len, batch_size, num_contexts, hidden_size) else: @@ -1001,11 +1011,18 @@ def prepare_drafter_inputs( accepted_tokens_gen = accepted_tokens[num_contexts:, :] input_ids_gen = accepted_tokens_gen[gen_batch_idx, gen_token_idx].unsqueeze(1) - input_ids_gen = torch.concat( - [mtp_past_tokens_pool[slot_ids][:, 1:], input_ids_gen], - dim=1) + + if runtime_draft_len > 1: + history_tokens = mtp_past_tokens_pool[slot_ids][:, -( + runtime_draft_len - 1):] + else: + history_tokens = torch.empty((num_gens, 0), + dtype=torch.int, + device=input_ids.device) + input_ids_gen = torch.concat([history_tokens, input_ids_gen], + dim=1) hidden_states_gen = mtp_past_hidden_states_pool[ - slot_ids].flatten(0, 1) + slot_ids][:, -runtime_draft_len:, :].flatten(0, 1) return_input_ids_list.append(input_ids_gen.flatten(0, 1)) return_hidden_states_list.append(hidden_states_gen) # Concatenate into continuous buffers @@ -1019,9 +1036,9 @@ def prepare_drafter_inputs( position_ids_list.append(position_ids[:num_ctx_tokens]) if num_gens > 0: position_ids_gen = position_ids[num_ctx_tokens:].reshape( - num_gens, mtp_num_modules + 1)[:, -mtp_num_modules:] + num_gens, runtime_draft_len + 1)[:, -runtime_draft_len:] position_ids_gen = position_ids_gen - ( - 1 + mtp_num_modules - + 1 + runtime_draft_len - num_accepted_tokens[num_contexts:].unsqueeze(1)) position_ids_list.append(position_ids_gen.flatten()) return_position_ids = torch.concat(position_ids_list, dim=-1) @@ -1142,6 +1159,12 @@ def forward( draft_model, resource_manager=None, ): + runtime_draft_len = spec_metadata.runtime_draft_len + # skip the draft forward if the runtime draft length is 0 + if runtime_draft_len == 0: + return self.skip_drafting(input_ids, position_ids, hidden_states, + logits, attn_metadata, spec_metadata, + draft_model) batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts @@ -1191,7 +1214,7 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): # Predict draft tokens next_draft_tokens = [] with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): - for i in range(self.mtp_num_modules): + for i in range(runtime_draft_len): if i == 0: hidden_states = draft_model.mtp_layers[0]( embed_tokens=draft_model.embed_tokens, @@ -1200,7 +1223,7 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): start_ids_gen = ( spec_metadata.batch_indices_cuda[:num_gens] * - (self.mtp_num_modules + 1)).long() + (runtime_draft_len + 1)).long() gather_ids_gen = (start_ids_gen + num_accepted_tokens[num_contexts:] - 1 + attn_metadata.num_ctx_tokens) @@ -1281,7 +1304,7 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): # update kv_lens_cuda if hasattr(attn_metadata, 'kv_lens_cuda'): attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - self.mtp_num_modules - + runtime_draft_len - num_accepted_tokens[num_contexts:]) attn_metadata.kv_lens_cuda[:num_contexts] += 1 # update metadata for flash mla @@ -1367,6 +1390,7 @@ def prepare_drafter_inputs( spec_metadata: MTPSpecMetadata, ): num_contexts = attn_metadata.num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len # context input_ids_ctx = self._prepare_context_input_ids( @@ -1374,7 +1398,8 @@ def prepare_drafter_inputs( accepted_tokens, num_contexts) # generation - input_ids_gen = accepted_tokens[num_contexts:, :].flatten() + input_ids_gen = accepted_tokens[num_contexts:, :runtime_draft_len + + 1].flatten() # get draft inputs input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0) diff --git a/tensorrt_llm/_torch/speculative/pard.py b/tensorrt_llm/_torch/speculative/pard.py index f4da0e6c9d2..872d07c9bbd 100644 --- a/tensorrt_llm/_torch/speculative/pard.py +++ b/tensorrt_llm/_torch/speculative/pard.py @@ -99,15 +99,6 @@ def __init__( def max_draft_len(self) -> int: return self.spec_config.max_draft_len - @property - def _draft_tokens_per_req(self) -> int: - """Total tokens per gen request in the draft forward. - - Uses 2K to fit all accepted tokens (up to K+1) plus K-1 mask tokens, - ensuring K unique predictions regardless of how many tokens were accepted. - """ - return 2 * self.max_draft_len - def _prepare_attn_metadata_for_pard(self, attn_metadata, spec_metadata): """ Save attn_metadata fields that PARD modifies during forward. @@ -190,13 +181,25 @@ def forward( num_gens = batch_size - num_contexts raw_logits = logits - K = self.max_draft_len + K = spec_metadata.runtime_draft_len + + if K == 0: + return self.skip_drafting( + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ) self._execute_guided_decoder_if_present(logits) # draft_tokens buffer has (2K-1) entries per gen request; extract the K real drafts if num_gens > 0: - draft_tokens = spec_metadata.draft_tokens.reshape(num_gens, 2 * K - 1)[:, :K] + draft_tokens = spec_metadata.draft_tokens[: num_gens * (2 * K - 1)] + draft_tokens = draft_tokens.reshape(num_gens, 2 * K - 1)[:, :K] else: draft_tokens = spec_metadata.draft_tokens.reshape(0, K) @@ -262,14 +265,13 @@ def forward( gen_start_idx = attn_metadata.num_ctx_tokens request_bases = ( - torch.arange(num_gens, dtype=torch.long, device="cuda") - * self._draft_tokens_per_req + torch.arange(num_gens, dtype=torch.long, device="cuda") * (2 * K) + gen_start_idx ) gen_num_accepted = num_accepted_tokens[num_contexts:batch_size].long() base_offsets = gen_num_accepted - 1 # M = bonus position - offsets = torch.arange(self.max_draft_len, dtype=torch.long, device="cuda") + offsets = torch.arange(K, dtype=torch.long, device="cuda") gen_gather_ids = ( request_bases.unsqueeze(1) + base_offsets.unsqueeze(1) + offsets.unsqueeze(0) @@ -281,7 +283,7 @@ def forward( ) vocab_size = gen_logits.shape[-1] - gen_logits = gen_logits.reshape(num_gens, self.max_draft_len, vocab_size) + gen_logits = gen_logits.reshape(num_gens, K, vocab_size) # Use torch.argmax directly to avoid cute_argmax stride issues d2t = getattr(draft_model.model, "d2t", None) @@ -384,6 +386,8 @@ def prepare_1st_drafter_inputs( num_contexts = attn_metadata.num_contexts batch_size = attn_metadata.num_seqs num_gens = batch_size - num_contexts + runtime_draft_len = spec_metadata.runtime_draft_len + total_tokens_per_req = 2 * runtime_draft_len if ( hasattr(self.spec_config, "mask_token_id") @@ -412,8 +416,6 @@ def prepare_1st_drafter_inputs( gen_num_accepted = num_accepted_tokens[num_contexts : num_contexts + num_gens] gen_accepted_tokens = accepted_tokens[num_contexts : num_contexts + num_gens, :] - total_tokens_per_req = self._draft_tokens_per_req # 2K - # Start with all mask tokens request_ids_2d = torch.full( (num_gens, total_tokens_per_req), @@ -452,9 +454,9 @@ def prepare_1st_drafter_inputs( - total_tokens_per_req ) else: - gen_pos_starts = position_ids[ - attn_metadata.num_ctx_tokens :: self._draft_tokens_per_req - ][:num_gens] + gen_pos_starts = position_ids[attn_metadata.num_ctx_tokens :: total_tokens_per_req][ + :num_gens + ] offsets = torch.arange(total_tokens_per_req, dtype=torch.int32, device="cuda") position_ids_gen = (gen_pos_starts.unsqueeze(1) + offsets.unsqueeze(0)).flatten() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 8f512dcab02..3122dd9dd68 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -860,6 +860,10 @@ def tokens_per_gen_step(self) -> int: """Total tokens per gen request in one spec dec iteration (including golden token).""" return 1 + self.max_total_draft_tokens + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """Total tokens per gen request for the current runtime draft length.""" + return 1 + runtime_draft_len + def num_capture_layers(self) -> int: return 0 @@ -1450,6 +1454,10 @@ def tokens_per_gen_step(self) -> int: """PARD needs 2K tokens per gen request: K+1 accepted + K-1 masks.""" return 2 * self.max_draft_len + def get_runtime_tokens_per_gen_step(self, runtime_draft_len: int) -> int: + """PARD needs 2K runtime tokens per gen request for logical draft length K.""" + return 1 if runtime_draft_len == 0 else 2 * runtime_draft_len + def supports_backend(self, backend: str) -> bool: return backend == "pytorch" diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 79fd6f1d65c..fe0026ad0ff 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -436,6 +436,43 @@ def test_pard_sa(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_pard_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig( + max_batch_size=500 + if draft_len_schedule or max_concurrency is not None else None), + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + pard_model_dir = f"{llm_models_root()}/PARD-Llama-3.2-1B" + pard_config = PARDDecodingConfig( + max_draft_len=max_draft_len, + speculative_model=pard_model_dir, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=pard_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_hopper def test_ngram(self): max_bs = 16 @@ -1688,6 +1725,83 @@ def test_bfloat16_mtp_sa(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_bfloat16_mtp_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig( + max_batch_size=500 + if draft_len_schedule or max_concurrency is not None else None) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device_memory(60000) + @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ + (False, True), + (True, False), + ]) + def test_bfloat16_mtp_eagle_dynamic_draft_len(self, enable_max_concurrency, + enable_draft_len_schedule): + max_concurrency = 100 if enable_max_concurrency else None + draft_len_schedule = { + 50: 4, + 200: 3, + 350: 2 + } if enable_draft_len_schedule else None + max_draft_len = 4 + cuda_graph_config = CudaGraphConfig( + max_batch_size=500 + if draft_len_schedule or max_concurrency is not None else None) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=cuda_graph_config, + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + # Force MTP-Eagle one-model path. + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=max_draft_len, + max_draft_len=max_draft_len, + use_mtp_vanilla=False, + mtp_eagle_one_model=True, + max_concurrency=max_concurrency, + draft_len_schedule=draft_len_schedule, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=False, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(4) @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler", diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 533803a0cbc..b8d5b790652 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -37,6 +37,8 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype @@ -94,6 +96,10 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_sched accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_scheduler[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-enable_chunked_prefill=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_scheduler[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-enable_chunked_prefill=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_2_model_mtp +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_mtp_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_mtp_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_mtp_eagle_dynamic_draft_len[enable_max_concurrency=False-enable_draft_len_schedule=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_mtp_eagle_dynamic_draft_len[enable_max_concurrency=True-enable_draft_len_schedule=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] diff --git a/tests/unittest/_torch/speculative/test_mtp.py b/tests/unittest/_torch/speculative/test_mtp.py index d3965f477c0..b15a832595b 100644 --- a/tests/unittest/_torch/speculative/test_mtp.py +++ b/tests/unittest/_torch/speculative/test_mtp.py @@ -312,6 +312,7 @@ def test_sample_and_accept_draft_tokens(self, test_case_name, max_total_draft_tokens=mtp_num_modules, mtp_num_modules=mtp_num_modules) spec_metadata.draft_tokens = draft_tokens + spec_metadata.runtime_draft_len = mtp_num_modules # mtp worker mtpworker = MTPWorker(spec_config) @@ -901,6 +902,7 @@ def test_mtp_update_mtp_hidden_states( spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.runtime_draft_len = num_nextn_predict_layers spec_metadata.prepare() mtpworker = MTPWorker(spec_config) @@ -1397,6 +1399,7 @@ def test_prepare_drafter_inputs( spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.runtime_draft_len = num_nextn_predict_layers spec_metadata.prepare() mtpworker = MTPWorker(spec_config)