Skip to content

Commit b3bec16

Browse files
Avoid duplicating eagle_utils.py
1 parent 5c527ff commit b3bec16

File tree

10 files changed

+276
-1230
lines changed

10 files changed

+276
-1230
lines changed

python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import logging
4-
import os
54
from http import HTTPStatus
65
from typing import TYPE_CHECKING
76

@@ -147,31 +146,14 @@ def process_prebuilt_extend(
147146
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
148147

149148
# local import to avoid circular import
150-
if (
151-
os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0")
152-
== "1"
153-
):
154-
from sglang.srt.speculative.eagle_utils_for_overlap_scheduler import (
155-
EagleDraftInput,
156-
)
157-
158-
spec_info = EagleDraftInput(
159-
topk_p=topk_p,
160-
topk_index=topk_index,
161-
hidden_states=hidden_states,
162-
verified_id=self.output_ids,
163-
spec_steps=server_args.speculative_num_steps,
164-
)
165-
else:
166-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
167-
168-
spec_info = EagleDraftInput(
169-
topk_p=topk_p,
170-
topk_index=topk_index,
171-
hidden_states=hidden_states,
172-
verified_id=self.output_ids,
173-
)
149+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
174150

151+
spec_info = EagleDraftInput(
152+
topk_p=topk_p,
153+
topk_index=topk_index,
154+
hidden_states=hidden_states,
155+
verified_id=self.output_ids,
156+
)
175157
spec_info.prepare_for_extend(self)
176158
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
177159
self.spec_info = spec_info

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import os
43
from dataclasses import dataclass
54
from typing import TYPE_CHECKING, Optional, Union
65

@@ -14,14 +13,7 @@
1413
from sglang.srt.managers.schedule_batch import global_server_args_dict
1514
from sglang.srt.mem_cache.memory_pool import SWAKVPool
1615
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
17-
18-
if os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0") == "1":
19-
from sglang.srt.speculative.eagle_utils_for_overlap_scheduler import (
20-
EagleDraftInput,
21-
EagleVerifyInput,
22-
)
23-
else:
24-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
16+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
2517

2618
if TYPE_CHECKING:
2719
from sglang.srt.layers.radix_attention import RadixAttention
@@ -340,6 +332,9 @@ def __init__(
340332
model_runner.token_to_kv_pool.full_to_swa_index_mapping
341333
)
342334
self.topk = model_runner.server_args.speculative_eagle_topk or 0
335+
self.enable_overlap_schedule = (
336+
not model_runner.server_args.disable_overlap_schedule
337+
)
343338
self.speculative_num_steps = speculative_num_steps
344339
self.speculative_num_draft_tokens = (
345340
model_runner.server_args.speculative_num_draft_tokens
@@ -1902,9 +1897,9 @@ def init_forward_metadata_replay_cuda_graph(
19021897
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
19031898
)
19041899
accept_length = spec_info.accept_length[:bs]
1905-
if getattr(spec_info, "spec_steps", None) is not None:
1900+
if self.enable_overlap_schedule:
19061901
# EAGLE + Overlap scheduling code path
1907-
metadata.max_seq_len_q = spec_info.spec_steps + 1
1902+
metadata.max_seq_len_q = self.speculative_num_steps + 1
19081903
elif spec_info.accept_length_cpu:
19091904
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
19101905
else:

python/sglang/srt/managers/schedule_batch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
906906
# Speculative decoding
907907
spec_algorithm: SpeculativeAlgorithm = None
908908
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
909-
# Used for EAGLE + Overlap scheduling only. Stores the temporary draft output token locations.
909+
# Used for EAGLE + Overlap scheduling. Stores the temporary draft output KV cache locations.
910910
draft_out_cache_loc: Optional[torch.Tensor] = None
911911

912912
# Whether to return hidden states
@@ -1695,7 +1695,9 @@ def filter_batch(
16951695

16961696
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
16971697
if self.spec_info:
1698-
self.spec_info.filter_batch(keep_indices_device)
1698+
self.spec_info.filter_batch(
1699+
keep_indices_device, has_been_filtered=not self.enable_overlap
1700+
)
16991701

17001702
def merge_batch(self, other: "ScheduleBatch"):
17011703
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1948,7 +1950,7 @@ class ModelWorkerBatch:
19481950
# If set, the output of the batch contains the hidden states of the run.
19491951
capture_hidden_mode: CaptureHiddenMode = None
19501952
hicache_consumer_index: int = 0
1951-
# Used for EAGLE + Overlap scheduling only. Stores the temporary draft output token locations.
1953+
# Used for EAGLE + Overlap scheduling. Stores the temporary draft output KV cache locations.
19521954
draft_out_cache_loc: Optional[torch.Tensor] = None
19531955

19541956
# Overlap event

python/sglang/srt/managers/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1814,7 +1814,6 @@ def run_batch(
18141814

18151815
model_worker_batch = batch.get_model_worker_batch()
18161816
if self.enable_overlap:
1817-
# TODO (timmy): Do not alias seq_lens between forward and scheduler threads.
18181817
# Optimistically estimate the seq_lens_cpu for the next draft forward
18191818
model_worker_batch.seq_lens_cpu.add_(
18201819
self.server_args.speculative_num_steps + 1

python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import bisect
4-
import os
54
from typing import TYPE_CHECKING, Callable
65

76
import torch
@@ -21,12 +20,7 @@
2120
ForwardBatch,
2221
ForwardMode,
2322
)
24-
25-
if os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0") == "1":
26-
from sglang.srt.speculative.eagle_utils_for_overlap_scheduler import EagleDraftInput
27-
else:
28-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
29-
23+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
3024
from sglang.srt.utils import (
3125
require_attn_tp_gather,
3226
require_gathered_buffer,
@@ -210,23 +204,12 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable):
210204
global_dp_buffer_len = None
211205
global_num_tokens_for_logprob = None
212206

213-
if (
214-
os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0")
215-
== "1"
216-
):
217-
spec_info = EagleDraftInput(
218-
topk_p=topk_p,
219-
topk_index=topk_index,
220-
hidden_states=hidden_states,
221-
capture_hidden_mode=CaptureHiddenMode.LAST,
222-
)
223-
else:
224-
spec_info = EagleDraftInput(
225-
topk_p=topk_p,
226-
topk_index=topk_index,
227-
hidden_states=hidden_states,
228-
capture_hidden_mode=CaptureHiddenMode.LAST,
229-
)
207+
spec_info = EagleDraftInput(
208+
topk_p=topk_p,
209+
topk_index=topk_index,
210+
hidden_states=hidden_states,
211+
capture_hidden_mode=CaptureHiddenMode.LAST,
212+
)
230213

231214
# Forward batch
232215
forward_batch = ForwardBatch(

python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import bisect
4-
import os
54
from typing import TYPE_CHECKING, Callable
65

76
import torch
@@ -22,15 +21,7 @@
2221
ForwardBatch,
2322
ForwardMode,
2423
)
25-
26-
if os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0") == "1":
27-
from sglang.srt.speculative.eagle_utils_for_overlap_scheduler import (
28-
EagleDraftInput,
29-
fast_topk,
30-
)
31-
else:
32-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
33-
24+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
3425
from sglang.srt.utils import (
3526
require_attn_tp_gather,
3627
require_gathered_buffer,
@@ -236,20 +227,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
236227
else:
237228
global_dp_buffer_len = None
238229

239-
if (
240-
os.environ.get("SGLANG_ENABLE_EXPERIMENTAL_EAGLE_OVERLAP_SCHEDULE", "0")
241-
== "1"
242-
):
243-
spec_info = EagleDraftInput(
244-
hidden_states=hidden_states,
245-
accept_length=accept_length,
246-
spec_steps=self.speculative_num_steps,
247-
)
248-
else:
249-
spec_info = EagleDraftInput(
250-
hidden_states=hidden_states,
251-
accept_length=accept_length,
252-
)
230+
spec_info = EagleDraftInput(
231+
hidden_states=hidden_states,
232+
accept_length=accept_length,
233+
)
253234
spec_info.positions = None
254235

255236
# Forward batch

0 commit comments

Comments
 (0)