Skip to content

Commit f326e2f

Browse files
committed
[SPEC] feat: init adaptive spec params from config
1 parent 42fe025 commit f326e2f

5 files changed

Lines changed: 73 additions & 54 deletions

File tree

python/sglang/srt/arg_groups/speculative_hook.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def handle_speculative_decoding(server_args: "ServerArgs") -> None:
108108
f"speculative_algorithm == EAGLE, got {server_args.speculative_algorithm}."
109109
)
110110

111+
if server_args.speculative_adaptive:
112+
_maybe_disable_adaptive(server_args)
113+
if server_args.speculative_adaptive:
114+
_init_adaptive_speculative_params(server_args)
115+
111116
if server_args.speculative_algorithm == "DFLASH":
112117
_handle_dflash(server_args)
113118
elif server_args.speculative_algorithm == "FROZEN_KV_MTP":
@@ -117,18 +122,6 @@ def handle_speculative_decoding(server_args: "ServerArgs") -> None:
117122
elif server_args.speculative_algorithm == "NGRAM":
118123
_handle_ngram(server_args)
119124

120-
if server_args.speculative_adaptive:
121-
_maybe_disable_adaptive(server_args)
122-
if server_args.speculative_adaptive:
123-
from sglang.srt.speculative.adaptive_spec_params import (
124-
validate_adaptive_initial_steps,
125-
)
126-
127-
validate_adaptive_initial_steps(
128-
server_args.speculative_num_steps,
129-
server_args.speculative_adaptive_config,
130-
)
131-
132125

133126
def _handle_dflash(server_args: "ServerArgs") -> None:
134127
if server_args.enable_dp_attention:
@@ -356,7 +349,10 @@ def _handle_eagle_family(server_args: "ServerArgs") -> None:
356349
"DeepSeek MTP does not require setting speculative_draft_model_path."
357350
)
358351

