Skip to content

Possible Bug in chain_speculative_sampling: output_emitted_token_num Exceeds Expected Limit #879

Open
@JaeminK

Description

I was testing chain_speculative_sampling and noticed that output_emitted_token_num sometimes exceeds the expected limit.
Specifically, in my test case, it should never be greater than 3, but in some instances, it returns 4, even though probability values for such a case don’t exist.

I used the following test script:

import torch
from flashinfer.sampling import chain_speculative_sampling

def run_chain_speculative_sampling(draft_probs, draft_token_ids, verify_probs):
    batch_size, num_draft_tokens = draft_token_ids.shape
    uniform_samples = torch.rand(batch_size, num_draft_tokens + 1, device=draft_token_ids.device)
    accepted_token_ids, output_accepted_token_num, output_emitted_token_num =\
        chain_speculative_sampling(draft_probs, draft_token_ids, uniform_samples, verify_probs)
    
    return accepted_token_ids, output_accepted_token_num, output_emitted_token_num

token_ids = torch.tensor([[ 1,  0,  1, -1, -1]], device='cuda:0')
draft_probs = torch.tensor([[[0.5883, 0.4117, 0.0000, 0.0000],
                             [1.0000, 0.0000, 0.0000, 0.0000],
                             [0.3803, 0.5435, 0.0762, 0.0000],
                             [0.0000, 0.0000, 0.0000, 0.0000],
                             [0.0000, 0.0000, 0.0000, 0.0000]]], device='cuda:0')

verify_probs = torch.tensor([[[0.4555, 0.5445, 0.0000, 0.0000],
                             [1.0000, 0.0000, 0.0000, 0.0000],
                             [0.5783, 0.2831, 0.0000, 0.1386],
                             [0.2346, 0.6850, 0.0804, 0.0000],
                             [0.0000, 0.0000, 0.0000, 0.0000],
                             [0.0000, 0.0000, 0.0000, 0.0000]]], device='cuda:0')

for i in range(10):
    accepted_token_ids, output_accepted_token_num, output_emitted_token_num = run_chain_speculative_sampling(draft_probs, token_ids, verify_probs)

    print(f"------------------------------------------------------")
    print(f"accepted_token_ids : {accepted_token_ids.squeeze(0).tolist()}")
    print(f"output_accepted_token_num : {output_accepted_token_num.item()}")
    print(f"output_emitted_token_num : {output_emitted_token_num.item()}")

Reason for having unnecessary extra tokens and probabilities is because I have detected abnormal output when I was testing batching. I have isolated the batch that was causling the problem, and could still generate the same behavior.

Here are some sample outputs from multiple runs:

------------------------------------------------------
accepted_token_ids : [1, 0, 3, -1, -1, -1]
output_accepted_token_num : 3
output_emitted_token_num : 2
------------------------------------------------------
accepted_token_ids : [1, 0, 1, -1, 3, -1] # Unexpected 3 at 4th index!
output_accepted_token_num : 4 # Unexpected!
output_emitted_token_num : 4  # Unexpected!
------------------------------------------------------
accepted_token_ids : [1, 0, 0, -1, -1, -1]
output_accepted_token_num : 3
output_emitted_token_num : 2
------------------------------------------------------
accepted_token_ids : [1, 0, 1, -1, 3, -1] # Unexpected 3 at 4th index!
output_accepted_token_num : 4 # Unexpected!
output_emitted_token_num : 4  # Unexpected!

My torch version is 2.4.0 with cuda version 12.1, and I've tested the case in both flashinfer 0.2.0 and 0.2.1

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions