Skip to content

Commit afa2d29

Browse files
committed
Remove unused code
1 parent f306866 commit afa2d29

File tree

1 file changed

+1
-190
lines changed

1 file changed

+1
-190
lines changed

tunix/models/qwen3vl/model.py

Lines changed: 1 addition & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax.interpreters import pxla
2626
import jax.sharding as shd
2727
import jaxtyping
28+
import numpy as np
2829
from tunix.generate.mappings import BackendMappingMixin
2930
from tunix.models.qwen3vl.vision import VisionEmbeddings
3031
from tunix.models.qwen3vl.vision import VisionGridData
@@ -107,116 +108,6 @@ class ModelConfig:
107108
param_dtype: jnp.dtype = jnp.bfloat16
108109
vision_config: VisionModelConfig | None = None
109110

110-
@classmethod
111-
def qwen3_0p6b(cls): # qwen3-0.6B
112-
return cls(
113-
num_layers=28,
114-
vocab_size=151936,
115-
embed_dim=1024,
116-
hidden_dim=3072,
117-
num_heads=16,
118-
head_dim=128,
119-
num_kv_heads=8,
120-
norm_eps=1e-06,
121-
rope_theta=1_000_000,
122-
)
123-
124-
@classmethod
125-
def qwen3_1p7b(cls): # qwen3-1.7B
126-
return cls(
127-
num_layers=28,
128-
vocab_size=151936,
129-
embed_dim=2048,
130-
hidden_dim=6144,
131-
num_heads=16,
132-
head_dim=128,
133-
num_kv_heads=8,
134-
norm_eps=1e-06,
135-
rope_theta=1_000_000,
136-
)
137-
138-
@classmethod
139-
def qwen3_4b(cls): # qwen3-4B
140-
return cls(
141-
num_layers=36,
142-
vocab_size=151936,
143-
embed_dim=2560,
144-
hidden_dim=9728,
145-
num_heads=32,
146-
head_dim=128,
147-
num_kv_heads=8,
148-
norm_eps=1e-06,
149-
rope_theta=1_000_000,
150-
use_tied_embedding=True,
151-
)
152-
153-
@classmethod
154-
def _qwen3_4b_2507(cls): # Qwen3-4B-Instruct-2507 and Qwen3-4B-Thinking-2507
155-
return cls(
156-
num_layers=36,
157-
vocab_size=151936,
158-
embed_dim=2560,
159-
hidden_dim=9728,
160-
num_heads=32,
161-
head_dim=128,
162-
num_kv_heads=8,
163-
norm_eps=1e-06,
164-
rope_theta=5_000_000,
165-
use_tied_embedding=True,
166-
)
167-
168-
@classmethod
169-
def qwen3_4b_instruct_2507(cls): # Qwen3-4B-Instruct-2507
170-
return cls._qwen3_4b_2507()
171-
172-
@classmethod
173-
def qwen3_4b_thinking_2507(cls): # Qwen3-4B-Thinking-2507
174-
return cls._qwen3_4b_2507()
175-
176-
@classmethod
177-
def qwen3_8b(cls): # qwen3-8B
178-
return cls(
179-
num_layers=36,
180-
vocab_size=151936,
181-
embed_dim=4096,
182-
hidden_dim=12288,
183-
num_heads=32,
184-
head_dim=128,
185-
num_kv_heads=8,
186-
norm_eps=1e-06,
187-
rope_theta=1_000_000,
188-
)
189-
190-
@classmethod
191-
def qwen3_14b(cls): # qwen3-14B
192-
return cls(
193-
num_layers=40,
194-
vocab_size=151936,
195-
embed_dim=5120,
196-
hidden_dim=17408,
197-
num_heads=40,
198-
head_dim=128,
199-
num_kv_heads=8,
200-
norm_eps=1e-06,
201-
rope_theta=1_000_000,
202-
)
203-
204-
@classmethod
205-
def qwen3_30b_a3b(cls): # qwen3-30B-a3b
206-
return cls(
207-
num_layers=48,
208-
vocab_size=151936,
209-
embed_dim=2048,
210-
hidden_dim=768,
211-
num_heads=32,
212-
head_dim=128,
213-
num_kv_heads=4,
214-
norm_eps=1e-06,
215-
rope_theta=1_000_000,
216-
num_experts=128,
217-
num_experts_per_tok=8,
218-
)
219-
220111
@classmethod
221112
def qwen3vl_4b(cls): # qwen3-vl-4b
222113
return cls(
@@ -734,86 +625,6 @@ def num_kv_heads(self):
734625
return self.k_proj.shape[1]
735626

736627

737-
class MoELayer(nnx.Module):
738-
"""MoE layer."""
739-
740-
def __init__(
741-
self,
742-
config: ModelConfig,
743-
*,
744-
rngs: nnx.Rngs,
745-
param_dtype: jnp.dtype = jnp.bfloat16,
746-
):
747-
self.shd_config = config.shd_config
748-
self.experts_per_tok = config.num_experts_per_tok
749-
self.num_experts = config.num_experts
750-
self.router = nnx.Linear(
751-
in_features=config.embed_dim,
752-
out_features=config.num_experts,
753-
use_bias=False,
754-
rngs=rngs,
755-
param_dtype=param_dtype,
756-
)
757-
self.gate_proj = nnx.Param(
758-
nnx.initializers.normal(dtype=param_dtype)(
759-
rngs.params(),
760-
(config.num_experts, config.embed_dim, config.hidden_dim),
761-
),
762-
sharding=self.shd_config.exp_weight_cdf,
763-
)
764-
self.up_proj = nnx.Param(
765-
nnx.initializers.normal(dtype=param_dtype)(
766-
rngs.params(),
767-
(config.num_experts, config.embed_dim, config.hidden_dim),
768-
),
769-
sharding=self.shd_config.exp_weight_cdf,
770-
)
771-
self.down_proj = nnx.Param(
772-
nnx.initializers.normal(dtype=param_dtype)(
773-
rngs.params(),
774-
(config.num_experts, config.hidden_dim, config.embed_dim),
775-
),
776-
sharding=self.shd_config.exp_weight_cfd,
777-
)
778-
779-
def __call__(self, x):
780-
scores = self.router(x).astype(jnp.float32) # [B,T,E]
781-
routing_weights, routing_idx = jax.lax.top_k(
782-
jax.nn.softmax(scores, axis=-1), self.experts_per_tok
783-
)
784-
routing_weights = (
785-
routing_weights / jnp.sum(routing_weights, axis=-1, keepdims=True)
786-
).astype(x.dtype)
787-
788-
dispatch_mask = jax.nn.one_hot(
789-
routing_idx, num_classes=self.num_experts, dtype=x.dtype
790-
) # [B, T, K, E]
791-
dispatch_mask = jnp.swapaxes(dispatch_mask, -1, -2) # [B, T, E, K]
792-
793-
dispatched_input = jnp.einsum(
794-
'BTID,BTEK->BTED', x[:, :, None, :], dispatch_mask
795-
).astype(x.dtype)
796-
797-
expert_outputs = []
798-
for i in range(self.num_experts):
799-
expert_input = dispatched_input[:, :, i, :]
800-
activations = nnx.silu(
801-
jnp.einsum('BTD,DF->BTF', expert_input, self.gate_proj[i])
802-
) * jnp.einsum('BTD,DF->BTF', expert_input, self.up_proj[i])
803-
activations = shard(activations, self.shd_config.act_btf)
804-
expert_output = jnp.einsum('BTF,FD->BTD', activations, self.down_proj[i])
805-
expert_outputs.append(expert_output)
806-
807-
stacked_outputs = jnp.stack(expert_outputs, axis=2) # [B, T, E, D]
808-
routing_weights = jnp.tile(
809-
routing_weights[:, :, None, :], (1, 1, self.num_experts, 1)
810-
) # [B, T, E, K]
811-
routing_weights = dispatch_mask * routing_weights # [B, T, E, K]
812-
813-
output = jnp.einsum('BTED,BTEK->BTD', stacked_outputs, routing_weights)
814-
return output
815-
816-
817628
class MLP(nnx.Module):
818629
"""MLP module."""
819630

0 commit comments

Comments
 (0)