Skip to content

Commit a15fc00

Browse files
Add Qwen3 Maxtext to vLLM weight mapping
PiperOrigin-RevId: 822234220
1 parent 42cb7ed commit a15fc00

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@
2020
"""
2121

2222
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
23+
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2324

2425

2526
class VLLM_WEIGHT_MAPPING:
26-
"""Mapping MaxText model weights to vLLM's model weights"""
27+
"""Mapping MaxText model weights to vLLM's model weights."""
2728

2829
def __getattr__(self, name):
2930
if name.startswith("llama3.1"):
3031
return LLAMA3_VLLM_MAPPING
32+
elif name.startswith("qwen3"):
33+
return QWEN3_VLLM_MAPPING
3134
else:
3235
raise ValueError(f"{name} vLLM weight mapping not found.")
3336

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines the weight mapping from MaxText's Qwen3 model to a vLLM-compatible format.
16+
17+
This module provides the `QWEN3_VLLM_MAPPING` dataclass, which contains all the
18+
necessary configurations to convert MaxText's Qwen3 model weights into a
19+
format that can be loaded by HuggingFace's vLLM. This includes:
20+
- A direct mapping of parameter names.
21+
- Sharding specifications for distributed environments.
22+
"""
23+
24+
from dataclasses import dataclass
25+
26+
27+
@dataclass
28+
class QWEN3_VLLM_MAPPING:
29+
"""Mapping MaxText Qwen3-8 weights to vLLM's Qwen3-8 weights."""
30+
31+
@staticmethod
32+
def to_hf_hook_fns():
33+
"""Returns a dictionary of hook functions to be applied to MaxText weights.
34+
35+
Returns:
36+
An empty dictionary, as no hook functions are needed for this mapping.
37+
"""
38+
39+
return {}
40+
41+
@staticmethod
42+
def to_hf_transpose_keys():
43+
"""Returns a list of keys for weights that need to be transposed.
44+
45+
Returns:
46+
An empty dictionary, as no keys require transposition for this mapping.
47+
"""
48+
return {}
49+
50+
@staticmethod
51+
def lora_to_hf_mappings():
52+
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.
53+
54+
Returns:
55+
None, as LoRA mappings are not defined for this model.
56+
"""
57+
return None
58+
59+
@staticmethod
60+
def to_hf_mapping():
61+
"""Mapping from MaxText model to HuggingFace vLLM model.
62+
63+
Currently, the param mapping conforms to the Tunix API, which combines the
64+
param name & sharding in one dictionary.
65+
This is subject to change in the future where we can decouple the two.
66+
"""
67+
return {
68+
# Token embeddings - shard vocab dimension
69+
"base.token_embedder.embedding": (
70+
"model.embed.embedding",
71+
("model", None),
72+
),
73+
# Final layer norm - no sharding needed
74+
"base.decoder.decoder_norm.scale": (
75+
"model.norm.scale",
76+
(None,),
77+
),
78+
# LM head (logits projection) - shard vocab dimension
79+
"base.decoder.logits_dense.kernel": (
80+
"model.lm_head",
81+
(None, "model"),
82+
),
83+
# Layer-specific mappings (scanned -> unscanned)
84+
# MLP components - shard hidden dimensions
85+
"base.decoder.layers.mlp.wi_0.kernel": (
86+
"model.layers.*.mlp.gate_proj.kernel",
87+
(None, "layer", "model"),
88+
),
89+
"base.decoder.layers.mlp.wi_1.kernel": (
90+
"model.layers.*.mlp.up_proj.kernel",
91+
(None, "layer", "model"),
92+
),
93+
"base.decoder.layers.mlp.wo.kernel": (
94+
"model.layers.*.mlp.down_proj.kernel",
95+
("model", "layer", None),
96+
),
97+
# Layer norms - no sharding needed
98+
"base.decoder.layers.pre_self_attention_layer_norm.scale": (
99+
"model.layers.*.input_layernorm.scale",
100+
(None, "layer"),
101+
),
102+
"base.decoder.layers.post_self_attention_layer_norm.scale": (
103+
"model.layers.*.post_attention_layernorm.scale",
104+
(None, "layer"),
105+
),
106+
# Attention components - shard head dimensions
107+
"base.decoder.layers.self_attention.query.kernel": (
108+
"model.layers.*.self_attn.q_proj.kernel",
109+
(None, "layer", "model", None),
110+
),
111+
"base.decoder.layers.self_attention.key.kernel": (
112+
"model.layers.*.self_attn.k_proj.kernel",
113+
(None, "layer", "model", None),
114+
),
115+
"base.decoder.layers.self_attention.value.kernel": (
116+
"model.layers.*.self_attn.v_proj.kernel",
117+
(None, "layer", "model", None),
118+
),
119+
"base.decoder.layers.self_attention.out.kernel": (
120+
"model.layers.*.self_attn.o_proj.kernel",
121+
("model", "layer", None, None),
122+
),
123+
"base.decoder.layers.self_attention.query_norm.scale": (
124+
"model.layers.*.self_attn.q_norm.scale",
125+
(None, "layer"),
126+
),
127+
"base.decoder.layers.self_attention.key_norm.scale": (
128+
"model.layers.*.self_attn.k_norm.scale",
129+
(None, "layer"),
130+
),
131+
}

0 commit comments

Comments
 (0)