Support overlap scheduling for speculative decoding#9588
Open
timmy-feng wants to merge 2 commits intosgl-project:mainfrom
Open
Support overlap scheduling for speculative decoding#9588timmy-feng wants to merge 2 commits intosgl-project:mainfrom
timmy-feng wants to merge 2 commits intosgl-project:mainfrom
Conversation
f6446c8 to
16aac38
Compare
Collaborator
|
|
3a41324 to
25dd3a7
Compare
Collaborator
There was a problem hiding this comment.
This method and the overlap_prepare_for_verify method were deliberately kept separate to reduce the risk of breaking existing code.
19a3975 to
2e4d830
Compare
Co-authored-by: Nathan Wang <nathan.r.wang@gmail.com>
2e4d830 to
b3bec16
Compare
pythongiant
reviewed
Dec 30, 2025
| # but we copy seq_lens in the scheduler's stream. This is a problem because seq_lens may | ||
| # not have been mutated by EagleWorkerClient before the scheduler stream starts making | ||
| # a copy of it. To avoid this, we synchronize all streams before copying seq_lens. | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
Quick suggestion, is the torch.cuda.synchronize() here unintentionally serializing execution? It forces a full device barrier and can kill overlap throughput. Do you think something like this preserves overlap rather than stalling the whole GPU?
event = torch.cuda.Event(blocking=False)
event.record(draft_stream)
# later when TP worker needs results
event.wait(tp_stream)This way only the TP stream waits on draft completion rather than synchronizing the entire device
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Speculative decoding currently does not support overlap scheduling due to the sequential logic between the draft and target models. However, overlap scheduling has been shown to achieve up to 10% performance gains in non-speculative use cases. This PR achieves host overlap in speculative decoding with 5-10% improvement at various batch sizes.
Feature parity is in the works.
To enable this experimental feature, the
SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE=1environment variable must be set. Additionally, using Flash Attention 3 is recommended as there is a sync in the Flash Infer backend.Modifications
There should be no behavior change if the SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE environment variable is not set.
Host Syncs
The following was done to remove host syncs:
spec_stepsand adding padding handlersprocess_batch_result_decodeandfilter_batchrespectivelyresolve_last_batch_result-> eviction mask to scheduler)seq_lens_cpusince the host can only know the exact sequence length from one step agoEagle Client
After removal of all syncs,
EagleWorkerClientwas implemented with:forward_speculative_batch_generationfunction which puts work on a queue for the forward threadFutureSpecInfoclass which contains future buffers corresponding to each tensor inEagleDraftInputforward_thread_func_andresoluve_last_batch_resultwhich mirror their counterparts intp_worker_overlap_thread.pyFuture Work
We hope these items can be addressed in future PR's.
page_size > 1exists in this branch of work which supports paged attention for all backends other thanfa3eagle_worker.pyandeagle_worker_for_overlap_scheduer.py. We separated these two files for now to reduce the risk of breaking changes.Accuracy Tests
I ran GSM8K on an H100.
Benchmarking and Profiling
Benchmarks were run on an H200.
Before
With concurrency 1:
With concurrency 4:
After
With concurrency 1:
With concurrency 4:
Repro Script
This script was run on an H200:
Checklist