Skip to content

Commit 5c527ff

Browse files
Support overlap scheduling for speculative decoding
Co-authored-by: Nathan Wang <nathan.r.wang@gmail.com>
1 parent 6b39f9c commit 5c527ff

12 files changed

+2605
-47
lines changed

python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
from http import HTTPStatus
56
from typing import TYPE_CHECKING
67

@@ -146,14 +147,31 @@ def process_prebuilt_extend(
146147
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
147148

148149
# local import to avoid circular import
149-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
150+
if (
151+
os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0")
152+
== "1"
153+
):
154+
from sglang.srt.speculative.eagle_utils_for_overlap_scheduler import (
155+
EagleDraftInput,
156+
)
157+
158+
spec_info = EagleDraftInput(
159+
topk_p=topk_p,
160+
topk_index=topk_index,
161+
hidden_states=hidden_states,
162+
verified_id=self.output_ids,
163+
spec_steps=server_args.speculative_num_steps,
164+
)
165+
else:
166+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
167+
168+
spec_info = EagleDraftInput(
169+
topk_p=topk_p,
170+
topk_index=topk_index,
171+
hidden_states=hidden_states,
172+
verified_id=self.output_ids,
173+
)
150174

151-
spec_info = EagleDraftInput(
152-
topk_p=topk_p,
153-
topk_index=topk_index,
154-
hidden_states=hidden_states,
155-
verified_id=self.output_ids,
156-
)
157175
spec_info.prepare_for_extend(self)
158176
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
159177
self.spec_info = spec_info

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Optional, Union
56

@@ -13,7 +14,14 @@
1314
from sglang.srt.managers.schedule_batch import global_server_args_dict
1415
from sglang.srt.mem_cache.memory_pool import SWAKVPool
1516
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
17+
18+
if os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0") == "1":
19+
from sglang.srt.speculative.eagle_utils_for_overlap_scheduler import (
20+
EagleDraftInput,
21+
EagleVerifyInput,
22+
)
23+
else:
24+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
1725

1826
if TYPE_CHECKING:
1927
from sglang.srt.layers.radix_attention import RadixAttention
@@ -1894,7 +1902,10 @@ def init_forward_metadata_replay_cuda_graph(
18941902
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
18951903
)
18961904
accept_length = spec_info.accept_length[:bs]
1897-
if spec_info.accept_length_cpu:
1905+
if getattr(spec_info, "spec_steps", None) is not None:
1906+
# EAGLE + Overlap scheduling code path
1907+
metadata.max_seq_len_q = spec_info.spec_steps + 1
1908+
elif spec_info.accept_length_cpu:
18981909
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
18991910
else:
19001911
metadata.max_seq_len_q = 1

python/sglang/srt/layers/logits_processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class LogitsProcessorOutput:
5959
# Used by speculative decoding (EAGLE)
6060
# The last hidden layers
6161
hidden_states: Optional[torch.Tensor] = None
62+
# Used by speculative decoding (EAGLE) + overlap scheduling
63+
# Speculative accept lengths
64+
accept_length: Optional[torch.Tensor] = None
6265

6366
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
6467
# The logprobs of the next tokens. shape: [#seq]

python/sglang/srt/managers/schedule_batch.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@
108108
"quantization",
109109
"enable_custom_logit_processor",
110110
"disaggregation_mode",
111+
"speculative_num_steps",
112+
"speculative_eagle_topk",
113+
"speculative_num_draft_tokens",
111114
]
112115

113116
# Put some global args for easy access
@@ -903,6 +906,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
903906
# Speculative decoding
904907
spec_algorithm: SpeculativeAlgorithm = None
905908
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
909+
# Used for EAGLE + Overlap scheduling only. Stores the temporary draft output token locations.
910+
draft_out_cache_loc: Optional[torch.Tensor] = None
906911

907912
# Whether to return hidden states
908913
return_hidden_states: bool = False
@@ -1540,7 +1545,22 @@ def prepare_for_decode(self):
15401545
self.forward_mode = ForwardMode.DECODE
15411546
bs = len(self.reqs)
15421547