359-
if server_args.speculative_num_steps is None:
352+
if (
353+
not server_args.speculative_adaptive
354+
and server_args.speculative_num_steps is None
355+
):
360356
assert (
361357
server_args.speculative_eagle_topk is None
362358
and server_args.speculative_num_draft_tokens is None
@@ -472,3 +468,31 @@ def _maybe_disable_adaptive(server_args: "ServerArgs") -> None:
472468
"Falling back to static speculative params."
473469
)
474470
server_args.speculative_adaptive = False
471+
472+
473+
def _init_adaptive_speculative_params(server_args: "ServerArgs") -> None:
474+
from sglang.srt.speculative.adaptive_spec_params import (
475+
resolve_candidate_steps_from_config,
476+
)
477+
478+
candidate_steps = resolve_candidate_steps_from_config(
479+
cfg_path=server_args.speculative_adaptive_config,
480+
)
481+
482+
if server_args.speculative_eagle_topk is None:
483+
server_args.speculative_eagle_topk = 1
484+
485+
if server_args.speculative_num_steps is None:
486+
server_args.speculative_num_steps = candidate_steps[len(candidate_steps) // 2]
487+
488+
if server_args.speculative_num_steps not in candidate_steps:
489+
raise ValueError(
490+
f"--speculative-num-steps={server_args.speculative_num_steps} "
491+
f"is not in the adaptive config candidate_steps {candidate_steps}. "
492+
"Pass one of those values."
493+
)
494+
495+
server_args.speculative_num_draft_tokens = (
496+
server_args.speculative_num_steps + 1
497+
)
498+

python/sglang/srt/speculative/adaptive_spec_params.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def adaptive_unsupported_reason(server_args: ServerArgs) -> str | None:
4949
f"speculative_algorithm={server_args.speculative_algorithm} "
5050
"(only EAGLE/EAGLE3 are supported)"
5151
)
52-
if server_args.speculative_eagle_topk != 1:
52+
if (
53+
server_args.speculative_eagle_topk is not None
54+
and server_args.speculative_eagle_topk != 1
55+
):
5356
return (
5457
f"speculative_eagle_topk={server_args.speculative_eagle_topk} "
5558
"(only topk=1 is supported)"
@@ -126,19 +129,6 @@ def resolve_candidate_steps_from_config(
126129
return sorted(all_steps)
127130

128131

129-
def validate_adaptive_initial_steps(
130-
initial_steps: int,
131-
cfg_path: str | None = None,
132-
) -> None:
133-
"""Require the initial step to be a candidate of some BS slot."""
134-
candidate_steps = resolve_candidate_steps_from_config(cfg_path)
135-
if initial_steps not in candidate_steps:
136-
raise ValueError(
137-
f"--speculative-num-steps={initial_steps} is not in the adaptive "
138-
f"config candidate_steps {candidate_steps}. Pass one of those values."
139-
)
140-
141-
142132
class AdaptiveStepSlot:
143133
"""Tracks acceptance rate via EMA and adapts num_steps accordingly.
144134

test/registered/spec/eagle/test_adaptive_speculative.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,6 @@ def setUpClass(cls):
7272
"EAGLE",
7373
"--speculative-draft-model-path",
7474
cls.draft_model,
75-
"--speculative-num-steps",
76-
"1",
77-
"--speculative-eagle-topk",
78-
"1",
79-
"--speculative-num-draft-tokens",
80-
"2",
8175
"--speculative-adaptive",
8276
"--speculative-adaptive-config",
8377
cls.adaptive_config_path,

test/registered/unit/server_args/test_server_args.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,38 @@ def test_external_corpus_max_tokens_must_be_positive(self):
522522
self.assertIn("external-corpus-max-tokens", str(context.exception))
523523

524524

525+
class TestAdaptiveSpecArgs(CustomTestCase):
526+
def test_adaptive_defaults_to_config_step_when_spec_params_omitted(self):
527+
with tempfile.NamedTemporaryFile("w", suffix=".json") as f:
528+
json.dump(
529+
{
530+
"1": {"candidate_steps": [1, 3, 5]},
531+
"8": {"candidate_steps": [1]},
532+
},
533+
f,
534+
)
535+
f.flush()
536+
537+
args = ServerArgs(model_path="dummy")
538+
args.speculative_algorithm = "EAGLE"
539+
args.speculative_adaptive = True
540+
args.speculative_adaptive_config = f.name
541+
args.device = "cuda"
542+
args.get_model_config = lambda: SimpleNamespace(
543+
hf_config=SimpleNamespace(
544+
architectures=["LlamaForCausalLM"],
545+
get_text_config=lambda: SimpleNamespace(),
546+
)
547+
)
548+
549+
handle_speculative_decoding(args)
550+
551+
self.assertTrue(args.speculative_adaptive)
552+
self.assertEqual(args.speculative_eagle_topk, 1)
553+
self.assertEqual(args.speculative_num_steps, 3)
554+
self.assertEqual(args.speculative_num_draft_tokens, 4)
555+
556+
525557
class TestDeepEPWaterfillArgs(CustomTestCase):
526558
def test_waterfill_enforces_shared_experts_fusion(self):
527559
server_args = ServerArgs(

test/registered/unit/spec/test_adaptive_spec_params.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
AdaptiveSpeculativeParams,
77
AdaptiveStepSlot,
88
resolve_candidate_steps_from_config,
9-
validate_adaptive_initial_steps,
109
)
1110
from sglang.test.ci.ci_register import register_cpu_ci, register_xpu_ci
1211

@@ -403,25 +402,5 @@ def test_unions_and_dedups_across_slots(self):
403402
self.assertEqual(steps, [1, 3, 5, 7])
404403

405404

406-
class TestValidateAdaptiveInitialSteps(unittest.TestCase):
407-
def test_accepts_value_from_any_slot(self):
408-
with tempfile.NamedTemporaryFile("w", suffix=".json") as f:
409-
json.dump(
410-
{
411-
"1": {"candidate_steps": [1, 5]},
412-
"8": {"candidate_steps": [1, 3, 7]},
413-
},
414-
f,
415-
)
416-
f.flush()
417-
# Membership in any slot is enough: 5 lives in the smallest slot,
418-
# 7 only in a larger slot -- both accepted.
419-
validate_adaptive_initial_steps(5, cfg_path=f.name)
420-
validate_adaptive_initial_steps(7, cfg_path=f.name)
421-
# 9 is in no slot -> rejected.
422-
with self.assertRaises(ValueError):
423-
validate_adaptive_initial_steps(9, cfg_path=f.name)
424-
425-
426405
if __name__ == "__main__":
427406
unittest.main()

0 commit comments

Comments
 (0)