diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 9dd1dbc662..62471c13cb 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -84,7 +84,7 @@ jobs: bash scripts/ci_install_dependency.sh - name: Run test - timeout-minutes: 25 + timeout-minutes: 30 run: | cd test/srt python3 run_suite.py --suite per-commit-2-gpu diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 0e9a77f04f..10ac84eccc 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -386,15 +386,36 @@ def __init__( self.model = Qwen2Model( config, quant_config=quant_config, prefix=add_prefix("model", prefix) ) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + + # handle the lm head on different pp ranks + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), - ) + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() + + # perform weight tying for PP + if self.pp_group.world_size > 1 and config.tie_word_embeddings: + if self.pp_group.is_first_rank: + self.pp_group.send( + self.model.embed_tokens.weight, dst=self.pp_group.last_rank + ) + else: + emb_token_weight = self.pp_group.recv( + size=(config.vocab_size, config.hidden_size), + dtype=next(self.model.parameters()).dtype, + src=self.pp_group.first_rank, + ) + self.lm_head.weight.copy_(emb_token_weight) + self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -470,7 +491,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # the checkpoint. Skip them. continue if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue + if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: + # Handle pp weight tying here + # find the embed_tokens.weight in the weights + embed_token_weights = next( + filter(lambda x: x[0] == "model.embed_tokens.weight", weights) + )[1] + loaded_weight = embed_token_weights + else: + continue if name.startswith("model.vision_tower") and name not in params_dict: continue diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 181802a091..de7db4c32e 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -21,7 +21,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -249,15 +249,36 @@ def __init__( self.model = Qwen3Model( config, quant_config=quant_config, prefix=add_prefix("model", prefix) ) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + + # handle the lm head on different pp ranks + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), - ) + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() + + # perform weight tying for PP + if self.pp_group.world_size > 1 and config.tie_word_embeddings: + if self.pp_group.is_first_rank: + self.pp_group.send( + self.model.embed_tokens.weight, dst=self.pp_group.last_rank + ) + else: + emb_token_weight = self.pp_group.recv( + size=(config.vocab_size, config.hidden_size), + dtype=next(self.model.parameters()).dtype, + src=self.pp_group.first_rank, + ) + self.lm_head.weight.copy_(emb_token_weight) + self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -330,7 +351,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # the checkpoint. Skip them. continue if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue + if self.pp_group.world_size > 1 and self.pp_group.is_last_rank: + # Handle pp weight tying here + # find the embed_tokens.weight in the weights + embed_token_weights = next( + filter(lambda x: x[0] == "model.embed_tokens.weight", weights) + )[1] + loaded_weight = embed_token_weights + else: + continue if name.startswith("model.vision_tower") and name not in params_dict: continue diff --git a/test/srt/test_pp_single_node.py b/test/srt/test_pp_single_node.py index b7fdae2d60..51cc981081 100644 --- a/test/srt/test_pp_single_node.py +++ b/test/srt/test_pp_single_node.py @@ -116,6 +116,62 @@ def test_pp_consistency(self): ) +class TestQwenPPTieWeightsAccuracy(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts + cls.model_name = ( + "Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True + ) + + def run_gsm8k_test(self, pp_size): + process = popen_launch_server( + self.model_name, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--pp-size", + pp_size, + "--chunked-prefill-size", + 256, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + time.sleep(5) + return metrics + finally: + kill_process_tree(process.pid) + + def test_baseline_accuracy(self): + metrics = self.run_gsm8k_test(pp_size=1) + print(f"[Qwen Baseline] {metrics=}") + self.assertGreater(metrics["accuracy"], 0.39) + + def test_pp_consistency(self): + baseline = self.run_gsm8k_test(pp_size=1) + pp_metrics = self.run_gsm8k_test(pp_size=2) + + print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") + + self.assertAlmostEqual( + pp_metrics["accuracy"], + baseline["accuracy"], + delta=0.01, + msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})", + ) + + class TestFixedBugs(unittest.TestCase): def test_chunked_prefill_with_small_bs(self): model = DEFAULT_MODEL_NAME_FOR_TEST