-
Notifications
You must be signed in to change notification settings - Fork 110
Description
Hi, thanks for your great work on SGLang and SpecForge!
I am trying to test https://huggingface.co/Rayzl/qwen2.5-vl-7b-eagle3-sgl on Qwen2.5-VL using the reference configs from: #102 , but the speculative decoding performance is far below expectations.
Below is a detailed report of my setup, logs, and results.
1. My SGLang server command
python -m sglang.launch_server \
--model-path /ch/pretrained_models/Qwen2.5-VL-7B-Instruct \
--speculative-draft-model-path /ch/pretrained_models/qwen2.5-vl-7b-eagle3-sgl \
--speculative-algorithm EAGLE3 \
--speculative-num-steps 4 \
--speculative-eagle-topk 1 \
--speculative-num-draft-tokens 24 \
--trust-remote-code \
--chunked-prefill-size -1 \
--cuda-graph-max-bs 1 \
--tp 1 \
--mem-fraction-static 0.7 \
--host 0.0.0.0 \
--port 8080Client benchmark:
python run_mmstar.py --host http://0.0.0.0 --port 8080 --parallel 1 --num-questions 502. Results (Qwen2.5-VL with EAGLE3)
Average Latency: 92.421 s
Average Output throughput: 41.960 token/s
Average Accept length: 1.037
3. SGLang logs (accept length always ≈ 1)
Below are several captured decode logs:
[2025-11-19 15:25:25] Decode batch, #running-req: 1, #token: 353, token usage: 0.04, accept len: 1.02, accept rate: 0.20, cuda graph: True, gen throughput (token/s): 44.13, #queue-req: 0,
[2025-11-19 15:25:26] Decode batch, #running-req: 1, #token: 393, token usage: 0.05, accept len: 1.00, accept rate: 0.20, cuda graph: True, gen throughput (token/s): 43.13, #queue-req: 0,
[2025-11-19 15:25:27] Decode batch, #running-req: 1, #token: 433, token usage: 0.05, accept len: 1.00, accept rate: 0.20, cuda graph: True, gen throughput (token/s): 43.12, #queue-req: 0,
[2025-11-19 15:25:28] Decode batch, #running-req: 1, #token: 474, token usage: 0.06, accept len: 1.02, accept rate: 0.20, cuda graph: True, gen throughput (token/s): 44.20, #queue-req: 0,
[2025-11-19 15:25:29] Decode batch, #running-req: 1, #token: 514, token usage: 0.06, accept len: 1.00, accept rate: 0.20, cuda graph: True, gen throughput (token/s): 43.01, #queue-req: 0,
[2025-11-19 15:25:30] Decode batch, #running-req: 1, #token: 557, token usage: 0.06, accept len: 1.07, accept rate: 0.21, cuda graph: True, gen throughput (token/s): 46.21, #queue-req: 0,
This suggests that the draft model’s predictions are almost always rejected.
4. Similar behavior on Qwen3-VL
The result is essentially the same: accept_len ≈ 1.
5. However: Llama-3.1-8B + EAGLE3 works correctly
Using the same speculative settings:
- speculative-num-steps=4
- speculative-eagle-topk=6
- speculative-num-draft-tokens=24
with https://huggingface.co/lmsys/sglang-EAGLE-LLaMA3-Instruct-8B on gsm8k I get expected results:
Average Latency: 52.161 s
Average Output throughput: 86.099 token/s
Average Accept length: 2.313
So the EAGLE3 pipeline works normally on Llama models.
6. VLLM results: Qwen2.5-VL EAGLE3 behaves correctly
I also tested Qwen2.5-VL EAGLE3 in vLLM, using configs from vllm-project/vllm#22872
Example command:
vllm serve \
/ch/pretrained_models/Qwen2.5-VL-7B-Instruct \
--port 5580 --host 0.0.0.0 \
--max-num-seqs 128 --dtype bfloat16 --max-model-len=8192 \
--no-enable-prefix-caching --trust-remote-code -tp 1\
--speculative-config '{"method": "eagle3", "model": "/ch/pretrained_models/qwen2.5-vl-7b-eagle3-sgl", "prefill_token_shift": false, "num_speculative_tokens": 3, "draft_tensor_parallel_size": 1, "max_model_len": 8192}' \
--num-lookahead-slots=3 \
--gpu-memory-utilization=0.93Results:
- with EAGLE3: Output token throughput (tok/s) = 135.67
- without EAGLE3: Output token throughput (tok/s) = 97.92
- end-to-end speedup ≈ 1.385× → ✔ expected behavior
Meaning:
The Qwen2.5-VL EAGLE3 draft model itself is fine but SGLang’s integration leads to extremely low accept_len.
7. My question
- Is my configuration missing anything specific for multimodal models?
- Are additional modifications needed beyond PR #8801 to fully support Qwen VL EAGLE3?
Any guidance or hints would be greatly appreciated.
Thank you very much for your help!