Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tensorrt_llm/_torch/speculative/one_model_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def apply_temperature(
return logits.div_(temp.unsqueeze(dim=1))


@torch.compile(options={"max-autotune": True})
def sampling_batch_spec_dec_one_model(
logits: torch.Tensor,
temperatures: torch.Tensor,
Expand All @@ -84,11 +83,14 @@ def sampling_batch_spec_dec_one_model(
offset: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
CUDA-graph compatible sampling. Supports mixed sampling params.

We can't do dynamic kernel selection inside graphs, so this might
be slower than a torch.argmax for greedy requests. This is why advanced
sampling is opt-in for now.
Sampling for speculative decoding. Supports mixed sampling params.

NOTE: torch.compile is intentionally omitted. With TP > 1 each rank
runs in a separate process; compiled sampling can produce slightly
different results across ranks (e.g. different Triton kernel selections).
In spec dec, sampled tokens determine acceptance counts and therefore
the batch shape of subsequent draft-model NCCL collectives — divergent
tokens cause a deadlock.
"""
logits = apply_temperature(logits, temperatures)
if use_flashinfer:
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5596343)
unittest/_torch/speculative/test_eagle3.py::test_llama_eagle3_dynamic_tree[True-False] SKIP (https://nvbugs/6113021)
perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_grace_blackwell-r1_fp4_v2_tep4_mtp3_1k1k] SKIP (https://nvbugs/6114727)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[v2_kv_cache-trtllm-one_model-no_overlap_scheduler] SKIP (https://nvbugs/6114821)
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] SKIP (https://nvbugs/6120535)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] SKIP (https://nvbugs/6110074)
test_doc.py::test_url_validity SKIP (https://nvbugs/6109719)
Expand Down
Loading