Skip to content

Commit 3f23d8c

Browse files
authored
added support for tied weights in qwen pipeline parallelism (#6546)
1 parent 1a39979 commit 3f23d8c

File tree

4 files changed

+134
-20
lines changed

4 files changed

+134
-20
lines changed

.github/workflows/pr-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ jobs:
8484
bash scripts/ci_install_dependency.sh
8585
8686
- name: Run test
87-
timeout-minutes: 25
87+
timeout-minutes: 30
8888
run: |
8989
cd test/srt
9090
python3 run_suite.py --suite per-commit-2-gpu

python/sglang/srt/models/qwen2.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -386,15 +386,36 @@ def __init__(
386386
self.model = Qwen2Model(
387387
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
388388
)
389-
if config.tie_word_embeddings:
390-
self.lm_head = self.model.embed_tokens
389+
390+
# handle the lm head on different pp ranks
391+
if self.pp_group.is_last_rank:
392+
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
393+
self.lm_head = self.model.embed_tokens
394+
else:
395+
self.lm_head = ParallelLMHead(
396+
config.vocab_size,
397+
config.hidden_size,
398+
quant_config=quant_config,
399+
prefix=add_prefix("lm_head", prefix),
400+
)
391401
else:
392-
self.lm_head = ParallelLMHead(
393-
config.vocab_size,
394-
config.hidden_size,
395-
quant_config=quant_config,
396-
prefix=add_prefix("lm_head", prefix),
397-
)
402+
# ranks other than the last rank will have a placeholder layer
403+
self.lm_head = PPMissingLayer()
404+
405+
# perform weight tying for PP
406+
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
407+
if self.pp_group.is_first_rank:
408+
self.pp_group.send(
409+
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
410+
)
411+
else:
412+
emb_token_weight = self.pp_group.recv(
413+
size=(config.vocab_size, config.hidden_size),
414+
dtype=next(self.model.parameters()).dtype,
415+
src=self.pp_group.first_rank,
416+
)
417+
self.lm_head.weight.copy_(emb_token_weight)
418+
398419
self.logits_processor = LogitsProcessor(config)
399420
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
400421

@@ -470,7 +491,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
470491
# the checkpoint. Skip them.
471492
continue
472493
if self.config.tie_word_embeddings and "lm_head.weight" in name:
473-
continue
494+
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
495+
# Handle pp weight tying here
496+
# find the embed_tokens.weight in the weights
497+
embed_token_weights = next(
498+
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
499+
)[1]
500+
loaded_weight = embed_token_weights
501+
else:
502+
continue
474503
if name.startswith("model.vision_tower") and name not in params_dict:
475504
continue
476505

python/sglang/srt/models/qwen3.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sglang.srt.layers.quantization.base_config import QuantizationConfig
2222
from sglang.srt.layers.radix_attention import RadixAttention
2323
from sglang.srt.layers.rotary_embedding import get_rope
24-
from sglang.srt.layers.utils import get_layer_id
24+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
2525
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
2626
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
2727
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -249,15 +249,36 @@ def __init__(
249249
self.model = Qwen3Model(
250250
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
251251
)
252-
if config.tie_word_embeddings:
253-
self.lm_head = self.model.embed_tokens
252+
253+
# handle the lm head on different pp ranks
254+
if self.pp_group.is_last_rank:
255+
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
256+
self.lm_head = self.model.embed_tokens
257+
else:
258+
self.lm_head = ParallelLMHead(
259+
config.vocab_size,
260+
config.hidden_size,
261+
quant_config=quant_config,
262+
prefix=add_prefix("lm_head", prefix),
263+
)
254264
else:
255-
self.lm_head = ParallelLMHead(
256-
config.vocab_size,
257-
config.hidden_size,
258-
quant_config=quant_config,
259-
prefix=add_prefix("lm_head", prefix),
260-
)
265+
# ranks other than the last rank will have a placeholder layer
266+
self.lm_head = PPMissingLayer()
267+
268+
# perform weight tying for PP
269+
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
270+
if self.pp_group.is_first_rank:
271+
self.pp_group.send(
272+
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
273+
)
274+
else:
275+
emb_token_weight = self.pp_group.recv(
276+
size=(config.vocab_size, config.hidden_size),
277+
dtype=next(self.model.parameters()).dtype,
278+
src=self.pp_group.first_rank,
279+
)
280+
self.lm_head.weight.copy_(emb_token_weight)
281+
261282
self.logits_processor = LogitsProcessor(config)
262283
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
263284

@@ -330,7 +351,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
330351
# the checkpoint. Skip them.
331352
continue
332353
if self.config.tie_word_embeddings and "lm_head.weight" in name:
333-
continue
354+
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
355+
# Handle pp weight tying here
356+
# find the embed_tokens.weight in the weights
357+
embed_token_weights = next(
358+
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
359+
)[1]
360+
loaded_weight = embed_token_weights
361+
else:
362+
continue
334363
if name.startswith("model.vision_tower") and name not in params_dict:
335364
continue
336365

test/srt/test_pp_single_node.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,62 @@ def test_pp_consistency(self):
116116
)
117117

118118

119+
class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
120+
@classmethod
121+
def setUpClass(cls):
122+
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
123+
cls.model_name = (
124+
"Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True
125+
)
126+
127+
def run_gsm8k_test(self, pp_size):
128+
process = popen_launch_server(
129+
self.model_name,
130+
self.base_url,
131+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
132+
other_args=[
133+
"--pp-size",
134+
pp_size,
135+
"--chunked-prefill-size",
136+
256,
137+
],
138+
)
139+
140+
try:
141+
args = SimpleNamespace(
142+
num_shots=5,
143+
data_path=None,
144+
num_questions=200,
145+
max_new_tokens=512,
146+
parallel=128,
147+
host="http://127.0.0.1",
148+
port=int(self.base_url.split(":")[-1]),
149+
)
150+
metrics = run_eval(args)
151+
time.sleep(5)
152+
return metrics
153+
finally:
154+
kill_process_tree(process.pid)
155+
156+
def test_baseline_accuracy(self):
157+
metrics = self.run_gsm8k_test(pp_size=1)
158+
print(f"[Qwen Baseline] {metrics=}")
159+
self.assertGreater(metrics["accuracy"], 0.39)
160+
161+
def test_pp_consistency(self):
162+
baseline = self.run_gsm8k_test(pp_size=1)
163+
pp_metrics = self.run_gsm8k_test(pp_size=2)
164+
165+
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
166+
167+
self.assertAlmostEqual(
168+
pp_metrics["accuracy"],
169+
baseline["accuracy"],
170+
delta=0.01,
171+
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
172+
)
173+
174+
119175
class TestFixedBugs(unittest.TestCase):
120176
def test_chunked_prefill_with_small_bs(self):
121177
model = DEFAULT_MODEL_NAME_FOR_TEST

0 commit comments

Comments
 (0)