@@ -390,7 +390,8 @@ def _verify_trees_with_forward(
390390
391391 # Extract verification results - 现在返回 valid_lengths
392392 verified_tokens , kv_cache_position_ids , llm_generated_tokens , valid_lengths = self ._extract_best_verified_paths_fixed (
393- logits , batch_node_paths , input_ids , logits_processor , tree_tokens .shape [1 ], seq_lengths , is_first_iteration
393+ logits , batch_node_paths , input_ids , logits_processor , tree_tokens .shape [1 ], seq_lengths , is_first_iteration ,
394+ acceptance_top_k = 3
394395 )
395396 return verified_tokens , kv_cache_position_ids , new_past_key_values , llm_generated_tokens , valid_lengths
396397
@@ -582,6 +583,7 @@ def _extract_best_verified_paths_fixed(
582583 tree_len : int ,
583584 seq_lengths : torch .LongTensor ,
584585 is_first_iteration : bool ,
586+ acceptance_top_k : int = 1 ,
585587 ) -> Tuple [Optional [torch .Tensor ], torch .Tensor , torch .Tensor , torch .Tensor ]:
586588 """
587589 Returns:
@@ -621,9 +623,17 @@ def _extract_best_verified_paths_fixed(
621623 if pos >= seq_len :
622624 break
623625
624- predicted_token = torch .argmax (logits [batch_idx , pos ]).item ()
626+ if acceptance_top_k == 1 :
627+ predicted_token = torch .argmax (logits [batch_idx , pos ]).item ()
628+ accepted = (predicted_token == node .token_id )
629+ else :
630+ top_k_tokens = torch .topk (
631+ logits [batch_idx , pos ],
632+ k = acceptance_top_k
633+ ).indices .tolist ()
634+ accepted = (node .token_id in top_k_tokens )
625635
626- if predicted_token == node . token_id :
636+ if accepted :
627637 verified_tokens .append (node .token_id )
628638 absolute_position = tree_root_position + node .position_in_sequence + 1
629639 verified_positions .append (absolute_position )
0 commit comments