|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Defines the weight mapping from MaxText's Qwen3 model to a vLLM-compatible format. |
| 16 | +
|
| 17 | +This module provides the `QWEN3_VLLM_MAPPING` dataclass, which contains all the |
| 18 | +necessary configurations to convert MaxText's Qwen3 model weights into a |
| 19 | +format that can be loaded by HuggingFace's vLLM. This includes: |
| 20 | +- A direct mapping of parameter names. |
| 21 | +- Sharding specifications for distributed environments. |
| 22 | +""" |
| 23 | + |
| 24 | +from dataclasses import dataclass |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class QWEN3_VLLM_MAPPING: |
| 29 | + """Mapping MaxText Qwen3-8 weights to vLLM's Qwen3-8 weights.""" |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def to_hf_hook_fns(): |
| 33 | + """Returns a dictionary of hook functions to be applied to MaxText weights. |
| 34 | +
|
| 35 | + Returns: |
| 36 | + An empty dictionary, as no hook functions are needed for this mapping. |
| 37 | + """ |
| 38 | + |
| 39 | + return {} |
| 40 | + |
| 41 | + @staticmethod |
| 42 | + def to_hf_transpose_keys(): |
| 43 | + """Returns a list of keys for weights that need to be transposed. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + An empty dictionary, as no keys require transposition for this mapping. |
| 47 | + """ |
| 48 | + return {} |
| 49 | + |
| 50 | + @staticmethod |
| 51 | + def lora_to_hf_mappings(): |
| 52 | + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. |
| 53 | +
|
| 54 | + Returns: |
| 55 | + None, as LoRA mappings are not defined for this model. |
| 56 | + """ |
| 57 | + return None |
| 58 | + |
| 59 | + @staticmethod |
| 60 | + def to_hf_mapping(): |
| 61 | + """Mapping from MaxText model to HuggingFace vLLM model. |
| 62 | +
|
| 63 | + Currently, the param mapping conforms to the Tunix API, which combines the |
| 64 | + param name & sharding in one dictionary. |
| 65 | + This is subject to change in the future where we can decouple the two. |
| 66 | + """ |
| 67 | + return { |
| 68 | + # Token embeddings - shard vocab dimension |
| 69 | + "base.token_embedder.embedding": ( |
| 70 | + "model.embed.embedding", |
| 71 | + ("model", None), |
| 72 | + ), |
| 73 | + # Final layer norm - no sharding needed |
| 74 | + "base.decoder.decoder_norm.scale": ( |
| 75 | + "model.norm.scale", |
| 76 | + (None,), |
| 77 | + ), |
| 78 | + # LM head (logits projection) - shard vocab dimension |
| 79 | + "base.decoder.logits_dense.kernel": ( |
| 80 | + "model.lm_head", |
| 81 | + (None, "model"), |
| 82 | + ), |
| 83 | + # Layer-specific mappings (scanned -> unscanned) |
| 84 | + # MLP components - shard hidden dimensions |
| 85 | + "base.decoder.layers.mlp.wi_0.kernel": ( |
| 86 | + "model.layers.*.mlp.gate_proj.kernel", |
| 87 | + (None, "layer", "model"), |
| 88 | + ), |
| 89 | + "base.decoder.layers.mlp.wi_1.kernel": ( |
| 90 | + "model.layers.*.mlp.up_proj.kernel", |
| 91 | + (None, "layer", "model"), |
| 92 | + ), |
| 93 | + "base.decoder.layers.mlp.wo.kernel": ( |
| 94 | + "model.layers.*.mlp.down_proj.kernel", |
| 95 | + ("model", "layer", None), |
| 96 | + ), |
| 97 | + # Layer norms - no sharding needed |
| 98 | + "base.decoder.layers.pre_self_attention_layer_norm.scale": ( |
| 99 | + "model.layers.*.input_layernorm.scale", |
| 100 | + (None, "layer"), |
| 101 | + ), |
| 102 | + "base.decoder.layers.post_self_attention_layer_norm.scale": ( |
| 103 | + "model.layers.*.post_attention_layernorm.scale", |
| 104 | + (None, "layer"), |
| 105 | + ), |
| 106 | + # Attention components - shard head dimensions |
| 107 | + "base.decoder.layers.self_attention.query.kernel": ( |
| 108 | + "model.layers.*.self_attn.q_proj.kernel", |
| 109 | + (None, "layer", "model", None), |
| 110 | + ), |
| 111 | + "base.decoder.layers.self_attention.key.kernel": ( |
| 112 | + "model.layers.*.self_attn.k_proj.kernel", |
| 113 | + (None, "layer", "model", None), |
| 114 | + ), |
| 115 | + "base.decoder.layers.self_attention.value.kernel": ( |
| 116 | + "model.layers.*.self_attn.v_proj.kernel", |
| 117 | + (None, "layer", "model", None), |
| 118 | + ), |
| 119 | + "base.decoder.layers.self_attention.out.kernel": ( |
| 120 | + "model.layers.*.self_attn.o_proj.kernel", |
| 121 | + ("model", "layer", None, None), |
| 122 | + ), |
| 123 | + "base.decoder.layers.self_attention.query_norm.scale": ( |
| 124 | + "model.layers.*.self_attn.q_norm.scale", |
| 125 | + (None, "layer"), |
| 126 | + ), |
| 127 | + "base.decoder.layers.self_attention.key_norm.scale": ( |
| 128 | + "model.layers.*.self_attn.k_norm.scale", |
| 129 | + (None, "layer"), |
| 130 | + ), |
| 131 | + } |
0 commit comments