|
25 | 25 | from jax.interpreters import pxla |
26 | 26 | import jax.sharding as shd |
27 | 27 | import jaxtyping |
| 28 | +import numpy as np |
28 | 29 | from tunix.generate.mappings import BackendMappingMixin |
29 | 30 | from tunix.models.qwen3vl.vision import VisionEmbeddings |
30 | 31 | from tunix.models.qwen3vl.vision import VisionGridData |
@@ -107,116 +108,6 @@ class ModelConfig: |
107 | 108 | param_dtype: jnp.dtype = jnp.bfloat16 |
108 | 109 | vision_config: VisionModelConfig | None = None |
109 | 110 |
|
110 | | - @classmethod |
111 | | - def qwen3_0p6b(cls): # qwen3-0.6B |
112 | | - return cls( |
113 | | - num_layers=28, |
114 | | - vocab_size=151936, |
115 | | - embed_dim=1024, |
116 | | - hidden_dim=3072, |
117 | | - num_heads=16, |
118 | | - head_dim=128, |
119 | | - num_kv_heads=8, |
120 | | - norm_eps=1e-06, |
121 | | - rope_theta=1_000_000, |
122 | | - ) |
123 | | - |
124 | | - @classmethod |
125 | | - def qwen3_1p7b(cls): # qwen3-1.7B |
126 | | - return cls( |
127 | | - num_layers=28, |
128 | | - vocab_size=151936, |
129 | | - embed_dim=2048, |
130 | | - hidden_dim=6144, |
131 | | - num_heads=16, |
132 | | - head_dim=128, |
133 | | - num_kv_heads=8, |
134 | | - norm_eps=1e-06, |
135 | | - rope_theta=1_000_000, |
136 | | - ) |
137 | | - |
138 | | - @classmethod |
139 | | - def qwen3_4b(cls): # qwen3-4B |
140 | | - return cls( |
141 | | - num_layers=36, |
142 | | - vocab_size=151936, |
143 | | - embed_dim=2560, |
144 | | - hidden_dim=9728, |
145 | | - num_heads=32, |
146 | | - head_dim=128, |
147 | | - num_kv_heads=8, |
148 | | - norm_eps=1e-06, |
149 | | - rope_theta=1_000_000, |
150 | | - use_tied_embedding=True, |
151 | | - ) |
152 | | - |
153 | | - @classmethod |
154 | | - def _qwen3_4b_2507(cls): # Qwen3-4B-Instruct-2507 and Qwen3-4B-Thinking-2507 |
155 | | - return cls( |
156 | | - num_layers=36, |
157 | | - vocab_size=151936, |
158 | | - embed_dim=2560, |
159 | | - hidden_dim=9728, |
160 | | - num_heads=32, |
161 | | - head_dim=128, |
162 | | - num_kv_heads=8, |
163 | | - norm_eps=1e-06, |
164 | | - rope_theta=5_000_000, |
165 | | - use_tied_embedding=True, |
166 | | - ) |
167 | | - |
168 | | - @classmethod |
169 | | - def qwen3_4b_instruct_2507(cls): # Qwen3-4B-Instruct-2507 |
170 | | - return cls._qwen3_4b_2507() |
171 | | - |
172 | | - @classmethod |
173 | | - def qwen3_4b_thinking_2507(cls): # Qwen3-4B-Thinking-2507 |
174 | | - return cls._qwen3_4b_2507() |
175 | | - |
176 | | - @classmethod |
177 | | - def qwen3_8b(cls): # qwen3-8B |
178 | | - return cls( |
179 | | - num_layers=36, |
180 | | - vocab_size=151936, |
181 | | - embed_dim=4096, |
182 | | - hidden_dim=12288, |
183 | | - num_heads=32, |
184 | | - head_dim=128, |
185 | | - num_kv_heads=8, |
186 | | - norm_eps=1e-06, |
187 | | - rope_theta=1_000_000, |
188 | | - ) |
189 | | - |
190 | | - @classmethod |
191 | | - def qwen3_14b(cls): # qwen3-14B |
192 | | - return cls( |
193 | | - num_layers=40, |
194 | | - vocab_size=151936, |
195 | | - embed_dim=5120, |
196 | | - hidden_dim=17408, |
197 | | - num_heads=40, |
198 | | - head_dim=128, |
199 | | - num_kv_heads=8, |
200 | | - norm_eps=1e-06, |
201 | | - rope_theta=1_000_000, |
202 | | - ) |
203 | | - |
204 | | - @classmethod |
205 | | - def qwen3_30b_a3b(cls): # qwen3-30B-a3b |
206 | | - return cls( |
207 | | - num_layers=48, |
208 | | - vocab_size=151936, |
209 | | - embed_dim=2048, |
210 | | - hidden_dim=768, |
211 | | - num_heads=32, |
212 | | - head_dim=128, |
213 | | - num_kv_heads=4, |
214 | | - norm_eps=1e-06, |
215 | | - rope_theta=1_000_000, |
216 | | - num_experts=128, |
217 | | - num_experts_per_tok=8, |
218 | | - ) |
219 | | - |
220 | 111 | @classmethod |
221 | 112 | def qwen3vl_4b(cls): # qwen3-vl-4b |
222 | 113 | return cls( |
@@ -734,86 +625,6 @@ def num_kv_heads(self): |
734 | 625 | return self.k_proj.shape[1] |
735 | 626 |
|
736 | 627 |
|
737 | | -class MoELayer(nnx.Module): |
738 | | - """MoE layer.""" |
739 | | - |
740 | | - def __init__( |
741 | | - self, |
742 | | - config: ModelConfig, |
743 | | - *, |
744 | | - rngs: nnx.Rngs, |
745 | | - param_dtype: jnp.dtype = jnp.bfloat16, |
746 | | - ): |
747 | | - self.shd_config = config.shd_config |
748 | | - self.experts_per_tok = config.num_experts_per_tok |
749 | | - self.num_experts = config.num_experts |
750 | | - self.router = nnx.Linear( |
751 | | - in_features=config.embed_dim, |
752 | | - out_features=config.num_experts, |
753 | | - use_bias=False, |
754 | | - rngs=rngs, |
755 | | - param_dtype=param_dtype, |
756 | | - ) |
757 | | - self.gate_proj = nnx.Param( |
758 | | - nnx.initializers.normal(dtype=param_dtype)( |
759 | | - rngs.params(), |
760 | | - (config.num_experts, config.embed_dim, config.hidden_dim), |
761 | | - ), |
762 | | - sharding=self.shd_config.exp_weight_cdf, |
763 | | - ) |
764 | | - self.up_proj = nnx.Param( |
765 | | - nnx.initializers.normal(dtype=param_dtype)( |
766 | | - rngs.params(), |
767 | | - (config.num_experts, config.embed_dim, config.hidden_dim), |
768 | | - ), |
769 | | - sharding=self.shd_config.exp_weight_cdf, |
770 | | - ) |
771 | | - self.down_proj = nnx.Param( |
772 | | - nnx.initializers.normal(dtype=param_dtype)( |
773 | | - rngs.params(), |
774 | | - (config.num_experts, config.hidden_dim, config.embed_dim), |
775 | | - ), |
776 | | - sharding=self.shd_config.exp_weight_cfd, |
777 | | - ) |
778 | | - |
779 | | - def __call__(self, x): |
780 | | - scores = self.router(x).astype(jnp.float32) # [B,T,E] |
781 | | - routing_weights, routing_idx = jax.lax.top_k( |
782 | | - jax.nn.softmax(scores, axis=-1), self.experts_per_tok |
783 | | - ) |
784 | | - routing_weights = ( |
785 | | - routing_weights / jnp.sum(routing_weights, axis=-1, keepdims=True) |
786 | | - ).astype(x.dtype) |
787 | | - |
788 | | - dispatch_mask = jax.nn.one_hot( |
789 | | - routing_idx, num_classes=self.num_experts, dtype=x.dtype |
790 | | - ) # [B, T, K, E] |
791 | | - dispatch_mask = jnp.swapaxes(dispatch_mask, -1, -2) # [B, T, E, K] |
792 | | - |
793 | | - dispatched_input = jnp.einsum( |
794 | | - 'BTID,BTEK->BTED', x[:, :, None, :], dispatch_mask |
795 | | - ).astype(x.dtype) |
796 | | - |
797 | | - expert_outputs = [] |
798 | | - for i in range(self.num_experts): |
799 | | - expert_input = dispatched_input[:, :, i, :] |
800 | | - activations = nnx.silu( |
801 | | - jnp.einsum('BTD,DF->BTF', expert_input, self.gate_proj[i]) |
802 | | - ) * jnp.einsum('BTD,DF->BTF', expert_input, self.up_proj[i]) |
803 | | - activations = shard(activations, self.shd_config.act_btf) |
804 | | - expert_output = jnp.einsum('BTF,FD->BTD', activations, self.down_proj[i]) |
805 | | - expert_outputs.append(expert_output) |
806 | | - |
807 | | - stacked_outputs = jnp.stack(expert_outputs, axis=2) # [B, T, E, D] |
808 | | - routing_weights = jnp.tile( |
809 | | - routing_weights[:, :, None, :], (1, 1, self.num_experts, 1) |
810 | | - ) # [B, T, E, K] |
811 | | - routing_weights = dispatch_mask * routing_weights # [B, T, E, K] |
812 | | - |
813 | | - output = jnp.einsum('BTED,BTEK->BTD', stacked_outputs, routing_weights) |
814 | | - return output |
815 | | - |
816 | | - |
817 | 628 | class MLP(nnx.Module): |
818 | 629 | """MLP module.""" |
819 | 630 |
|
|
0 commit comments