1543-
if self.spec_algorithm.is_eagle():
1548+
if self.enable_overlap and self.spec_algorithm.is_eagle():
1549+
assert (
1550+
self.token_to_kv_pool_allocator.page_size == 1
1551+
), "Eagle + Overlap Scheduler currently only supports page size 1"
1552+
self.draft_out_cache_loc, backup_state = self.alloc_token_slots(
1553+
bs
1554+
* global_server_args_dict["speculative_num_steps"]
1555+
* global_server_args_dict["speculative_eagle_topk"],
1556+
backup_state=True,
1557+
)
1558+
self.token_to_kv_pool_allocator.restore_state(backup_state)
1559+
self.out_cache_loc = self.alloc_token_slots(
1560+
bs * global_server_args_dict["speculative_num_draft_tokens"]
1561+
)
1562+
return
1563+
elif self.spec_algorithm.is_eagle():
15441564
# if spec decoding is used, the decode batch is prepared inside
15451565
# `forward_batch_speculative_generation` after running draft models.
15461566
return
@@ -1648,11 +1668,20 @@ def filter_batch(
16481668
if self.multimodal_inputs is not None:
16491669
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
16501670
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1671+
1672+
if self.spec_algorithm.is_eagle() and self.enable_overlap:
1673+
# In eagle overlap mode, seq_lens is mutated in the EagleWorkerClient's forward_stream,
1674+
# but we copy seq_lens in the scheduler's stream. This is a problem because seq_lens may
1675+
# not have been mutated by EagleWorkerClient before the scheduler stream starts making
1676+
# a copy of it. To avoid this, we synchronize all streams before copying seq_lens.
1677+
torch.cuda.synchronize()
1678+
16511679
self.seq_lens = self.seq_lens[keep_indices_device]
16521680
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
16531681
self.out_cache_loc = None
16541682
self.seq_lens_sum = self.seq_lens.sum().item()
1655-
self.output_ids = self.output_ids[keep_indices_device]
1683+
if self.output_ids is not None:
1684+
self.output_ids = self.output_ids[keep_indices_device]
16561685
self.return_logprob = any(req.return_logprob for req in self.reqs)
16571686
if self.return_logprob:
16581687
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
@@ -1766,6 +1795,7 @@ def get_model_worker_batch(
17661795
token_type_ids=self.token_type_ids,
17671796
spec_algorithm=self.spec_algorithm,
17681797
spec_info=self.spec_info,
1798+
draft_out_cache_loc=self.draft_out_cache_loc,
17691799
hicache_consumer_index=self.hicache_consumer_index,
17701800
capture_hidden_mode=(
17711801
CaptureHiddenMode.FULL
@@ -1918,6 +1948,8 @@ class ModelWorkerBatch:
19181948
# If set, the output of the batch contains the hidden states of the run.
19191949
capture_hidden_mode: CaptureHiddenMode = None
19201950
hicache_consumer_index: int = 0
1951+
# Used for EAGLE + Overlap scheduling only. Stores the temporary draft output token locations.
1952+
draft_out_cache_loc: Optional[torch.Tensor] = None
19211953

19221954
# Overlap event
19231955
launch_done: Optional[threading.Event] = None

python/sglang/srt/managers/scheduler.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __init__(
317317
logger.info("Overlap scheduler is disabled for embedding models.")
318318

319319
# Launch a tensor parallel worker
320-
if self.enable_overlap:
320+
if self.enable_overlap and not self.spec_algorithm.is_eagle():
321321
TpWorkerClass = TpModelWorkerClient
322322
else:
323323
TpWorkerClass = TpModelWorker
@@ -334,9 +334,16 @@ def __init__(
334334

335335
# Launch a draft worker for speculative decoding
336336
if self.spec_algorithm.is_eagle():
337-
from sglang.srt.speculative.eagle_worker import EAGLEWorker
337+
if self.enable_overlap:
338+
from sglang.srt.speculative.eagle_worker_overlap_thread import (
339+
EAGLEWorkerClient as EAGLEWorkerClass,
340+
)
341+
else:
342+
from sglang.srt.speculative.eagle_worker import (
343+
EAGLEWorker as EAGLEWorkerClass,
344+
)
338345

339-
self.draft_worker = EAGLEWorker(
346+
self.draft_worker = EAGLEWorkerClass(
340347
gpu_id=gpu_id,
341348
tp_rank=tp_rank,
342349
moe_ep_rank=moe_ep_rank,
@@ -820,15 +827,25 @@ def event_loop_overlap(self):
820827
tmp_batch = ScheduleBatch(
821828
reqs=None,
822829
forward_mode=ForwardMode.DUMMY_FIRST,
823-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
830+
next_batch_sampling_info=(
831+
self.draft_worker.cur_sampling_info
832+
if self.enable_overlap and self.spec_algorithm.is_eagle()
833+
else self.tp_worker.cur_sampling_info
834+
),
824835
)
825836
self.process_batch_result(tmp_batch, None, batch.launch_done)
826837

827838
if self.last_batch:
828839
# Process the results of the last batch
829840
tmp_batch, tmp_result = self.result_queue.popleft()
830841
tmp_batch.next_batch_sampling_info = (
831-
self.tp_worker.cur_sampling_info if batch else None
842+
(
843+
self.draft_worker.cur_sampling_info
844+
if self.enable_overlap and self.spec_algorithm.is_eagle()
845+
else self.tp_worker.cur_sampling_info
846+
)
847+
if batch
848+
else None
832849
)
833850
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
834851
self.process_batch_result(
@@ -1789,6 +1806,38 @@ def run_batch(
17891806
self.tp_worker.forward_batch_generation(model_worker_batch)
17901807
)
17911808
bid = model_worker_batch.bid
1809+
elif self.enable_overlap:
1810+
if batch.has_grammar:
1811+
raise NotImplementedError(
1812+
"Grammar + EAGLE + Overlap is not supported for now"
1813+
)
1814+
1815+
model_worker_batch = batch.get_model_worker_batch()
1816+
if self.enable_overlap:
1817+
# TODO (timmy): Do not alias seq_lens between forward and scheduler threads.
1818+
# Optimistically estimate the seq_lens_cpu for the next draft forward
1819+
model_worker_batch.seq_lens_cpu.add_(
1820+
self.server_args.speculative_num_steps + 1
1821+
)
1822+
1823+
# Populate fields needed to reuse batch for verify
1824+
model_worker_batch.extend_seq_lens = batch.extend_lens
1825+
model_worker_batch.extend_prefix_lens = batch.prefix_lens
1826+
model_worker_batch.extend_logprob_start_lens = (
1827+
batch.extend_logprob_start_lens
1828+
)
1829+
1830+
(
1831+
logits_output,
1832+
next_token_ids,
1833+
free_cache_loc_cpu,
1834+
bid,
1835+
can_run_cuda_graph,
1836+
next_spec_info,
1837+
) = self.draft_worker.forward_batch_speculative_generation(
1838+
model_worker_batch
1839+
)
1840+
batch.spec_info = next_spec_info
17921841
else:
17931842
(
17941843
logits_output,

python/sglang/srt/managers/scheduler_output_processor_mixin.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,14 @@ def process_batch_result_prefill(
5151
)
5252

5353
if self.enable_overlap:
54-
logits_output, next_token_ids, _ = (
55-
self.tp_worker.resolve_last_batch_result(launch_done)
56-
)
54+
if self.spec_algorithm.is_eagle():
55+
logits_output, next_token_ids, _, _, _, _ = (
56+
self.draft_worker.resolve_last_batch_result(launch_done)
57+
)
58+
else:
59+
logits_output, next_token_ids, _ = (
60+
self.tp_worker.resolve_last_batch_result(launch_done)
61+
)
5762
else:
5863
# Move next_token_ids and logprobs to cpu
5964
next_token_ids = next_token_ids.tolist()
@@ -205,9 +210,25 @@ def process_batch_result_decode(
205210
self.num_generated_tokens += len(batch.reqs)
206211

207212
if self.enable_overlap:
208-
logits_output, next_token_ids, can_run_cuda_graph = (
209-
self.tp_worker.resolve_last_batch_result(launch_done)
210-
)
213+
if self.spec_algorithm.is_eagle():
214+
(
215+
logits_output,
216+
next_token_ids,
217+
free_cache_loc_cpu,
218+
# Note: It's important we use out_cache_loc here and not batch.out_cache_loc.
219+
# out_cache_loc stores the out cache locations for the accepted tokens in
220+
# the target verify step, which is what we want. However, batch.out_cache_loc
221+
# contains the out cache locations for all tokens. If we use that, we will end
222+
# up freeing the wrong locations when we free extra delayed tokens in specdec.
223+
out_cache_loc,
224+
_,
225+
can_run_cuda_graph,
226+
) = self.draft_worker.resolve_last_batch_result(launch_done)
227+
else:
228+
logits_output, next_token_ids, can_run_cuda_graph = (
229+
self.tp_worker.resolve_last_batch_result(launch_done)
230+
)
231+
out_cache_loc = batch.out_cache_loc
211232
next_token_logprobs = logits_output.next_token_logprobs
212233
elif batch.spec_algorithm.is_none():
213234
# spec decoding handles output logprobs inside verify process.
@@ -217,28 +238,46 @@ def process_batch_result_decode(
217238

218239
self.token_to_kv_pool_allocator.free_group_begin()
219240

241+
if self.enable_overlap and self.spec_algorithm.is_eagle():
242+
if free_cache_loc_cpu is not None:
243+
free_cache_loc_cpu = free_cache_loc_cpu[free_cache_loc_cpu != 0]
244+
self.token_to_kv_pool_allocator.free(
245+
free_cache_loc_cpu.to("cuda", non_blocking=True)
246+
)
247+
248+
accept_length = logits_output.accept_length.tolist()
249+
idx_to_batch = [
250+
i for i, length in enumerate(accept_length) for _ in range(length + 1)
251+
]
252+
253+
num_generated_tokens_this_batch = len(idx_to_batch)
254+
self.num_generated_tokens += num_generated_tokens_this_batch
255+
self.spec_num_total_accepted_tokens += num_generated_tokens_this_batch
256+
self.spec_num_total_forward_ct += len(batch.reqs)
257+
else:
258+
idx_to_batch = list(range(len(batch.reqs)))
259+
220260
# Check finish condition
221261
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
222262
# We should ignore using next_token_ids for spec decoding cases.
223-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
263+
for i, (b, next_token_id) in enumerate(zip(idx_to_batch, next_token_ids)):
264+
req = batch.reqs[b]
224265
if req.is_retracted:
225266
continue
226267

227268
if self.enable_overlap and req.finished():
228269
# Free the one extra delayed token
229270
if self.page_size == 1:
230-
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
271+
self.token_to_kv_pool_allocator.free(out_cache_loc[i : i + 1])
231272
else:
232273
# Only free when the extra token is in a new page
233274
if (
234275
len(req.origin_input_ids) + len(req.output_ids) - 1
235276
) % self.page_size == 0:
236-
self.token_to_kv_pool_allocator.free(
237-
batch.out_cache_loc[i : i + 1]
238-
)
277+
self.token_to_kv_pool_allocator.free(out_cache_loc[i : i + 1])
239278
continue
240279

241-
if batch.spec_algorithm.is_none():
280+
if batch.spec_algorithm.is_none() or self.enable_overlap:
242281
# speculative worker will solve the output_ids in speculative decoding
243282
req.output_ids.append(next_token_id)
244283

@@ -247,8 +286,10 @@ def process_batch_result_decode(
247286
self.tree_cache.cache_finished_req(req)
248287
req.time_stats.completion_time = time.time()
249288

250-
if req.return_logprob and batch.spec_algorithm.is_none():
251-
# speculative worker handles logprob in speculative decoding
289+
if req.return_logprob and (
290+
batch.spec_algorithm.is_none() or self.enable_overlap
291+
):
292+
# non-overlap speculative worker handles logprob in speculative decoding
252293
req.output_token_logprobs_val.append(next_token_logprobs[i])
253294
req.output_token_logprobs_idx.append(next_token_id)
254295
if req.top_logprobs_num > 0:

0 commit comments

Comments
 (0)