Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions python/sglang/srt/arg_groups/speculative_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def handle_speculative_decoding(server_args: "ServerArgs") -> None:
f"speculative_algorithm == EAGLE, got {server_args.speculative_algorithm}."
)

if server_args.speculative_adaptive:
_maybe_disable_adaptive(server_args)
if server_args.speculative_adaptive:
_init_adaptive_speculative_params(server_args)

if server_args.speculative_algorithm == "DFLASH":
_handle_dflash(server_args)
elif server_args.speculative_algorithm == "FROZEN_KV_MTP":
Expand All @@ -117,18 +122,6 @@ def handle_speculative_decoding(server_args: "ServerArgs") -> None:
elif server_args.speculative_algorithm == "NGRAM":
_handle_ngram(server_args)

if server_args.speculative_adaptive:
_maybe_disable_adaptive(server_args)
if server_args.speculative_adaptive:
from sglang.srt.speculative.adaptive_spec_params import (
validate_adaptive_initial_steps,
)

validate_adaptive_initial_steps(
server_args.speculative_num_steps,
server_args.speculative_adaptive_config,
)


def _handle_dflash(server_args: "ServerArgs") -> None:
if server_args.enable_dp_attention:
Expand Down Expand Up @@ -340,18 +333,20 @@ def _handle_eagle_family(server_args: "ServerArgs") -> None:
"DeepSeek MTP does not require setting speculative_draft_model_path."
)

if server_args.speculative_num_steps is None:
if (
not server_args.speculative_adaptive
and server_args.speculative_num_steps is None
):
assert (
server_args.speculative_eagle_topk is None
and server_args.speculative_num_draft_tokens is None
)
from sglang.srt.server_args import auto_choose_speculative_params

(
server_args.speculative_num_steps,
server_args.speculative_eagle_topk,
server_args.speculative_num_draft_tokens,
) = auto_choose_speculative_params(server_args)
) = _auto_choose_speculative_params(server_args, model_arch)

if (
server_args.attention_backend == "trtllm_mha"
Expand Down Expand Up @@ -462,3 +457,63 @@ def _maybe_disable_adaptive(server_args: "ServerArgs") -> None:
"Falling back to static speculative params."
)
server_args.speculative_adaptive = False


def _init_adaptive_speculative_params(server_args: "ServerArgs") -> None:
from sglang.srt.speculative.adaptive_spec_params import (
resolve_candidate_steps_from_config,
)

candidate_steps = resolve_candidate_steps_from_config(
cfg_path=server_args.speculative_adaptive_config,
)

if server_args.speculative_eagle_topk is None:
server_args.speculative_eagle_topk = 1

if server_args.speculative_num_steps is None:
server_args.speculative_num_steps = candidate_steps[len(candidate_steps) // 2]

if server_args.speculative_num_steps not in candidate_steps:
raise ValueError(
f"--speculative-num-steps={server_args.speculative_num_steps} "
f"is not in the adaptive config candidate_steps {candidate_steps}. "
"Pass one of those values."
)

server_args.speculative_num_draft_tokens = server_args.speculative_num_steps + 1


def _auto_choose_speculative_params(
server_args: "ServerArgs", model_arch: str
) -> tuple:
"""
Automatically choose the parameters for speculative decoding.

You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
if server_args.speculative_algorithm == "STANDALONE":
return (3, 1, 4)
if model_arch in ["LlamaForCausalLM"]:
return (5, 4, 8)
elif model_arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"GptOssForCausalLM",
"Glm4MoeForCausalLM",
"Glm4MoeLiteForCausalLM",
"GlmMoeDsaForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
"BailingMoeV2_5ForCausalLM",
"MistralLarge3ForCausalLM",
"PixtralForConditionalGeneration",
"MiMoV2ForCausalLM",
"MiMoV2FlashForCausalLM",
]:
return (3, 1, 4)
elif model_arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
return (5, 4, 8)
else:
return (3, 1, 4)
38 changes: 0 additions & 38 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8147,41 +8147,3 @@ def init_new(
).to_tcp(),
instance_id=instance_id,
)


def auto_choose_speculative_params(self: ServerArgs):
"""
Automatically choose the parameters for speculative decoding.

You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
hf_config = self.get_model_config().hf_config
arch = hf_config.architectures[0]
if self.speculative_algorithm == "STANDALONE":
# The default value for standalone speculative decoding
return (3, 1, 4)
if arch in ["LlamaForCausalLM"]:
# The default value for llama
return (5, 4, 8)
elif arch in [
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"GptOssForCausalLM",
"Glm4MoeForCausalLM",
"Glm4MoeLiteForCausalLM",
"GlmMoeDsaForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
"BailingMoeV2_5ForCausalLM",
"MistralLarge3ForCausalLM",
"PixtralForConditionalGeneration",
"MiMoV2ForCausalLM",
"MiMoV2FlashForCausalLM",
]:
return (3, 1, 4)
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
return (5, 4, 8)
else:
# The default value for all other models
return (3, 1, 4)
18 changes: 4 additions & 14 deletions python/sglang/srt/speculative/adaptive_spec_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def adaptive_unsupported_reason(server_args: ServerArgs) -> str | None:
f"speculative_algorithm={server_args.speculative_algorithm} "
"(only EAGLE/EAGLE3 are supported)"
)
if server_args.speculative_eagle_topk != 1:
if (
server_args.speculative_eagle_topk is not None
and server_args.speculative_eagle_topk != 1
):
return (
f"speculative_eagle_topk={server_args.speculative_eagle_topk} "
"(only topk=1 is supported)"
Expand Down Expand Up @@ -126,19 +129,6 @@ def resolve_candidate_steps_from_config(
return sorted(all_steps)


def validate_adaptive_initial_steps(
initial_steps: int,
cfg_path: str | None = None,
) -> None:
"""Require the initial step to be a candidate of some BS slot."""
candidate_steps = resolve_candidate_steps_from_config(cfg_path)
if initial_steps not in candidate_steps:
raise ValueError(
f"--speculative-num-steps={initial_steps} is not in the adaptive "
f"config candidate_steps {candidate_steps}. Pass one of those values."
)


class AdaptiveStepSlot:
"""Tracks acceptance rate via EMA and adapts num_steps accordingly.

Expand Down
6 changes: 0 additions & 6 deletions test/registered/spec/eagle/test_adaptive_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@ def setUpClass(cls):
"EAGLE",
"--speculative-draft-model-path",
cls.draft_model,
"--speculative-num-steps",
"1",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"2",
"--speculative-adaptive",
"--speculative-adaptive-config",
cls.adaptive_config_path,
Expand Down
33 changes: 33 additions & 0 deletions test/registered/unit/server_args/test_server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import tempfile
import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import sglang.srt.server_args as server_args_module
Expand Down Expand Up @@ -522,6 +523,38 @@ def test_external_corpus_max_tokens_must_be_positive(self):
self.assertIn("external-corpus-max-tokens", str(context.exception))


class TestAdaptiveSpecArgs(CustomTestCase):
def test_adaptive_defaults_to_config_step_when_spec_params_omitted(self):
with tempfile.NamedTemporaryFile("w", suffix=".json") as f:
Comment on lines +527 to +528

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test uses SimpleNamespace to mock the model configuration, but SimpleNamespace is not imported in this file. This will cause a NameError when the test is executed. Please import SimpleNamespace from types.

    def test_adaptive_defaults_to_config_step_when_spec_params_omitted(self):
        from types import SimpleNamespace
        with tempfile.NamedTemporaryFile("w", suffix=".json") as f:

json.dump(
{
"1": {"candidate_steps": [1, 3, 5]},
"8": {"candidate_steps": [1]},
},
f,
)
f.flush()

args = ServerArgs(model_path="dummy")
args.speculative_algorithm = "EAGLE"
args.speculative_adaptive = True
args.speculative_adaptive_config = f.name
args.device = "cuda"
args.get_model_config = lambda: SimpleNamespace(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Import SimpleNamespace before using it

This new test uses SimpleNamespace but the module never imports it, so when this test runs with the normal test dependencies installed it raises NameError while constructing args.get_model_config and never exercises handle_speculative_decoding. Add from types import SimpleNamespace (as done in other tests) so the adaptive-default coverage can pass.

Useful? React with 👍 / 👎.

hf_config=SimpleNamespace(
architectures=["LlamaForCausalLM"],
get_text_config=lambda: SimpleNamespace(),
)
)

handle_speculative_decoding(args)

self.assertTrue(args.speculative_adaptive)
self.assertEqual(args.speculative_eagle_topk, 1)
self.assertEqual(args.speculative_num_steps, 3)
self.assertEqual(args.speculative_num_draft_tokens, 4)


class TestDeepEPWaterfillArgs(CustomTestCase):
def test_waterfill_enforces_shared_experts_fusion(self):
server_args = ServerArgs(
Expand Down
21 changes: 0 additions & 21 deletions test/registered/unit/spec/test_adaptive_spec_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
AdaptiveSpeculativeParams,
AdaptiveStepSlot,
resolve_candidate_steps_from_config,
validate_adaptive_initial_steps,
)
from sglang.test.ci.ci_register import register_cpu_ci, register_xpu_ci

Expand Down Expand Up @@ -403,25 +402,5 @@ def test_unions_and_dedups_across_slots(self):
self.assertEqual(steps, [1, 3, 5, 7])


class TestValidateAdaptiveInitialSteps(unittest.TestCase):
def test_accepts_value_from_any_slot(self):
with tempfile.NamedTemporaryFile("w", suffix=".json") as f:
json.dump(
{
"1": {"candidate_steps": [1, 5]},
"8": {"candidate_steps": [1, 3, 7]},
},
f,
)
f.flush()
# Membership in any slot is enough: 5 lives in the smallest slot,
# 7 only in a larger slot -- both accepted.
validate_adaptive_initial_steps(5, cfg_path=f.name)
validate_adaptive_initial_steps(7, cfg_path=f.name)
# 9 is in no slot -> rejected.
with self.assertRaises(ValueError):
validate_adaptive_initial_steps(9, cfg_path=f.name)


if __name__ == "__main__":
unittest.main()
Loading