Skip to content

added support for tied weights in qwen pipeline parallelism #6546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
bash scripts/ci_install_dependency.sh

- name: Run test
timeout-minutes: 25
timeout-minutes: 30
run: |
cd test/srt
python3 run_suite.py --suite per-commit-2-gpu
Expand Down
47 changes: 38 additions & 9 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,15 +386,36 @@ def __init__(
self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens

# handle the lm head on different pp ranks
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()

# perform weight tying for PP
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
if self.pp_group.is_first_rank:
self.pp_group.send(
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
)
else:
emb_token_weight = self.pp_group.recv(
size=(config.vocab_size, config.hidden_size),
dtype=next(self.model.parameters()).dtype,
src=self.pp_group.first_rank,
)
self.lm_head.weight.copy_(emb_token_weight)

self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

Expand Down Expand Up @@ -470,7 +491,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# the checkpoint. Skip them.
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights = next(
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
)[1]
loaded_weight = embed_token_weights
else:
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue

Expand Down
49 changes: 39 additions & 10 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -249,15 +249,36 @@ def __init__(
self.model = Qwen3Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens

# handle the lm head on different pp ranks
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()

# perform weight tying for PP
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
if self.pp_group.is_first_rank:
self.pp_group.send(
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
)
else:
emb_token_weight = self.pp_group.recv(
size=(config.vocab_size, config.hidden_size),
dtype=next(self.model.parameters()).dtype,
src=self.pp_group.first_rank,
)
self.lm_head.weight.copy_(emb_token_weight)

self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

Expand Down Expand Up @@ -330,7 +351,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# the checkpoint. Skip them.
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights = next(
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
)[1]
loaded_weight = embed_token_weights
else:
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue

Expand Down
56 changes: 56 additions & 0 deletions test/srt/test_pp_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,62 @@ def test_pp_consistency(self):
)


class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
cls.model_name = (
"Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True
)

def run_gsm8k_test(self, pp_size):
process = popen_launch_server(
self.model_name,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--pp-size",
pp_size,
"--chunked-prefill-size",
256,
],
)

try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
time.sleep(5)
return metrics
finally:
kill_process_tree(process.pid)

def test_baseline_accuracy(self):
metrics = self.run_gsm8k_test(pp_size=1)
print(f"[Qwen Baseline] {metrics=}")
self.assertGreater(metrics["accuracy"], 0.39)

def test_pp_consistency(self):
baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2)

print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")

self.assertAlmostEqual(
pp_metrics["accuracy"],
baseline["accuracy"],
delta=0.01,
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
)


class TestFixedBugs(unittest.TestCase):
def test_chunked_prefill_with_small_bs(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
Expand Down
Loading