Open
Description
I was testing chain_speculative_sampling and noticed that output_emitted_token_num sometimes exceeds the expected limit.
Specifically, in my test case, it should never be greater than 3, but in some instances, it returns 4, even though probability values for such a case don’t exist.
I used the following test script:
import torch
from flashinfer.sampling import chain_speculative_sampling
def run_chain_speculative_sampling(draft_probs, draft_token_ids, verify_probs):
batch_size, num_draft_tokens = draft_token_ids.shape
uniform_samples = torch.rand(batch_size, num_draft_tokens + 1, device=draft_token_ids.device)
accepted_token_ids, output_accepted_token_num, output_emitted_token_num =\
chain_speculative_sampling(draft_probs, draft_token_ids, uniform_samples, verify_probs)
return accepted_token_ids, output_accepted_token_num, output_emitted_token_num
token_ids = torch.tensor([[ 1, 0, 1, -1, -1]], device='cuda:0')
draft_probs = torch.tensor([[[0.5883, 0.4117, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000, 0.0000],
[0.3803, 0.5435, 0.0762, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000]]], device='cuda:0')
verify_probs = torch.tensor([[[0.4555, 0.5445, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000, 0.0000],
[0.5783, 0.2831, 0.0000, 0.1386],
[0.2346, 0.6850, 0.0804, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000]]], device='cuda:0')
for i in range(10):
accepted_token_ids, output_accepted_token_num, output_emitted_token_num = run_chain_speculative_sampling(draft_probs, token_ids, verify_probs)
print(f"------------------------------------------------------")
print(f"accepted_token_ids : {accepted_token_ids.squeeze(0).tolist()}")
print(f"output_accepted_token_num : {output_accepted_token_num.item()}")
print(f"output_emitted_token_num : {output_emitted_token_num.item()}")
Reason for having unnecessary extra tokens and probabilities is because I have detected abnormal output when I was testing batching. I have isolated the batch that was causling the problem, and could still generate the same behavior.
Here are some sample outputs from multiple runs:
------------------------------------------------------
accepted_token_ids : [1, 0, 3, -1, -1, -1]
output_accepted_token_num : 3
output_emitted_token_num : 2
------------------------------------------------------
accepted_token_ids : [1, 0, 1, -1, 3, -1] # Unexpected 3 at 4th index!
output_accepted_token_num : 4 # Unexpected!
output_emitted_token_num : 4 # Unexpected!
------------------------------------------------------
accepted_token_ids : [1, 0, 0, -1, -1, -1]
output_accepted_token_num : 3
output_emitted_token_num : 2
------------------------------------------------------
accepted_token_ids : [1, 0, 1, -1, 3, -1] # Unexpected 3 at 4th index!
output_accepted_token_num : 4 # Unexpected!
output_emitted_token_num : 4 # Unexpected!
My torch version is 2.4.0 with cuda version 12.1, and I've tested the case in both flashinfer 0.2.0 and 0.2.1
Metadata
Assignees
Labels
No labels
Activity