Skip to content

Commit c4da436

Browse files
authored
[Feature] Add tensor parallelism support for ming-omni (#270)
* [Core] Add TP driver-worker infrastructure for omni thinker - Add TP follower module with register_omni_models(), follower_worker_loop(), and spawn_followers() for TP ranks > 0 - Expose tp_cpu_group and tp_size properties on ModelWorker - Spawn TP follower processes in create_sglang_ar_engine before rank 0 init - Manage follower lifecycle (stop signal + terminate) in OmniEngine - Broadcast serialization-safe ModelWorkerBatch to followers each step * [Bugfix] Fix TP runtime issues (daemon, pickle, broadcast, model config) - Add model_override_args for Ming TP>1 (flattened llm_config fields) - Fix broadcast_pyobj list wrapping and ServerArgs field name - Set daemon=False for TP stages to allow follower subprocess spawning - Remove mp.Event from follower spawn (unpicklable in nested spawn) * [TP] Serialization-safe follower batch with device relocation - Add serialization.py with make_follower_batch() and first-call pickle verification that diagnoses unpicklable fields - Strip reqs, sampling_info, launch_done before broadcast - Add relocate_batch_tensors() to move all tensors (including nested multimodal inputs like mrope_position_delta) to follower device - Extract patch_batch_for_follower() to module scope for testability - Pass real vocab_size to SamplingBatchInfo stub * [Test] TP validation, serialization tests, and review hardening - Add test_tp_follower.py: registration, picklability, batch patching - Add test_tp_batch_serialization.py: pickle round-trip regression tests - Add scripts/test_ming_tp.py: TP=1 vs TP=2 output consistency harness * [Config] Expose TP in Qwen3 pipeline config and example entrypoints - Inject server_args_overrides into thinker stage in Qwen3 pipeline configs - Add --tp-size, --cpu-offload-gb, --mem-fraction-static CLI flags to run_qwen3_omni_server.py and run_qwen3_omni_text_first.py * [Style] Pre-commit formatting fixes * [TP] Sync page tables and harden follower runtime * [TP] Mirror multimodal thinker forward paths on followers * [Bugfix] Tear down torch.distributed in MingImageEncoder cleanup * [TP] Enable tensor parallelism for Ming-flash-omni vision encoder * [TP] fix follower shutdown ordering * [Style] fix lint * [Style] restore note attribution * [TP] Fix Qwen3 speech TP validation and factory test mocks * [TP] Fix _DummyEngine stub to accept follower_processes * [chore] trim comments * fix * [TP] address review on PR #270 * [Test] Fix black formatting
1 parent 922b9aa commit c4da436

29 files changed

Lines changed: 1733 additions & 252 deletions

examples/run_ming_omni_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import argparse
2525
import logging
26+
import multiprocessing as mp
2627
import os
2728

2829
from sglang_omni.models.ming_omni.config import MingOmniPipelineConfig
@@ -100,6 +101,7 @@ def main() -> None:
100101
overrides = {}
101102
if args.tp_size and args.tp_size > 1:
102103
overrides["tp_size"] = args.tp_size
104+
overrides["disable_custom_all_reduce"] = True
103105
if args.quantization:
104106
overrides["quantization"] = args.quantization
105107
if args.cpu_offload_gb:
@@ -122,4 +124,5 @@ def main() -> None:
122124

123125

124126
if __name__ == "__main__":
127+
mp.set_start_method("spawn", force=True)
125128
main()

examples/run_ming_omni_speech.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def parse_args() -> argparse.Namespace:
7575
parser.add_argument("--timeout", type=float, default=300.0)
7676
parser.add_argument("--cpu-offload-gb", type=float, default=0)
7777
parser.add_argument("--mem-fraction-static", type=float, default=None)
78+
parser.add_argument(
79+
"--tp-size", type=int, default=1, help="Tensor parallel size for thinker"
80+
)
7881
return parser.parse_args()
7982

8083

@@ -89,6 +92,8 @@ async def main_async(args: argparse.Namespace) -> None:
8992
}
9093

9194
overrides = {}
95+
if args.tp_size > 1:
96+
overrides["tp_size"] = args.tp_size
9297
if args.cpu_offload_gb:
9398
overrides["cpu_offload_gb"] = args.cpu_offload_gb
9499
if args.mem_fraction_static is not None:

