Skip to content

Commit 83cd13e

Browse files
author
Xu Xiong
committed
set sd accepted when it's top3
1 parent ada668f commit 83cd13e

1 file changed

Lines changed: 13 additions & 3 deletions

File tree

src/bloombee/models/llama/speculative_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ def _verify_trees_with_forward(
390390

391391
# Extract verification results - 现在返回 valid_lengths
392392
verified_tokens, kv_cache_position_ids, llm_generated_tokens, valid_lengths = self._extract_best_verified_paths_fixed(
393-
logits, batch_node_paths, input_ids, logits_processor, tree_tokens.shape[1], seq_lengths, is_first_iteration
393+
logits, batch_node_paths, input_ids, logits_processor, tree_tokens.shape[1], seq_lengths, is_first_iteration,
394+
acceptance_top_k = 3
394395
)
395396
return verified_tokens, kv_cache_position_ids, new_past_key_values, llm_generated_tokens, valid_lengths
396397

@@ -582,6 +583,7 @@ def _extract_best_verified_paths_fixed(
582583
tree_len: int,
583584
seq_lengths: torch.LongTensor,
584585
is_first_iteration: bool,
586+
acceptance_top_k: int = 1,
585587
) -> Tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
586588
"""
587589
Returns:
@@ -621,9 +623,17 @@ def _extract_best_verified_paths_fixed(
621623
if pos >= seq_len:
622624
break
623625

624-
predicted_token = torch.argmax(logits[batch_idx, pos]).item()
626+
if acceptance_top_k == 1:
627+
predicted_token = torch.argmax(logits[batch_idx, pos]).item()
628+
accepted = (predicted_token == node.token_id)
629+
else:
630+
top_k_tokens = torch.topk(
631+
logits[batch_idx, pos],
632+
k=acceptance_top_k
633+
).indices.tolist()
634+
accepted = (node.token_id in top_k_tokens)
625635

626-
if predicted_token == node.token_id:
636+
if accepted:
627637
verified_tokens.append(node.token_id)
628638
absolute_position = tree_root_position + node.position_in_sequence + 1
629639
verified_positions.append(absolute_position)

0 commit comments

Comments
 (0)