Skip to content

[https://nvbugs/6114821][fix] Remove torch.compile from spec dec sampling to prevent NCCL deadlock#13552

Open
tensorrt-cicd wants to merge 2 commits intoNVIDIA:mainfrom
tensorrt-cicd:repair-bot-bug6114821
Open

[https://nvbugs/6114821][fix] Remove torch.compile from spec dec sampling to prevent NCCL deadlock#13552
tensorrt-cicd wants to merge 2 commits intoNVIDIA:mainfrom
tensorrt-cicd:repair-bot-bug6114821

Conversation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

@tensorrt-cicd tensorrt-cicd commented Apr 28, 2026

Summary

  • Root cause: With TP > 1, each rank runs torch.compile independently on the speculative decoding sampling function, which can select different Triton kernel implementations across ranks. This non-determinism causes divergent sampling outputs, leading to mismatched draft token acceptance counts and therefore mismatched batch shapes for subsequent NCCL collectives, resulting in a deadlock.
  • Fix: Removed the @torch.compile(options={"max-autotune": True}) decorator from sampling_batch_spec_dec_one_model in one_model_sampler.py and added a comment explaining why compilation must be avoided in this code path. The corresponding test waiver for the Eagle3 4-GPU accuracy test was also removed since the fix resolves the underlying failure.
  • Automated fix generated by repair-bot

Test plan

  • Verify fix on the same GPU type as the original failure
  • Check for regressions in related tests

Links

Summary by CodeRabbit

  • Bug Fixes

    • Improved speculative decoding reliability in multi-GPU deployments by removing a compilation optimization that could cause kernel divergence across GPU ranks, ensuring consistent token sampling and acceptance behavior.
  • Tests

    • Enabled a previously waived speculative decoding test case to increase validation coverage for multi-GPU tensor parallel configurations.

…prevent NCCL deadlock

With non-greedy sampling (temperature > 0) in one-model speculative decoding
with TP > 1, torch.compile on the sampling function causes different compiled
code on different ranks (each rank compiles in a separate process). This
produces different sampling results across ranks, which diverges the acceptance
counts. Since acceptance counts determine the batch shape of subsequent
draft-model forward passes containing NCCL collectives, divergent tokens
cause an NCCL deadlock.

Fix: remove torch.compile from sampling_batch_spec_dec_one_model so all
ranks execute identical eager-mode code. Also remove the waiver for the
affected test.

Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 28, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 949837c9-262a-4946-ad39-a468e09b89ef

📥 Commits

Reviewing files that changed from the base of the PR and between 4e830bb and 60a30e4.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/speculative/one_model_sampler.py
  • tests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt

📝 Walkthrough

Walkthrough

The @torch.compile decorator is removed from the sampling_batch_spec_dec_one_model function in the speculative decoding sampler, and the docstring is updated to explain that compilation is intentionally omitted due to kernel selection divergence in multi-process settings. A corresponding test waive entry is removed from the integration test list.

Changes

Cohort / File(s) Summary
Speculative Decoding Sampler
tensorrt_llm/_torch/speculative/one_model_sampler.py
Removed @torch.compile(options={"max-autotune": True}) decorator from sampling_batch_spec_dec_one_model function. Updated docstring to document speculative decoding semantics and explain that compilation is intentionally omitted to prevent kernel selection divergence across ranks in multi-process (TP > 1) settings, which can lead to divergent sampled tokens and NCCL collective shape mismatches.
Integration Test Waives
tests/integration/test_lists/waives.txt
Removed waive entry for a specific GPTOSS Eagle3 4-GPU test configuration (v2_kv_cache + trtllm + one_model + no_overlap_scheduler), allowing this test to run without being skipped.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title clearly and specifically summarizes the main change: removing torch.compile from speculative decoding sampling to fix an NCCL deadlock issue, directly matching the core modification in the changeset.
Description check ✅ Passed The pull request description includes comprehensive coverage with root cause analysis, the specific fix applied, test verification, and relevant links, though it lacks explicit mapping to the template sections.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch repair-bot-bug6114821

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Collaborator

@ziyixiong-nv ziyixiong-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix and explanation look reasonable to me. + @mikeiovine in case you have any concerns.

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator

/bot run

Signed-off-by: Ziyi Xiong <219238287+ziyixiong-nv@users.noreply.github.com>
@ziyixiong-nv
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator Author

PR_Github #45912 [ run ] triggered by Bot. Commit: 7b48826 Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants