Skip to content

Commit 30038d9

Browse files
Qiaolin-Yujhinpanhiyouga
authored
[inference] support sglang backend (#7278)
* Mimic SGLang offline Engine * Add more tests and args * Pass all current tests * Clean Code * fix sample_params * clean code * Fix Stream Chat * change sglang from engine mode to server mode * fix * Fix Review Issues * Use SGLang Built-In Utilities * Fix test SGLang * Some Doc Issue * fix sglang engine * add readme --------- Co-authored-by: Jin Pan <[email protected]> Co-authored-by: hiyouga <[email protected]>
1 parent ef5f1c1 commit 30038d9

File tree

15 files changed

+433
-27
lines changed

15 files changed

+433
-27
lines changed

README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ Choose your path:
7979
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
8080
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
8181
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
82-
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc.
83-
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
82+
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
83+
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
8484

8585
### Day-N Support for Fine-Tuning Cutting-Edge Models
8686

@@ -106,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
106106

107107
## Changelog
108108

109+
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
110+
109111
[25/03/12] We supported fine-tuning the **[Gemma-3](https://huggingface.co/blog/gemma3)** model.
110112

111113
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
@@ -437,7 +439,7 @@ cd LLaMA-Factory
437439
pip install -e ".[torch,metrics]"
438440
```
439441

440-
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality
442+
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality
441443

442444
> [!TIP]
443445
> Use `pip install --no-deps -e .` to resolve package conflicts.

README_zh.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
8181
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)[BAdam](https://github.com/Ledzy/BAdam)[APOLLO](https://github.com/zhuhanqing/APOLLO)[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
8282
- **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)[Unsloth](https://github.com/unslothai/unsloth)[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
8383
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
84-
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。
85-
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
84+
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。
85+
- **极速推理**:基于 [vLLM](https://github.com/vllm-project/vllm)[SGLang](https://github.com/sgl-project/sglang) 的 OpenAI 风格 API、浏览器界面和命令行接口。
8686

8787
### 最新模型的 Day-N 微调适配
8888

@@ -108,6 +108,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
108108

109109
## 更新日志
110110

111+
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
112+
111113
[25/03/12] 我们支持了 **[Gemma-3](https://huggingface.co/blog/gemma3)** 模型的微调。
112114

113115
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
@@ -439,7 +441,7 @@ cd LLaMA-Factory
439441
pip install -e ".[torch,metrics]"
440442
```
441443

442-
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
444+
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
443445

444446
> [!TIP]
445447
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。

examples/inference/llama3_sglang.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
2+
template: llama3
3+
infer_backend: sglang
4+
trust_remote_code: true

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def get_console_scripts() -> list[str]:
5454
"awq": ["autoawq"],
5555
"aqlm": ["aqlm[gpu]>=1.1.0"],
5656
"vllm": ["vllm>=0.4.3,<=0.7.3"],
57+
"sglang": ["sglang>=0.4.4"],
5758
"galore": ["galore-torch"],
5859
"apollo": ["apollo-torch"],
5960
"badam": ["badam>=1.2.1"],

src/llamafactory/chat/chat_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..extras.misc import torch_gc
2626
from ..hparams import get_infer_args
2727
from .hf_engine import HuggingfaceEngine
28+
from .sglang_engine import SGLangEngine
2829
from .vllm_engine import VllmEngine
2930

3031

@@ -52,6 +53,8 @@ def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
5253
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
5354
elif model_args.infer_backend == EngineName.VLLM:
5455
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
56+
elif model_args.infer_backend == EngineName.SGLANG:
57+
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
5558
else:
5659
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
5760

src/llamafactory/chat/hf_engine.py

+8-15
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
import concurrent.futures
1716
import os
1817
from collections.abc import AsyncGenerator
1918
from threading import Thread
@@ -349,7 +348,6 @@ async def chat(
349348
if not self.can_generate:
350349
raise ValueError("The current model does not support `chat`.")
351350

352-
loop = asyncio.get_running_loop()
353351
input_args = (
354352
self.model,
355353
self.tokenizer,
@@ -365,8 +363,7 @@ async def chat(
365363
input_kwargs,
366364
)
367365
async with self.semaphore:
368-
with concurrent.futures.ThreadPoolExecutor() as pool:
369-
return await loop.run_in_executor(pool, self._chat, *input_args)
366+
return await asyncio.to_thread(self._chat, *input_args)
370367

371368
@override
372369
async def stream_chat(
@@ -382,7 +379,6 @@ async def stream_chat(
382379
if not self.can_generate:
383380
raise ValueError("The current model does not support `stream_chat`.")
384381

385-
loop = asyncio.get_running_loop()
386382
input_args = (
387383
self.model,
388384
self.tokenizer,
@@ -398,13 +394,12 @@ async def stream_chat(
398394
input_kwargs,
399395
)
400396
async with self.semaphore:
401-
with concurrent.futures.ThreadPoolExecutor() as pool:
402-
stream = self._stream_chat(*input_args)
403-
while True:
404-
try:
405-
yield await loop.run_in_executor(pool, stream)
406-
except StopAsyncIteration:
407-
break
397+
stream = self._stream_chat(*input_args)
398+
while True:
399+
try:
400+
yield await asyncio.to_thread(stream)
401+
except StopAsyncIteration:
402+
break
408403

409404
@override
410405
async def get_scores(
@@ -415,8 +410,6 @@ async def get_scores(
415410
if self.can_generate:
416411
raise ValueError("Cannot get scores using an auto-regressive model.")
417412

418-
loop = asyncio.get_running_loop()
419413
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
420414
async with self.semaphore:
421-
with concurrent.futures.ThreadPoolExecutor() as pool:
422-
return await loop.run_in_executor(pool, self._get_scores, *input_args)
415+
return await asyncio.to_thread(self._get_scores, *input_args)

0 commit comments

Comments
 (0)