Skip to content

[Bug] Cuda Graph Issues with TRTLLM-GEN Backend with BatchDecodeWithPagedKVCacheWrapper #2511

@NihalPotdar

Description

@NihalPotdar

Hi! I was running into an issue with the BatchDecodeWithPagedKVCacheWrapper and trtllm-gen backend on B200 where it returns different results when played with cuda graphs then without.

I have attached a snippet of the cuda graph code as reference:

During init:

self._cuda_graph_kv_indptr = torch.zeros(
    max_concurrent_requests + 1,
    dtype=torch.int32,
    device=self.device,
)

self._cuda_graph_kv_indices = torch.zeros(
    max_concurrent_requests * max_num_pages,
    dtype=torch.int32,
    device=self.device,
)

self._cuda_graph_kv_last_page_len = torch.ones(
    max_concurrent_requests,
    dtype=torch.int32,
    device=self.device,
)

During capture:

kv_indices=self._cuda_graph_kv_indices
kv_indptr=self._cuda_graph_kv_indptr[:batch_size + 1]
kv_last_page_len=self._cuda_graph_kv_last_page_len[:batch_size]

Before cuda graph replay:

self._cuda_graph_kv_indices[:total_pages].copy_(attention_data.kv_indices[:total_pages])
self._cuda_graph_kv_indptr[:batch_size + 1].copy_(attention_data.kv_indptr[:batch_size + 1])
self._cuda_graph_kv_last_page_len[:batch_size].copy_(attention_data.kv_last_page_len[:batch_size])

Questions:

  • I noticed that in sglang which directly calls these kernels and does not use the wrapper, the block table is passed in during the capture phase. Is that a requirement for this backend?

Environment

Kernel: BatchDecodeWithPagedKVCacheWrapper
Batch size: 8
GPU: SM100 (B200)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions