Skip to content

Commit 1dd6ffa

Browse files
authored
Merge pull request #8 from modal-labs/timmy/fix-seq-lens-race
Fix seq_lens race
2 parents 7cb5193 + 5af7f68 commit 1dd6ffa

2 files changed

Lines changed: 9 additions & 13 deletions

File tree

python/sglang/srt/managers/scheduler_output_processor_mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,12 @@ def process_batch_result_decode(
231231
self.token_to_kv_pool_allocator.free(free_cache_loc_cpu.to("cuda", non_blocking=True))
232232

233233
if self.spec_algorithm.is_eagle():
234+
# TODO (timmy): when does this happen?
235+
if batch.seq_lens is not None:
236+
batch.seq_lens.add_(logits_output.accept_length + 1)
237+
234238
accept_length = logits_output.accept_length.tolist()
235-
idx_to_batch = [i for i, length in enumerate(accept_length) for _ in range(length)]
239+
idx_to_batch = [i for i, length in enumerate(accept_length) for _ in range(length + 1)]
236240
else:
237241
idx_to_batch = list(range(len(batch.reqs)))
238242

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ def forward_batch_speculative_generation(
338338
)
339339
return logits_output, next_token_ids, None, bid, False, batch.spec_info
340340
else:
341+
# Clone seq_lens because it will be modified in-place by verify
342+
batch.seq_lens = batch.seq_lens.clone()
343+
341344
with self.draft_tp_context(self.draft_model_runner.tp_group):
342345
spec_info = self.draft(batch)
343346
logits_output, verify_output, can_run_cuda_graph = (
@@ -631,7 +634,7 @@ def verify(self, batch: ModelWorkerBatch, spec_info: EagleVerifyInput):
631634
res.accepted_indices
632635
]
633636
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
634-
logits_output.accept_length = res.draft_input.accept_length
637+
logits_output.accept_length = res.draft_input.accept_length.clone()
635638

636639
# Prepare the batch for the next draft forwards.
637640
batch.forward_mode = (
@@ -679,11 +682,6 @@ def forward_draft_extend(
679682
def forward_draft_extend_after_decode(self, batch: ModelWorkerBatch):
680683
assert isinstance(batch.spec_info, EagleDraftInput)
681684
# Backup fields that will be modified in-place
682-
seq_lens_backup = batch.seq_lens.clone()
683-
req_pool_indices_backup = batch.req_pool_indices
684-
accept_length_backup = batch.spec_info.accept_length
685-
return_logprob_backup = batch.return_logprob
686-
687685
input_is_idle = batch.forward_mode.is_idle()
688686

689687
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
@@ -756,15 +754,9 @@ def forward_draft_extend_after_decode(self, batch: ModelWorkerBatch):
756754

757755
self._detect_nan_if_needed(logits_output)
758756

759-
# Restore backup.
760-
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
761757
batch.forward_mode = (
762758
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
763759
)
764-
batch.seq_lens = seq_lens_backup
765-
batch.req_pool_indices = req_pool_indices_backup
766-
batch.spec_info.accept_length = accept_length_backup
767-
batch.return_logprob = return_logprob_backup
768760

769761
def capture_for_decode(
770762
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput

0 commit comments

Comments
 (0)