examples/run_qwen3_omni_server.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,19 @@ def parse_args() -> argparse.Namespace:
4444
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
4545
help="Hugging Face model id or local path",
4646
)
47-
parser.add_argument("--thinker-max-seq-len", type=int, default=8192)
47+
parser.add_argument("--thinker-max-seq-len", type=int, default=None)
48+
parser.add_argument(
49+
"--cpu-offload-gb",
50+
type=int,
51+
default=0,
52+
help="GB of model weights to offload to CPU",
53+
)
54+
parser.add_argument(
55+
"--mem-fraction-static",
56+
type=float,
57+
default=None,
58+
help="Fraction of GPU memory for KV cache",
59+
)
4860

4961
# Pipeline options
5062
parser.add_argument(
@@ -71,11 +83,26 @@ def parse_args() -> argparse.Namespace:
7183
def main() -> None:
7284
args = parse_args()
7385

86+
overrides = {}
87+
if args.cpu_offload_gb:
88+
overrides["cpu_offload_gb"] = args.cpu_offload_gb
89+
if args.mem_fraction_static is not None:
90+
overrides["mem_fraction_static"] = args.mem_fraction_static
91+
7492
config = Qwen3OmniPipelineConfig(
7593
model_path=args.model_path,
7694
relay_backend=args.relay_backend,
95+
server_args_overrides=overrides or None,
7796
)
7897

98+
# Override thinker_max_seq_len in stage executor args if provided
99+
if args.thinker_max_seq_len is not None:
100+
for stage in config.stages:
101+
if stage.name == "thinker":
102+
if stage.executor.args is None:
103+
stage.executor.args = {}
104+
stage.executor.args["thinker_max_seq_len"] = args.thinker_max_seq_len
105+
79106
launch_server(
80107
config,
81108
host=args.host,

examples/run_qwen3_omni_text_first.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def parse_args() -> argparse.Namespace:
2727
help="Hugging Face model id",
2828
)
2929
parser.add_argument("--prompt", type=str, default="Describe this input.")
30-
parser.add_argument("--dtype", type=str, default="bfloat16")
31-
parser.add_argument("--thinker-max-seq-len", type=int, default=8192)
30+
parser.add_argument("--thinker-max-seq-len", type=int, default=None)
3231
parser.add_argument("--max-new-tokens", type=int, default=1024)
3332
parser.add_argument("--temperature", type=float, default=0.8)
3433
parser.add_argument("--image-path", type=str, default=None)
@@ -40,14 +39,42 @@ def parse_args() -> argparse.Namespace:
4039
parser.add_argument(
4140
"--relay-backend", type=str, default="nixl", choices=["nixl", "shm"]
4241
)
42+
parser.add_argument(
43+
"--cpu-offload-gb",
44+
type=int,
45+
default=0,
46+
help="GB of model weights to offload to CPU",
47+
)
48+
parser.add_argument(
49+
"--mem-fraction-static",
50+
type=float,
51+
default=None,
52+
help="Fraction of GPU memory for KV cache",
53+
)
4354
return parser.parse_args()
4455

4556

4657
async def main_async(args: argparse.Namespace) -> None:
58+
overrides = {}
59+
if args.cpu_offload_gb:
60+
overrides["cpu_offload_gb"] = args.cpu_offload_gb
61+
if args.mem_fraction_static is not None:
62+
overrides["mem_fraction_static"] = args.mem_fraction_static
63+
4764
config = Qwen3OmniPipelineConfig(
4865
model_path=args.model_path,
4966
relay_backend=args.relay_backend,
67+
server_args_overrides=overrides or None,
5068
)
69+
70+
# Override thinker_max_seq_len in stage executor args if provided
71+
if args.thinker_max_seq_len is not None:
72+
for stage in config.stages:
73+
if stage.name == "thinker":
74+
if stage.executor.args is None:
75+
stage.executor.args = {}
76+
stage.executor.args["thinker_max_seq_len"] = args.thinker_max_seq_len
77+
5178
runner = build_pipeline_runner(config)
5279

5380
await runner.start()

scripts/test_ming_tp.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/usr/bin/env python3
2+
"""Validate Ming Omni thinker output consistency across TP configurations.
3+
4+
Usage:
5+
python scripts/test_ming_tp.py run --tp 1 --cpu-offload-gb 150
6+
python scripts/test_ming_tp.py run --tp 2 --cpu-offload-gb 40
7+
python scripts/test_ming_tp.py compare tp1_results.json tp2_results.json
8+
"""
9+
from __future__ import annotations
10+
11+
import argparse
12+
import asyncio
13+
import json
14+
import logging
15+
import multiprocessing as mp
16+
import os
17+
import sys
18+
19+
logging.basicConfig(
20+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
21+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
22+
)
23+
logger = logging.getLogger(__name__)
24+
25+
TEST_PROMPTS = [
26+
"What is 1+1?",
27+
"What is the capital of France?",
28+
"What is the capital of Japan?",
29+
"Explain quantum computing in one sentence.",
30+
]
31+
32+
33+
async def run_thinker(
34+
tp_size: int,
35+
cpu_offload_gb: int,
36+
mem_fraction: float,
37+
output_file: str,
38+
attention_backend: str | None = None,
39+
):
40+
from sglang_omni.models.ming_omni.config import MingOmniPipelineConfig
41+
from sglang_omni.pipeline.mp_runner import MultiProcessPipelineRunner
42+
from sglang_omni.proto import OmniRequest
43+
44+
overrides = {
45+
"tp_size": tp_size,
46+
"cpu_offload_gb": cpu_offload_gb,
47+
"mem_fraction_static": mem_fraction,
48+
}
49+
if attention_backend is not None:
50+
overrides["attention_backend"] = attention_backend
51+
52+
config = MingOmniPipelineConfig(
53+
model_path="inclusionAI/Ming-flash-omni-2.0",
54+
relay_backend="shm",
55+
server_args_overrides=overrides,
56+
)
57+
58+
runner = MultiProcessPipelineRunner(config)
59+
logger.info(
60+
"Starting pipeline with TP=%d, cpu_offload_gb=%d, attention_backend=%s ...",
61+
tp_size,
62+
cpu_offload_gb,
63+
attention_backend,
64+
)
65+
await runner.start(timeout=600)
66+
67+
results = []
68+
try:
69+
for i, prompt in enumerate(TEST_PROMPTS):
70+
logger.info("[%d/%d] Prompt: %s", i + 1, len(TEST_PROMPTS), prompt)
71+
request = {
72+
"messages": [
73+
{
74+
"role": "system",
75+
"content": "You are a friendly AI assistant. Please answer concisely.",
76+
},
77+
{"role": "user", "content": prompt},
78+
],
79+
"audios": [],
80+
}
81+
result = await asyncio.wait_for(
82+
runner.coordinator.submit(
83+
f"tp-test-{i}",
84+
OmniRequest(
85+
inputs=request,
86+
params={"max_new_tokens": 64, "temperature": 0.0},
87+
),
88+
),
89+
timeout=120,
90+
)
91+
text = ""
92+
if isinstance(result, dict):
93+
for stage_name, payload in result.items():
94+
data = (
95+
payload
96+
if isinstance(payload, dict)
97+
else getattr(payload, "data", {})
98+
)
99+
if isinstance(data, dict) and "text" in data:
100+
text = data["text"]
101+
break
102+
assert text, f"Empty output for prompt: {prompt}"
103+
results.append({"prompt": prompt, "output": text})
104+
logger.info(" Output: %s", text[:200])
105+
finally:
106+
await runner.stop()
107+
108+
with open(output_file, "w") as f:
109+
json.dump(
110+
{"tp_size": tp_size, "results": results}, f, indent=2, ensure_ascii=False
111+
)
112+
logger.info("Results saved to %s", output_file)
113+
114+
115+
def compare_outputs(file1: str, file2: str):
116+
with open(file1) as f:
117+
data1 = json.load(f)
118+
with open(file2) as f:
119+
data2 = json.load(f)
120+
121+
print(f"\n{'='*60}")
122+
print(f"Comparing TP={data1['tp_size']} vs TP={data2['tp_size']}")
123+
print(f"{'='*60}")
124+
125+
all_match = True
126+
for r1, r2 in zip(data1["results"], data2["results"]):
127+
match = r1["output"].strip() == r2["output"].strip()
128+
status = "MATCH" if match else "MISMATCH"
129+
if not match:
130+
all_match = False
131+
print(f"\n[{status}] Prompt: {r1['prompt']}")
132+
print(f" TP={data1['tp_size']}: {r1['output'][:120]}")
133+
print(f" TP={data2['tp_size']}: {r2['output'][:120]}")
134+
135+
print(f"\n{'='*60}")
136+
if all_match:
137+
print("ALL OUTPUTS MATCH - TP validation PASSED")
138+
else:
139+
print("OUTPUTS DIFFER - TP validation FAILED, needs investigation")
140+
print(f"{'='*60}")
141+
return all_match
142+
143+
144+
def main():
145+
mp.set_start_method("spawn", force=True)
146+
147+
parser = argparse.ArgumentParser(description=__doc__)
148+
sub = parser.add_subparsers(dest="cmd")
149+
150+
run_p = sub.add_parser("run")
151+
run_p.add_argument("--tp", type=int, required=True)
152+
run_p.add_argument("--cpu-offload-gb", type=int, default=80)
153+
run_p.add_argument("--mem-fraction", type=float, default=0.80)
154+
run_p.add_argument("--attention-backend", type=str, default=None)
155+
run_p.add_argument("--output", type=str, default=None)
156+
157+
cmp_p = sub.add_parser("compare")
158+
cmp_p.add_argument("file1")
159+
cmp_p.add_argument("file2")
160+
161+
args = parser.parse_args()
162+
163+
if args.cmd == "run":
164+
output = args.output or f"tp{args.tp}_results.json"
165+
asyncio.run(
166+
run_thinker(
167+
args.tp,
168+
args.cpu_offload_gb,
169+
args.mem_fraction,
170+
output,
171+
args.attention_backend,
172+
)
173+
)
174+
elif args.cmd == "compare":
175+
sys.exit(0 if compare_outputs(args.file1, args.file2) else 1)
176+
else:
177+
parser.print_help()
178+
179+
180+
if __name__ == "__main__":
181+
main()

sglang_omni/engines/ar/sglang_backend/model_runner.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def __init__(
3939
weight_prefix: str | None = None,
4040
) -> None:
4141
self._weight_prefix = weight_prefix
42-
self._register_omni_model()
42+
from sglang_omni.models.sglang_registry import register_omni_models_in_sglang
43+
44+
register_omni_models_in_sglang()
4345

4446
port_args = PortArgs.init_new(server_args)
4547
tp_size = server_args.tp_size
@@ -61,26 +63,3 @@ def __init__(
6163
nccl_port=nccl_port,
6264
server_args=server_args,
6365
)
64-
65-
def _register_omni_model(self):
66-
# Register sglang_omni model classes directly in SGLang's model registry.
67-
from sglang.srt.models.registry import ModelRegistry
68-
69-
from sglang_omni.models.fishaudio_s2_pro.sglang_model import (
70-
S2ProSGLangTextModel,
71-
)
72-
from sglang_omni.models.ming_omni.thinker import (
73-
BailingMM2Config,
74-
BailingMoeV2ForCausalLM,
75-
)
76-
from sglang_omni.models.qwen3_omni.talker import Qwen3OmniTalker
77-
78-
ModelRegistry.models["S2ProSGLangTextModel"] = S2ProSGLangTextModel
79-
ModelRegistry.models["Qwen3OmniTalker"] = Qwen3OmniTalker
80-
ModelRegistry.models["BailingMoeV2ForCausalLM"] = BailingMoeV2ForCausalLM
81-
82-
# Register BailingMM2Config with AutoConfig so SGLang can load
83-
# config.json from HF repos missing configuration_bailingmm2.py.
84-
from transformers import AutoConfig
85-
86-
AutoConfig.register("bailingmm_moe_v2_lite", BailingMM2Config)

sglang_omni/engines/ar/sglang_backend/model_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def __init__(
5353
)[0]
5454
set_random_seed(self.random_seed)
5555

56+
@property
57+
def tp_cpu_group(self):
58+
"""NCCL CPU process group for TP broadcast operations."""
59+
return self.model_runner.tp_group.cpu_group
60+
61+
@property
62+
def tp_size(self) -> int:
63+
return self.server_args.tp_size
64+
5665
def _init_model_config(self):
5766
from sglang.srt.configs.model_config import ModelConfig
5867

0 commit comments

Comments
 (0)