-
Notifications
You must be signed in to change notification settings - Fork 707
Open
Labels
Description
Hi! I was running into an issue with the BatchDecodeWithPagedKVCacheWrapper and trtllm-gen backend on B200 where it returns different results when played with cuda graphs then without.
I have attached a snippet of the cuda graph code as reference:
During init:
self._cuda_graph_kv_indptr = torch.zeros(
max_concurrent_requests + 1,
dtype=torch.int32,
device=self.device,
)
self._cuda_graph_kv_indices = torch.zeros(
max_concurrent_requests * max_num_pages,
dtype=torch.int32,
device=self.device,
)
self._cuda_graph_kv_last_page_len = torch.ones(
max_concurrent_requests,
dtype=torch.int32,
device=self.device,
)
During capture:
kv_indices=self._cuda_graph_kv_indices
kv_indptr=self._cuda_graph_kv_indptr[:batch_size + 1]
kv_last_page_len=self._cuda_graph_kv_last_page_len[:batch_size]
Before cuda graph replay:
self._cuda_graph_kv_indices[:total_pages].copy_(attention_data.kv_indices[:total_pages])
self._cuda_graph_kv_indptr[:batch_size + 1].copy_(attention_data.kv_indptr[:batch_size + 1])
self._cuda_graph_kv_last_page_len[:batch_size].copy_(attention_data.kv_last_page_len[:batch_size])
Questions:
- I noticed that in sglang which directly calls these kernels and does not use the wrapper, the block table is passed in during the capture phase. Is that a requirement for this backend?
Environment
Kernel: BatchDecodeWithPagedKVCacheWrapper
Batch size: 8
GPU: SM100 (B200)
Reactions are currently unavailable