[TRTLLM-10319][feat] Expand dynamic speculation to MTP and PARD.#12262
[TRTLLM-10319][feat] Expand dynamic speculation to MTP and PARD.#12262zheyuf wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis PR introduces dynamic, per-iteration draft-length handling for speculative decoding by propagating a new Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip You can disable the changed files summary in the walkthrough.Disable the |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
2366-2412:⚠️ Potential issue | 🔴 CriticalThe first tree-decoding step still assumes the configured max width.
This branch still requires
num_draft_tokens == spec_tree_manager.max_total_draft_tokensand appends the fullspec_dec_position_offsets[0]. Dynamic PARD batches below the configured maximum will fail here on the first generation step or warmup; if assertions are stripped,position_idsbecomes longer than1 + num_draft_tokens.Possible fix
if not self.is_draft_model and not spec_config.is_linear_tree: assert spec_tree_manager is not None - assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens + assert num_draft_tokens <= spec_tree_manager.max_total_draft_tokens position_ids.extend( past_seen_token_num + - spec_tree_manager.spec_dec_position_offsets[ - 0] # [max_total_draft_tokens + 1] + spec_tree_manager.spec_dec_position_offsets[0][ + :1 + num_draft_tokens] )The same runtime slice should be applied anywhere else that consumes
spec_dec_position_offsets[0].🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/model_engine.py` around lines 2366 - 2412, The code in the branch that handles tree decoding (inside model_engine where spec_tree_manager is used) assumes num_draft_tokens equals spec_tree_manager.max_total_draft_tokens and appends the entire spec_dec_position_offsets[0], which breaks dynamic PARD smaller-than-configured batches; remove the strict equality/assertion and instead extend position_ids with only the runtime slice of spec_tree_manager.spec_dec_position_offsets that corresponds to the actual tokens (e.g. use spec_tree_manager.spec_dec_position_offsets[0:1 + num_draft_tokens] or otherwise index up to 1 + num_draft_tokens) so position_ids length matches 1 + num_draft_tokens; apply the same runtime slicing anywhere else spec_dec_position_offsets[0] is consumed.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1652-1652: Consider moving constant definition outside the loop.
DRAFT_BUFFER_PADis redefined on each iteration of the for loop. While the performance impact is negligible, moving it before the loop (around line 1651) would be slightly cleaner.♻️ Suggested refactor
runtime_draft_len = get_draft_len_for_batch_size( self.model_engine.spec_config.draft_len_schedule, scheduled_batch.batch_size, self.model_engine.max_draft_len) # 2. Pad or truncate draft tokens to the resolved length - DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id. + DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id. for request in scheduled_batch.generation_requests: - DRAFT_BUFFER_PAD = 0 # Buffer sentinel, not PARD mask_token_id. current_num_draft_tokens = len(request.py_draft_tokens)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` at line 1652, DRAFT_BUFFER_PAD is being set inside the loop each iteration; pull the constant definition out of the loop by declaring DRAFT_BUFFER_PAD = 0 just once immediately before the enclosing for loop (so the loop body uses the already-defined symbol), ensuring any references inside the loop continue to use the same constant and no other logic changes are needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1602-1609: The code dereferences
spec_metadata.runtime_tokens_per_gen_step without guarding spec_metadata; update
the branch around runtime_draft_token_buffer_width calculation to first check
spec_metadata is not None (or explicitly enforce the precondition) and either
use a defined fixed-width fallback when spec_metadata is None or raise a clear
ValueError. Specifically, protect access to
spec_metadata.runtime_tokens_per_gen_step before computing
runtime_draft_token_buffer_width, then call
generate_spec_decoding_generation_length(runtime_draft_len=...), and compute
spec_decoding_position_offsets and spec_decoding_packed_mask only after
determining runtime_draft_token_buffer_width; reference spec_metadata,
runtime_tokens_per_gen_step, runtime_draft_token_buffer_width,
generate_spec_decoding_generation_length,
generate_spec_decoding_position_offsets, generate_spec_decoding_packed_mask, and
max_num_requests when making the guard or fallback change.
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 1202-1204: The warmup-sizing uses
get_runtime_tokens_per_gen_step(draft_len) with a value that may already be a
buffer width (e.g. self.max_total_draft_tokens for non-dynamic path), inflating
sizes; change the call sites so _get_graphs_to_capture / warmup sizing use the
logical draft length (K) not the buffer width (2K-1). Concretely, compute a
logical runtime_draft_len from draft_len or from self.max_total_draft_tokens by
converting buffer-width to K when needed, then pass that logical value into
get_runtime_tokens_per_gen_step and use it to compute
runtime_draft_token_buffer_width, update any places that set
self.runtime_draft_len, the warmup request, and KV budgeting to use this logical
runtime_draft_len (symbols to adjust: get_runtime_tokens_per_gen_step,
runtime_tokens_per_gen_step, runtime_draft_token_buffer_width,
_get_graphs_to_capture, self.max_total_draft_tokens, self.runtime_draft_len).
In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 285-287: Update the comment above the runtime_tokens_per_gen_step
variable to clarify the PARD edge case: explain that normally
runtime_tokens_per_gen_step equals 1 + runtime_draft_len, and for PARD it equals
2 * runtime_draft_len except when K=0 (in which case runtime_tokens_per_gen_step
is 1), referencing the PARD mode and the runtime_draft_len and K variables so
readers understand the K=0 special-case behavior for
runtime_tokens_per_gen_step.
In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Around line 601-609: The THOP branch calling
torch.ops.trtllm.mtp_update_hidden_states_op currently passes runtime_draft_len
which causes THOP to only retain a shortened MTP history; change the argument to
max_draft_len (self.spec_config.num_nextn_predict_layers) so THOP refreshes the
full MTP history window the same way the eager path does, ensuring both branches
update the same number of draft entries (compare the call in the is_thop block
and the eager update that uses max_draft_len).
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 439-445: The test function test_pard_dynamic_draft_len is missing
the Hopper-gating decorator; add the `@skip_pre_hopper` decorator immediately
above the function definition so it matches other PARD tests and will be skipped
on pre-Hopper runners; ensure the decorator is imported/available where other
tests use skip_pre_hopper so the new annotation compiles and is applied to
test_pard_dynamic_draft_len.
---
Outside diff comments:
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 2366-2412: The code in the branch that handles tree decoding
(inside model_engine where spec_tree_manager is used) assumes num_draft_tokens
equals spec_tree_manager.max_total_draft_tokens and appends the entire
spec_dec_position_offsets[0], which breaks dynamic PARD smaller-than-configured
batches; remove the strict equality/assertion and instead extend position_ids
with only the runtime slice of spec_tree_manager.spec_dec_position_offsets that
corresponds to the actual tokens (e.g. use
spec_tree_manager.spec_dec_position_offsets[0:1 + num_draft_tokens] or otherwise
index up to 1 + num_draft_tokens) so position_ids length matches 1 +
num_draft_tokens; apply the same runtime slicing anywhere else
spec_dec_position_offsets[0] is consumed.
---
Nitpick comments:
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Line 1652: DRAFT_BUFFER_PAD is being set inside the loop each iteration; pull
the constant definition out of the loop by declaring DRAFT_BUFFER_PAD = 0 just
once immediately before the enclosing for loop (so the loop body uses the
already-defined symbol), ensuring any references inside the loop continue to use
the same constant and no other logic changes are needed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d13b632c-a973-4679-97c2-b9e441e96b1d
📒 Files selected for processing (10)
tensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/pyexecutor/cuda_graph_runner.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/speculative/interface.pytensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/speculative/pard.pytensorrt_llm/llmapi/llm_args.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/qa/llm_function_core.txt
| # Total runtime tokens per generation request for the current iteration, | ||
| # Normally, it equals 1 + runtime_draft_len. But for PARD, it equals 2 * runtime_draft_len. | ||
| runtime_tokens_per_gen_step: int = 1 |
There was a problem hiding this comment.
Update the PARD runtime-token comment for K=0.
Line 286 states PARD uses 2 * runtime_draft_len, but the runtime behavior for K=0 is 1 token. The comment should reflect that edge case to avoid confusion.
Suggested doc fix
- # Normally, it equals 1 + runtime_draft_len. But for PARD, it equals 2 * runtime_draft_len.
+ # Normally, it equals 1 + runtime_draft_len.
+ # For PARD, it equals 2 * runtime_draft_len when runtime_draft_len > 0, else 1.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/interface.py` around lines 285 - 287, Update
the comment above the runtime_tokens_per_gen_step variable to clarify the PARD
edge case: explain that normally runtime_tokens_per_gen_step equals 1 +
runtime_draft_len, and for PARD it equals 2 * runtime_draft_len except when K=0
(in which case runtime_tokens_per_gen_step is 1), referencing the PARD mode and
the runtime_draft_len and K variables so readers understand the K=0 special-case
behavior for runtime_tokens_per_gen_step.
| runtime_draft_len = spec_metadata.runtime_draft_len | ||
| max_draft_len = self.spec_config.num_nextn_predict_layers | ||
|
|
||
| if self.is_thop: | ||
| _, _ = torch.ops.trtllm.mtp_update_hidden_states_op( | ||
| input_ids, seq_lens, hidden_states, | ||
| spec_metadata.mtp_hidden_states_ptrs, | ||
| spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens, | ||
| mtp_num_modules, batch_size, num_contexts, hidden_size) | ||
| runtime_draft_len, batch_size, num_contexts, hidden_size) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Python call sites =="
rg -n -C4 'mtp_update_hidden_states_op' tensorrt_llm
echo
echo "== Native op registration / implementation =="
rg -n -C6 'mtp_update_hidden_states(_op)?' . -g '!**/build/**' -g '!**/dist/**'Repository: NVIDIA/TensorRT-LLM
Length of output: 6303
🏁 Script executed:
sed -n '150,250p' cpp/tensorrt_llm/thop/specDecOp.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 4543
🏁 Script executed:
sed -n '610,660p' tensorrt_llm/_torch/speculative/mtp.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2889
🏁 Script executed:
sed -n '660,710p' tensorrt_llm/_torch/speculative/mtp.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2659
🏁 Script executed:
grep -rn "numMTPModules" cpp/tensorrt_llm --include="*.h" --include="*.cpp" -A2 -B2 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 8028
🏁 Script executed:
grep -rn "invokeMTPUpdateHiddenStates" cpp/tensorrt_llm --include="*.cu" -A30 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 5513
🏁 Script executed:
grep -n "mtpUpdateHiddenStatesKernel" cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu -A60 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 4929
🏁 Script executed:
grep -n "mtp_hidden_states_pool\|mtp_past_tokens_pool" tensorrt_llm/_torch/speculative/mtp.py -B2 -A2 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 1966
Verify THOP refreshes the full MTP history window.
The kernel code in mtpKernels.cu line 386 and 395 uses numMTPModules directly to control the history window size: for (int ii = 0; ii < numMTPModules; ii++). Since line 609 now passes runtime_draft_len as this parameter, THOP mode will maintain a history window of only runtime_draft_len entries. Meanwhile, the eager path at lines 674-676 always updates max_draft_len entries in the pools. When runtime_draft_len shrinks and later grows, THOP mode will have discarded history that eager mode preserved, causing a divergence.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/mtp.py` around lines 601 - 609, The THOP
branch calling torch.ops.trtllm.mtp_update_hidden_states_op currently passes
runtime_draft_len which causes THOP to only retain a shortened MTP history;
change the argument to max_draft_len (self.spec_config.num_nextn_predict_layers)
so THOP refreshes the full MTP history window the same way the eager path does,
ensuring both branches update the same number of draft entries (compare the call
in the is_thop block and the eager update that uses max_draft_len).
| @pytest.mark.skip_less_device_memory(60000) | ||
| @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ | ||
| (False, True), | ||
| (True, False), | ||
| ]) | ||
| def test_pard_dynamic_draft_len(self, enable_max_concurrency, | ||
| enable_draft_len_schedule): |
There was a problem hiding this comment.
Add Hopper gating for the new PARD dynamic-draft test.
test_pard_dynamic_draft_len is missing @skip_pre_hopper, unlike other PARD tests in this class. This can fail on unsupported pre-Hopper runners.
🔧 Suggested patch
+ `@skip_pre_hopper`
`@pytest.mark.skip_less_device_memory`(60000)
`@parametrize_with_ids`("enable_max_concurrency,enable_draft_len_schedule", [
(False, True),
(True, False),
])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 439 -
445, The test function test_pard_dynamic_draft_len is missing the Hopper-gating
decorator; add the `@skip_pre_hopper` decorator immediately above the function
definition so it matches other PARD tests and will be skipped on pre-Hopper
runners; ensure the decorator is imported/available where other tests use
skip_pre_hopper so the new annotation compiles and is applied to
test_pard_dynamic_draft_len.
|
PR_Github #39151 [ run ] triggered by Bot. Commit: |
|
PR_Github #39151 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #39310 [ run ] triggered by Bot. Commit: |
| return self.is_mtp_one_model() or self.is_eagle3_one_model( | ||
| ) or self.is_pard() |
There was a problem hiding this comment.
We should add draft/target support too
| task.evaluate(llm, extra_acc_spec="use_sa_spec") | ||
|
|
||
| @pytest.mark.skip_less_device_memory(60000) | ||
| @parametrize_with_ids("enable_max_concurrency,enable_draft_len_schedule", [ |
There was a problem hiding this comment.
whether there's a constraint on sm version?
|
PR_Github #39310 [ run ] completed with state
|
Summary by CodeRabbit
New Features
Tests
Description
This PR does two things:
Test Coverage
Added tests for MTP, MTP-Eagle, PARD on dynamic draft length and max conconcurrency control in
tests/integration/defs/accuracy/test_llm_api_pytorch.py.PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.