-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy pathrun_qwen3_omni_speech_server.py
More file actions
136 lines (110 loc) · 4.09 KB
/
run_qwen3_omni_speech_server.py
File metadata and controls
136 lines (110 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# SPDX-License-Identifier: Apache-2.0
"""Launch an OpenAI-compatible server for Qwen3-Omni with speech output.
Each stage runs in its own process with dedicated GPU placement.
Supports text + audio responses via the OpenAI chat completions API.
Usage::
python examples/run_qwen3_omni_speech_server.py
# Custom GPU placement:
python examples/run_qwen3_omni_speech_server.py \
--gpu-thinker 0 --gpu-talker 1 --gpu-code-predictor 2
# Then test:
curl http://localhost:8000/v1/chat/completions \\
-H "Content-Type: application/json" \\
-d '{
"model": "qwen3-omni",
"messages": [{"role": "user", "content": "Hello!"}],
"max_tokens": 64,
"stream": true,
"modalities": ["text", "audio"]
}'
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import multiprocessing as mp
import os
logging.basicConfig(
level=os.environ.get("LOGLEVEL", "INFO").upper(),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--model-path", type=str, default="Qwen/Qwen3-Omni-30B-A3B-Instruct"
)
# GPU placement
parser.add_argument("--gpu-thinker", type=int, default=0)
parser.add_argument("--gpu-talker", type=int, default=1)
parser.add_argument("--gpu-code-predictor", type=int, default=2)
parser.add_argument("--gpu-code2wav", type=int, default=0)
parser.add_argument("--gpu-image-encoder", type=int, default=0)
parser.add_argument("--gpu-audio-encoder", type=int, default=0)
# Pipeline
parser.add_argument(
"--relay-backend", type=str, default="shm", choices=["nixl", "shm"]
)
parser.add_argument(
"--mem-fraction-static",
type=float,
default=0.7,
help="Static memory fraction for SGLang-backed AR stages.",
)
# Server
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model-name", type=str, default="qwen3-omni")
return parser.parse_args()
async def main_async(args: argparse.Namespace) -> None:
import uvicorn
from sglang_omni.client import Client
from sglang_omni.models.qwen3_omni.config import Qwen3OmniSpeechPipelineConfig
from sglang_omni.pipeline.mp_runner import MultiProcessPipelineRunner
from sglang_omni.serve.openai_api import create_app
# Build GPU placement from CLI args
gpu_placement = {
"thinker": args.gpu_thinker,
"talker_ar": args.gpu_talker,
"code_predictor": args.gpu_code_predictor,
"code2wav": args.gpu_code2wav,
}
config = Qwen3OmniSpeechPipelineConfig(
model_path=args.model_path,
relay_backend=args.relay_backend,
gpu_placement=gpu_placement,
)
server_args_overrides = {"mem_fraction_static": args.mem_fraction_static}
for stage in config.stages:
if stage.name in {"thinker", "talker_ar"}:
stage.executor.args.setdefault("server_args_overrides", {}).update(
server_args_overrides
)
runner = MultiProcessPipelineRunner(config)
logger.info("Starting 9-stage speech pipeline (multiprocess)...")
await runner.start(timeout=600)
logger.info("Pipeline ready.")
try:
client = Client(runner.coordinator)
app = create_app(client, model_name=args.model_name)
server_config = uvicorn.Config(
app,
host=args.host,
port=args.port,
log_level="info",
)
server = uvicorn.Server(server_config)
await server.serve()
finally:
logger.info("Shutting down pipeline...")
await runner.stop()
logger.info("Pipeline stopped.")
def main() -> None:
mp.set_start_method("spawn", force=True)
args = parse_args()
asyncio.run(main_async(args))
if __name__ == "__main__":
main()