Skip to content

Commit f5e1e77

Browse files
committed
[jit_kernel] Use uint8_t* instead of bool* for tree_mask in ngram_utils kernel
1 parent 949e7f3 commit f5e1e77

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/sglang/jit_kernel/csrc/speculative/ngram_utils.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
// retrive_next_token: [bs, draft_token_num]
3131
// retrive_next_sibling: [bs, draft_token_num]
3232
__global__ void reconstructIndicesFromTreeMask(
33-
bool* tree_mask,
33+
uint8_t* tree_mask,
3434
int64_t* verified_seq_len,
3535
int64_t* positions,
3636
int64_t* retrive_index,
@@ -171,7 +171,7 @@ void reconstruct_indices_from_tree_mask(
171171

172172
LaunchKernel(grid, block, stream)(
173173
reconstructIndicesFromTreeMask,
174-
static_cast<bool*>(tree_mask.data_ptr()),
174+
static_cast<uint8_t*>(tree_mask.data_ptr()),
175175
static_cast<int64_t*>(verified_seq_len.data_ptr()),
176176
static_cast<int64_t*>(positions.data_ptr()),
177177
static_cast<int64_t*>(retrive_index.data_ptr()),

0 commit comments

Comments
 (0)