Skip to content

[TRTLLM-11043][feat] Add global pool support for suffix automaton speculative decoding#12130

Open
cascade812 wants to merge 2 commits intoNVIDIA:mainfrom
cascade812:guiju/at2
Open

[TRTLLM-11043][feat] Add global pool support for suffix automaton speculative decoding#12130
cascade812 wants to merge 2 commits intoNVIDIA:mainfrom
cascade812:guiju/at2

Conversation

@cascade812
Copy link
Collaborator

@cascade812 cascade812 commented Mar 12, 2026

Description

Add cross-request pattern matching to the suffix automaton (SA) speculative decoding implementation via a new enable_global_pool option, allowing each request to search all active SA states for the longest match instead of only its own context.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added global suffix automaton pool for speculative decoding, enabling requests to share pattern matching across active states and improve token acceptance rates.
    • Added enable_global_pool configuration option to EAGLE3, PARD, MTP, and SA decoding configurations.
  • Tests

    • Added comprehensive test coverage for global pool functionality and cross-request pattern matching scenarios.

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
@cascade812 cascade812 requested review from a team as code owners March 12, 2026 00:21
@cascade812 cascade812 requested review from hchings and zheyuf March 12, 2026 00:21
@cascade812
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38644 [ run ] triggered by Bot. Commit: ecd0a15 Link to invocation

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 12, 2026

📝 Walkthrough

Walkthrough

This PR introduces a global pool search feature for suffix automaton-based speculative decoding. The implementation adds new CUDA kernels for cross-request pattern matching, refactors the suffix automaton header interface, and exposes global pool capabilities through Python configuration and manager APIs.

Changes

Cohort / File(s) Summary
Core Suffix Automaton Logic
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomaton.h
Added lookupWithSuffix() method for longest-suffix prefix matching (appears in duplicate blocks). Updated copyright year.
Suffix Automaton CUDA Kernels
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cu
Introduced suffixAutomatonGlobalExtendKernel and suffixAutomatonGlobalSearchKernel for cross-request SA operations with per-block reduction logic. Added public API invokeSuffixAutomatonGlobalSearch to orchestrate kernels and nextPowerOf2 utility for thread block sizing.
Suffix Automaton Header Refactoring
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.h
Removed legacy parameter structs and kernel invocation functions. Retained only minimal includes and added compatibility headers for CUDA utilities.
Suffix Automaton Parameters
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h
New file consolidating parameter structures (SuffixAutomatonExtendParams, SuffixAutomatonExtendNgramParams, SuffixAutomatonGlobalSearchParams) with validation helpers and public API declarations moved from other headers.
Python Nanobind Bindings
cpp/tensorrt_llm/nanobind/suffixAutomaton/bindings.cpp
Updated header includes to use suffixAutomatonParams.h. Added invoke_global_search() export for cross-request pattern sharing. Updated copyright year.
Suffix Automaton Manager Configuration
tensorrt_llm/_torch/speculative/suffix_automaton.py
Added enable_global_pool configuration option to SAConfig. Introduced global pool GPU buffers (_gpu_active_slot_mask, _gpu_match_slot) with lazy allocation and pending update tracking. Implemented extend_global() method for CUDA-graph-compatible global search across active SA states.
Speculative Decoding Flow Conditionals
tensorrt_llm/_torch/speculative/sa_enhancer.py, tensorrt_llm/_torch/speculative/sa_worker.py
Added conditional branching to use extend_global() when enable_global_pool is enabled; otherwise preserves original extend() or extend_ngram() paths. No changes to surrounding buffer management or data flow.
Decoding Configuration
tensorrt_llm/llmapi/llm_args.py
Added enable_global_pool boolean field (default False) to four decoding config classes: Eagle3DecodingConfig, SADecodingConfig, MTPDecodingConfig, and PARDDecodingConfig with descriptive documentation.
Integration Tests
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/integration/test_lists/qa/llm_function_core.txt
Added three new integration tests covering global pool variants: test_eagle3_sa_global_pool, test_pard_sa_global_pool, and test_bfloat16_mtp_sa_global_pool. Updated test registry with corresponding entries.
Unit Tests
tests/torch/speculative/test_suffix_automaton.py, tests/unittest/_torch/speculative/test_sa.py
Added TestExtendGlobal class with six comprehensive test methods covering cross-request matching, slot preference, no-match scenarios, active slot masking, single-request alignment, and CUDA graph compatibility. Added config validation and integration tests for enable_global_pool behavior.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • hchings
  • joyang-nv
  • zheyuf
  • mikeiovine
  • pcastonguay
  • StanleySun639
🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.81% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description explains what is being added (enable_global_pool option for cross-request pattern matching) and why, but the Test Coverage section is empty and the PR Checklist items are unchecked. Complete the Test Coverage section by listing specific test cases added (e.g., TestExtendGlobal, test_eagle3_sa_global_pool, test_pard_sa_global_pool, test_bfloat16_mtp_sa_global_pool, test_sa_config_global_pool, test_llama_sa_global_pool) and review checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main change: adding global pool support for suffix automaton speculative decoding, directly reflecting the feature described throughout the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (5)
tensorrt_llm/_torch/speculative/sa_worker.py (1)

284-299: Consider using underscore prefix for unused match_len in both branches.

The match_len return value is unused in both code paths. For consistency and to satisfy linters, prefix the unused variable with an underscore in both branches.

♻️ Proposed fix
         if sa_manager.enable_global_pool:
-            match_len, draft_tokens = sa_manager.extend_global(
+            _match_len, draft_tokens = sa_manager.extend_global(
                 request_ids,
                 accepted_tokens,
                 num_accepted_tokens,
                 max_draft_len,
                 max_ngram_size=self._max_matching_ngram_size,
             )
         else:
-            match_len, draft_tokens = sa_manager.extend_ngram(
+            _match_len, draft_tokens = sa_manager.extend_ngram(
                 request_ids,
                 accepted_tokens,
                 num_accepted_tokens,
                 max_draft_len,
                 max_ngram_size=self._max_matching_ngram_size,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/sa_worker.py` around lines 284 - 299, The
variable match_len returned by sa_manager.extend_global and
sa_manager.extend_ngram is unused; update both branches to unpack into
_match_len and draft_tokens instead of match_len and draft_tokens (inside the if
sa_manager.enable_global_pool block and the else block) to follow the
underscore-prefix convention for unused values and satisfy linters while keeping
the call arguments (request_ids, accepted_tokens, num_accepted_tokens,
max_draft_len, max_ngram_size=self._max_matching_ngram_size) unchanged.
tests/unittest/_torch/speculative/test_sa.py (2)

332-338: Add strict=True to zip() for safer iteration.

When comparing speculative vs reference outputs, using strict=True ensures both lists have the same length and will raise an error if they don't match, preventing silent truncation.

♻️ Proposed fix
     # Verify 1: Identical results (correctness)
     for i, (text_spec,
             text_ref) in enumerate(zip(generated_text_spec,
-                                       generated_text_ref)):
+                                       generated_text_ref, strict=True)):
         assert text_spec == text_ref, (
             f"Prompt {i}: Global pool spec decode differs from baseline.\n"
             f"Spec: {text_spec}\nRef: {text_ref}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/speculative/test_sa.py` around lines 332 - 338, The
loop over generated_text_spec and generated_text_ref should use zip(...,
strict=True) to ensure both sequences have the same length; update the for
statement in tests/unittest/_torch/speculative/test_sa.py (the loop that
iterates with for i, (text_spec, text_ref) in enumerate(zip(generated_text_spec,
generated_text_ref)):) to call zip(generated_text_spec, generated_text_ref,
strict=True) so a length mismatch raises immediately instead of silently
truncating.

347-347: Remove extraneous f-string prefix.

This string has no placeholders, so the f prefix is unnecessary.

♻️ Proposed fix
-    print(f"Global pool spec decoding stats:")
+    print("Global pool spec decoding stats:")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/speculative/test_sa.py` at line 347, The print
statement using an f-string with the literal text "Global pool spec decoding
stats:" should be changed to a plain string literal—remove the unnecessary f
prefix on the print call (the print call currently written as an f-string) so it
prints the same message without using f-string formatting.
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)

365-366: Differentiate SA and SA+global-pool in extra_acc_spec.

These three new cases still report as use_sa_spec, exactly like the existing non-global SA tests. That makes the harness/results indistinguishable between the two configurations and can hide config-specific regressions. Use a distinct label such as use_sa_spec,enable_global_pool.

Proposed fix
-            task.evaluate(llm, extra_acc_spec="use_sa_spec")
+            task.evaluate(llm,
+                          extra_acc_spec="use_sa_spec,enable_global_pool")
-            task.evaluate(llm, extra_acc_spec="use_sa_spec")
+            task.evaluate(llm,
+                          extra_acc_spec="use_sa_spec,enable_global_pool")
-            task.evaluate(llm, extra_acc_spec="use_sa_spec")
+            task.evaluate(llm,
+                          extra_acc_spec="use_sa_spec,enable_global_pool")

Also applies to: 448-449, 1667-1668

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 365 -
366, The test calls that construct GSM8K and then call task.evaluate(llm,
extra_acc_spec="use_sa_spec") should use a distinct extra_acc_spec when the
global pooling variant is enabled; update the extra_acc_spec argument in the
GSM8K test invocations (the lines that create GSM8K(self.MODEL_NAME) and call
task.evaluate) to "use_sa_spec,enable_global_pool" for the global-pool variants
so results are distinguishable from the plain SA runs, and apply the same change
to the other two occurrences that currently pass "use_sa_spec".
cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h (1)

26-26: Use the repo's standard include guard for this new header.

#pragma once does not match the required TRTLLM_<FILENAME_IN_CAPS>_H guard pattern for .h files.

As per coding guidelines, "Use a preprocessor guard in C++ header files with the format TRTLLM_<FILENAME_IN_CAPS>_H (e.g., TRTLLM_FOO_BAR_HELLO_H). Do not use directory names or trailing underscores".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h`
at line 26, Replace the non-conforming `#pragma` once in suffixAutomatonParams.h
with the repo standard include guard: add a preprocessor guard using the symbol
TRTLLM_SUFFIXAUTOMATONPARAMS_H (define if not already defined and end with
`#endif`), ensuring the guard name is the uppercase filename without directories
or trailing underscores and wraps the entire header contents; update any
existing header opening/closing to use TRTLLM_SUFFIXAUTOMATONPARAMS_H so the
file no longer uses `#pragma` once.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h`:
- Line 28: The include "tensorrt_llm/common/assert.h" in suffixAutomatonParams.h
is missing — replace it with the header that actually exports TLLM_CHECK (e.g.,
the project's public assert header that defines TLLM_CHECK) or remove the
include if TLLM_CHECK is already provided by an existing include (such as
config.h); update the include in suffixAutomatonParams.h so references to
TLLM_CHECK compile successfully.
- Around line 143-205: The checkParams() method in
SuffixAutomatonGlobalSearchParams must validate maxNgramSize to prevent
shared-memory overflow and reject 0; update
SuffixAutomatonGlobalSearchParams::checkParams() to assert that maxNgramSize is
either -1 (longest-match mode) or in the range [1, kMaxGlobalSuffixLen], and add
a clear TLLM_CHECK_WITH_INFO error message when maxNgramSize == 0 or
maxNgramSize > kMaxGlobalSuffixLen (referencing kMaxGlobalSuffixLen and
maxNgramSize in the message) so the kernel's fixed-ngram copy into
sharedSuffix[64] cannot overflow.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 423-425: The test test_pard_sa_global_pool is being skipped (via
`@pytest.mark.skip`) because PARD has an accuracy bug when batch size > 1, but
global-pool behavior only matters with batch>1, so we must either fix the
batching bug or remove PARD from the feature path: locate the PARD-related
batching/forward code (search for symbols like test_pard_sa_global_pool,
skip_pre_hopper and the PARD/global-pool implementation functions such as
pard_forward or pard_sa_global_pool handler) and either (A) fix the batching
accuracy bug by correcting how multiple requests are aggregated/pooled (ensure
correct indexing, activation pooling and gradient/attention accumulation for
batch sizes >1), then re-enable the test by removing the `@pytest.mark.skip`, or
(B) revert/guard the PARD global-pool code path behind a feature flag and
disable it so the new PARD behavior is not exercised until the batch>1 fix is
implemented; pick one of these two actions and update tests accordingly.

In `@tests/torch/speculative/test_suffix_automaton.py`:
- Around line 663-689: The test test_extend_global_no_match must assert that
draft_tokens is zeroed when match_len == 0 to ensure no stale tokens are
returned from SuffixAutomatonManager.extend_global; update the test to verify
draft_tokens contains only zeros for the no-match rows (mirror the assertions
used in test_extend_ngram_no_match), and remove or replace any unused local
variables triggering Ruff warnings (e.g., drop unused placeholders or use the
asserted values) so the test both checks the zeroed draft buffer and eliminates
the unused-variable warning.

---

Nitpick comments:
In
`@cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h`:
- Line 26: Replace the non-conforming `#pragma` once in suffixAutomatonParams.h
with the repo standard include guard: add a preprocessor guard using the symbol
TRTLLM_SUFFIXAUTOMATONPARAMS_H (define if not already defined and end with
`#endif`), ensuring the guard name is the uppercase filename without directories
or trailing underscores and wraps the entire header contents; update any
existing header opening/closing to use TRTLLM_SUFFIXAUTOMATONPARAMS_H so the
file no longer uses `#pragma` once.

In `@tensorrt_llm/_torch/speculative/sa_worker.py`:
- Around line 284-299: The variable match_len returned by
sa_manager.extend_global and sa_manager.extend_ngram is unused; update both
branches to unpack into _match_len and draft_tokens instead of match_len and
draft_tokens (inside the if sa_manager.enable_global_pool block and the else
block) to follow the underscore-prefix convention for unused values and satisfy
linters while keeping the call arguments (request_ids, accepted_tokens,
num_accepted_tokens, max_draft_len,
max_ngram_size=self._max_matching_ngram_size) unchanged.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 365-366: The test calls that construct GSM8K and then call
task.evaluate(llm, extra_acc_spec="use_sa_spec") should use a distinct
extra_acc_spec when the global pooling variant is enabled; update the
extra_acc_spec argument in the GSM8K test invocations (the lines that create
GSM8K(self.MODEL_NAME) and call task.evaluate) to
"use_sa_spec,enable_global_pool" for the global-pool variants so results are
distinguishable from the plain SA runs, and apply the same change to the other
two occurrences that currently pass "use_sa_spec".

In `@tests/unittest/_torch/speculative/test_sa.py`:
- Around line 332-338: The loop over generated_text_spec and generated_text_ref
should use zip(..., strict=True) to ensure both sequences have the same length;
update the for statement in tests/unittest/_torch/speculative/test_sa.py (the
loop that iterates with for i, (text_spec, text_ref) in
enumerate(zip(generated_text_spec, generated_text_ref)):) to call
zip(generated_text_spec, generated_text_ref, strict=True) so a length mismatch
raises immediately instead of silently truncating.
- Line 347: The print statement using an f-string with the literal text "Global
pool spec decoding stats:" should be changed to a plain string literal—remove
the unnecessary f prefix on the print call (the print call currently written as
an f-string) so it prints the same message without using f-string formatting.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2db5188f-a3ac-43bf-bf30-d9e5b9c7b731

📥 Commits

Reviewing files that changed from the base of the PR and between be20657 and ecd0a15.

📒 Files selected for processing (13)
  • cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomaton.h
  • cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.h
  • cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h
  • cpp/tensorrt_llm/nanobind/suffixAutomaton/bindings.cpp
  • tensorrt_llm/_torch/speculative/sa_enhancer.py
  • tensorrt_llm/_torch/speculative/sa_worker.py
  • tensorrt_llm/_torch/speculative/suffix_automaton.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/torch/speculative/test_suffix_automaton.py
  • tests/unittest/_torch/speculative/test_sa.py

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38644 [ run ] completed with state SUCCESS. Commit: ecd0a15
/LLM/main/L0_MergeRequest_PR pipeline #29974 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
@cascade812
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38789 [ run ] triggered by Bot. Commit: d447f57 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38789 [ run ] completed with state SUCCESS. Commit: d447f57
/LLM/main/L0_MergeRequest_PR pipeline #30103 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@cascade812
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38825 [ run ] triggered by Bot. Commit: d447f57 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38825 [ run ] completed with state SUCCESS. Commit: d447f57
/LLM/main/L0_MergeRequest_PR pipeline #30137 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Collaborator

@zheyuf zheyuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Looks good to me. Just left some minor comments.

Comment on lines +1209 to 1219
"Limitations: at most 1024 concurrent slots; suffix matching is "
"capped at 64 tokens per request.")

@model_validator(mode='after')
def validate_sa_config(self):
if self.max_matching_ngram_size == 0:
raise ValueError(
"max_matching_ngram_size must be > 0 (fixed ngram) or -1 (longest match). "
"Got 0.")
if self.max_draft_len is None or self.max_draft_len <= 0:
raise ValueError("max_draft_len must be > 0 for SA")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to validate self.max_matching_ngram_size <=64 when enable_global_pool is True, since it says suffix matching is capped at 64 tokens per request.

Otherwise config like

SADecodingConfig(
    max_draft_len=4,
    enable_global_pool=True,
    max_matching_ngram_size=128,
)

will seems to be accepted at construction time and fail later at runtime in extend_global().

Comment on lines +336 to +337
matchLenOut[reqIdx] = 0;
matchSlotOut[reqIdx] = -1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In no match branch, it doesn't clear out the draftTokensOut, which could potentially cause the draft tokens to be stale.

In sa_worker.py:

 if sa_manager.enable_global_pool:
            match_len, draft_tokens = sa_manager.extend_global(...)
        else:
            match_len, draft_tokens = sa_manager.extend_ngram(...)
        return draft_tokens

Is it possible that it will directly return the stale draft tokens without any gating?

Comment on lines +285 to +293
match_len, draft_tokens = sa_manager.extend_global(
request_ids,
accepted_tokens,
num_accepted_tokens,
max_draft_len,
max_ngram_size=self._max_matching_ngram_size,
)
else:
match_len, draft_tokens = sa_manager.extend_ngram(
Copy link
Collaborator

@zheyuf zheyuf Mar 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match_len is not used. Is it better to change to _, draft_tokens = sa_manager.extend_global() ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants