[TRTLLM-11043][feat] Add global pool support for suffix automaton speculative decoding#12130
[TRTLLM-11043][feat] Add global pool support for suffix automaton speculative decoding#12130cascade812 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
|
/bot run |
|
PR_Github #38644 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR introduces a global pool search feature for suffix automaton-based speculative decoding. The implementation adds new CUDA kernels for cross-request pattern matching, refactors the suffix automaton header interface, and exposes global pool capabilities through Python configuration and manager APIs. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (5)
tensorrt_llm/_torch/speculative/sa_worker.py (1)
284-299: Consider using underscore prefix for unusedmatch_lenin both branches.The
match_lenreturn value is unused in both code paths. For consistency and to satisfy linters, prefix the unused variable with an underscore in both branches.♻️ Proposed fix
if sa_manager.enable_global_pool: - match_len, draft_tokens = sa_manager.extend_global( + _match_len, draft_tokens = sa_manager.extend_global( request_ids, accepted_tokens, num_accepted_tokens, max_draft_len, max_ngram_size=self._max_matching_ngram_size, ) else: - match_len, draft_tokens = sa_manager.extend_ngram( + _match_len, draft_tokens = sa_manager.extend_ngram( request_ids, accepted_tokens, num_accepted_tokens, max_draft_len, max_ngram_size=self._max_matching_ngram_size, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/sa_worker.py` around lines 284 - 299, The variable match_len returned by sa_manager.extend_global and sa_manager.extend_ngram is unused; update both branches to unpack into _match_len and draft_tokens instead of match_len and draft_tokens (inside the if sa_manager.enable_global_pool block and the else block) to follow the underscore-prefix convention for unused values and satisfy linters while keeping the call arguments (request_ids, accepted_tokens, num_accepted_tokens, max_draft_len, max_ngram_size=self._max_matching_ngram_size) unchanged.tests/unittest/_torch/speculative/test_sa.py (2)
332-338: Addstrict=Truetozip()for safer iteration.When comparing speculative vs reference outputs, using
strict=Trueensures both lists have the same length and will raise an error if they don't match, preventing silent truncation.♻️ Proposed fix
# Verify 1: Identical results (correctness) for i, (text_spec, text_ref) in enumerate(zip(generated_text_spec, - generated_text_ref)): + generated_text_ref, strict=True)): assert text_spec == text_ref, ( f"Prompt {i}: Global pool spec decode differs from baseline.\n" f"Spec: {text_spec}\nRef: {text_ref}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/speculative/test_sa.py` around lines 332 - 338, The loop over generated_text_spec and generated_text_ref should use zip(..., strict=True) to ensure both sequences have the same length; update the for statement in tests/unittest/_torch/speculative/test_sa.py (the loop that iterates with for i, (text_spec, text_ref) in enumerate(zip(generated_text_spec, generated_text_ref)):) to call zip(generated_text_spec, generated_text_ref, strict=True) so a length mismatch raises immediately instead of silently truncating.
347-347: Remove extraneous f-string prefix.This string has no placeholders, so the
fprefix is unnecessary.♻️ Proposed fix
- print(f"Global pool spec decoding stats:") + print("Global pool spec decoding stats:")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/speculative/test_sa.py` at line 347, The print statement using an f-string with the literal text "Global pool spec decoding stats:" should be changed to a plain string literal—remove the unnecessary f prefix on the print call (the print call currently written as an f-string) so it prints the same message without using f-string formatting.tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
365-366: Differentiate SA and SA+global-pool inextra_acc_spec.These three new cases still report as
use_sa_spec, exactly like the existing non-global SA tests. That makes the harness/results indistinguishable between the two configurations and can hide config-specific regressions. Use a distinct label such asuse_sa_spec,enable_global_pool.Proposed fix
- task.evaluate(llm, extra_acc_spec="use_sa_spec") + task.evaluate(llm, + extra_acc_spec="use_sa_spec,enable_global_pool")- task.evaluate(llm, extra_acc_spec="use_sa_spec") + task.evaluate(llm, + extra_acc_spec="use_sa_spec,enable_global_pool")- task.evaluate(llm, extra_acc_spec="use_sa_spec") + task.evaluate(llm, + extra_acc_spec="use_sa_spec,enable_global_pool")Also applies to: 448-449, 1667-1668
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 365 - 366, The test calls that construct GSM8K and then call task.evaluate(llm, extra_acc_spec="use_sa_spec") should use a distinct extra_acc_spec when the global pooling variant is enabled; update the extra_acc_spec argument in the GSM8K test invocations (the lines that create GSM8K(self.MODEL_NAME) and call task.evaluate) to "use_sa_spec,enable_global_pool" for the global-pool variants so results are distinguishable from the plain SA runs, and apply the same change to the other two occurrences that currently pass "use_sa_spec".cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h (1)
26-26: Use the repo's standard include guard for this new header.
#pragma oncedoes not match the requiredTRTLLM_<FILENAME_IN_CAPS>_Hguard pattern for.hfiles.As per coding guidelines, "Use a preprocessor guard in C++ header files with the format
TRTLLM_<FILENAME_IN_CAPS>_H(e.g.,TRTLLM_FOO_BAR_HELLO_H). Do not use directory names or trailing underscores".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h` at line 26, Replace the non-conforming `#pragma` once in suffixAutomatonParams.h with the repo standard include guard: add a preprocessor guard using the symbol TRTLLM_SUFFIXAUTOMATONPARAMS_H (define if not already defined and end with `#endif`), ensuring the guard name is the uppercase filename without directories or trailing underscores and wraps the entire header contents; update any existing header opening/closing to use TRTLLM_SUFFIXAUTOMATONPARAMS_H so the file no longer uses `#pragma` once.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h`:
- Line 28: The include "tensorrt_llm/common/assert.h" in suffixAutomatonParams.h
is missing — replace it with the header that actually exports TLLM_CHECK (e.g.,
the project's public assert header that defines TLLM_CHECK) or remove the
include if TLLM_CHECK is already provided by an existing include (such as
config.h); update the include in suffixAutomatonParams.h so references to
TLLM_CHECK compile successfully.
- Around line 143-205: The checkParams() method in
SuffixAutomatonGlobalSearchParams must validate maxNgramSize to prevent
shared-memory overflow and reject 0; update
SuffixAutomatonGlobalSearchParams::checkParams() to assert that maxNgramSize is
either -1 (longest-match mode) or in the range [1, kMaxGlobalSuffixLen], and add
a clear TLLM_CHECK_WITH_INFO error message when maxNgramSize == 0 or
maxNgramSize > kMaxGlobalSuffixLen (referencing kMaxGlobalSuffixLen and
maxNgramSize in the message) so the kernel's fixed-ngram copy into
sharedSuffix[64] cannot overflow.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 423-425: The test test_pard_sa_global_pool is being skipped (via
`@pytest.mark.skip`) because PARD has an accuracy bug when batch size > 1, but
global-pool behavior only matters with batch>1, so we must either fix the
batching bug or remove PARD from the feature path: locate the PARD-related
batching/forward code (search for symbols like test_pard_sa_global_pool,
skip_pre_hopper and the PARD/global-pool implementation functions such as
pard_forward or pard_sa_global_pool handler) and either (A) fix the batching
accuracy bug by correcting how multiple requests are aggregated/pooled (ensure
correct indexing, activation pooling and gradient/attention accumulation for
batch sizes >1), then re-enable the test by removing the `@pytest.mark.skip`, or
(B) revert/guard the PARD global-pool code path behind a feature flag and
disable it so the new PARD behavior is not exercised until the batch>1 fix is
implemented; pick one of these two actions and update tests accordingly.
In `@tests/torch/speculative/test_suffix_automaton.py`:
- Around line 663-689: The test test_extend_global_no_match must assert that
draft_tokens is zeroed when match_len == 0 to ensure no stale tokens are
returned from SuffixAutomatonManager.extend_global; update the test to verify
draft_tokens contains only zeros for the no-match rows (mirror the assertions
used in test_extend_ngram_no_match), and remove or replace any unused local
variables triggering Ruff warnings (e.g., drop unused placeholders or use the
asserted values) so the test both checks the zeroed draft buffer and eliminates
the unused-variable warning.
---
Nitpick comments:
In
`@cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h`:
- Line 26: Replace the non-conforming `#pragma` once in suffixAutomatonParams.h
with the repo standard include guard: add a preprocessor guard using the symbol
TRTLLM_SUFFIXAUTOMATONPARAMS_H (define if not already defined and end with
`#endif`), ensuring the guard name is the uppercase filename without directories
or trailing underscores and wraps the entire header contents; update any
existing header opening/closing to use TRTLLM_SUFFIXAUTOMATONPARAMS_H so the
file no longer uses `#pragma` once.
In `@tensorrt_llm/_torch/speculative/sa_worker.py`:
- Around line 284-299: The variable match_len returned by
sa_manager.extend_global and sa_manager.extend_ngram is unused; update both
branches to unpack into _match_len and draft_tokens instead of match_len and
draft_tokens (inside the if sa_manager.enable_global_pool block and the else
block) to follow the underscore-prefix convention for unused values and satisfy
linters while keeping the call arguments (request_ids, accepted_tokens,
num_accepted_tokens, max_draft_len,
max_ngram_size=self._max_matching_ngram_size) unchanged.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 365-366: The test calls that construct GSM8K and then call
task.evaluate(llm, extra_acc_spec="use_sa_spec") should use a distinct
extra_acc_spec when the global pooling variant is enabled; update the
extra_acc_spec argument in the GSM8K test invocations (the lines that create
GSM8K(self.MODEL_NAME) and call task.evaluate) to
"use_sa_spec,enable_global_pool" for the global-pool variants so results are
distinguishable from the plain SA runs, and apply the same change to the other
two occurrences that currently pass "use_sa_spec".
In `@tests/unittest/_torch/speculative/test_sa.py`:
- Around line 332-338: The loop over generated_text_spec and generated_text_ref
should use zip(..., strict=True) to ensure both sequences have the same length;
update the for statement in tests/unittest/_torch/speculative/test_sa.py (the
loop that iterates with for i, (text_spec, text_ref) in
enumerate(zip(generated_text_spec, generated_text_ref)):) to call
zip(generated_text_spec, generated_text_ref, strict=True) so a length mismatch
raises immediately instead of silently truncating.
- Line 347: The print statement using an f-string with the literal text "Global
pool spec decoding stats:" should be changed to a plain string literal—remove
the unnecessary f prefix on the print call (the print call currently written as
an f-string) so it prints the same message without using f-string formatting.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 2db5188f-a3ac-43bf-bf30-d9e5b9c7b731
📒 Files selected for processing (13)
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomaton.hcpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cucpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.hcpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.hcpp/tensorrt_llm/nanobind/suffixAutomaton/bindings.cpptensorrt_llm/_torch/speculative/sa_enhancer.pytensorrt_llm/_torch/speculative/sa_worker.pytensorrt_llm/_torch/speculative/suffix_automaton.pytensorrt_llm/llmapi/llm_args.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/qa/llm_function_core.txttests/torch/speculative/test_suffix_automaton.pytests/unittest/_torch/speculative/test_sa.py
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cu
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h
Show resolved
Hide resolved
|
PR_Github #38644 [ run ] completed with state
|
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
|
/bot run |
|
PR_Github #38789 [ run ] triggered by Bot. Commit: |
|
PR_Github #38789 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38825 [ run ] triggered by Bot. Commit: |
|
PR_Github #38825 [ run ] completed with state |
| "Limitations: at most 1024 concurrent slots; suffix matching is " | ||
| "capped at 64 tokens per request.") | ||
|
|
||
| @model_validator(mode='after') | ||
| def validate_sa_config(self): | ||
| if self.max_matching_ngram_size == 0: | ||
| raise ValueError( | ||
| "max_matching_ngram_size must be > 0 (fixed ngram) or -1 (longest match). " | ||
| "Got 0.") | ||
| if self.max_draft_len is None or self.max_draft_len <= 0: | ||
| raise ValueError("max_draft_len must be > 0 for SA") |
There was a problem hiding this comment.
Do we need to validate self.max_matching_ngram_size <=64 when enable_global_pool is True, since it says suffix matching is capped at 64 tokens per request.
Otherwise config like
SADecodingConfig(
max_draft_len=4,
enable_global_pool=True,
max_matching_ngram_size=128,
)
will seems to be accepted at construction time and fail later at runtime in extend_global().
| matchLenOut[reqIdx] = 0; | ||
| matchSlotOut[reqIdx] = -1; |
There was a problem hiding this comment.
In no match branch, it doesn't clear out the draftTokensOut, which could potentially cause the draft tokens to be stale.
In sa_worker.py:
if sa_manager.enable_global_pool:
match_len, draft_tokens = sa_manager.extend_global(...)
else:
match_len, draft_tokens = sa_manager.extend_ngram(...)
return draft_tokens
Is it possible that it will directly return the stale draft tokens without any gating?
| match_len, draft_tokens = sa_manager.extend_global( | ||
| request_ids, | ||
| accepted_tokens, | ||
| num_accepted_tokens, | ||
| max_draft_len, | ||
| max_ngram_size=self._max_matching_ngram_size, | ||
| ) | ||
| else: | ||
| match_len, draft_tokens = sa_manager.extend_ngram( |
There was a problem hiding this comment.
match_len is not used. Is it better to change to _, draft_tokens = sa_manager.extend_global() ?
Description
Add cross-request pattern matching to the suffix automaton (SA) speculative decoding implementation via a new
enable_global_pooloption, allowing each request to search all active SA states for the longest match instead of only its own context.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Release Notes
New Features
enable_global_poolconfiguration option to EAGLE3, PARD, MTP, and SA decoding configurations.Tests