@@ -338,6 +338,9 @@ def forward_batch_speculative_generation(
338338 )
339339 return logits_output , next_token_ids , None , bid , False , batch .spec_info
340340 else :
341+ # Clone seq_lens because it will be modified in-place by verify
342+ batch .seq_lens = batch .seq_lens .clone ()
343+
341344 with self .draft_tp_context (self .draft_model_runner .tp_group ):
342345 spec_info = self .draft (batch )
343346 logits_output , verify_output , can_run_cuda_graph = (
@@ -631,7 +634,7 @@ def verify(self, batch: ModelWorkerBatch, spec_info: EagleVerifyInput):
631634 res .accepted_indices
632635 ]
633636 logits_output .hidden_states = logits_output .hidden_states [res .accepted_indices ]
634- logits_output .accept_length = res .draft_input .accept_length
637+ logits_output .accept_length = res .draft_input .accept_length . clone ()
635638
636639 # Prepare the batch for the next draft forwards.
637640 batch .forward_mode = (
@@ -679,11 +682,6 @@ def forward_draft_extend(
679682 def forward_draft_extend_after_decode (self , batch : ModelWorkerBatch ):
680683 assert isinstance (batch .spec_info , EagleDraftInput )
681684 # Backup fields that will be modified in-place
682- seq_lens_backup = batch .seq_lens .clone ()
683- req_pool_indices_backup = batch .req_pool_indices
684- accept_length_backup = batch .spec_info .accept_length
685- return_logprob_backup = batch .return_logprob
686-
687685 input_is_idle = batch .forward_mode .is_idle ()
688686
689687 if not input_is_idle and batch .spec_info .verified_id .numel () == 0 :
@@ -756,15 +754,9 @@ def forward_draft_extend_after_decode(self, batch: ModelWorkerBatch):
756754
757755 self ._detect_nan_if_needed (logits_output )
758756
759- # Restore backup.
760- # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
761757 batch .forward_mode = (
762758 ForwardMode .DECODE if not input_is_idle else ForwardMode .IDLE
763759 )
764- batch .seq_lens = seq_lens_backup
765- batch .req_pool_indices = req_pool_indices_backup
766- batch .spec_info .accept_length = accept_length_backup
767- batch .return_logprob = return_logprob_backup
768760
769761 def capture_for_decode (
770762 self , logits_output : LogitsProcessorOutput , draft_input : EagleDraftInput
0 commit comments