Skip to content

Commit 0224a74

Browse files
authored
Add Qwen3 MoE (#2934)
* qwen-moe rebase * lint * fixed rebase error * swapped normal MoE model with CausalMoE Model in example, and swapped the tie word embeddings if statement * updated readme
1 parent cd7b877 commit 0224a74

File tree

4 files changed

+393
-1
lines changed

4 files changed

+393
-1
lines changed

candle-examples/examples/qwen/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
2525
print(i)
2626
```
2727
28+
The qwen3 MoE variant is also an option.
29+
30+
```bash
31+
$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. <think></think>." --model "3-moe-a3b"
32+
> In morning's hush, where daisies sleep,
33+
> A fleeting dance through sunlit deep—
34+
> They flutter soft on gossamer thread,
35+
> The messengers of spring’s own head.
36+
>
37+
> With painted sails and delicate grace,
38+
> They drift from bloom to blossom's face.
39+
> Each wing a tale in hues unseen,
40+
> Of ancient dreams and secrets between.
41+
>
42+
> No sound they make, yet still they speak—
43+
> Of time that flies, of life so brief.
44+
> A fleeting kiss on summer’s breath,
45+
> A whisper lost before death.
46+
>
47+
> Yet in their flight, the soul takes wing,
48+
> And for a moment, all is spring.
49+
> For though they fade, they never die—
50+
> Their beauty lives where hearts can fly.
51+
> 161 tokens generated (3.00 token/s)
52+
```

candle-examples/examples/qwen/main.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use clap::Parser;
1010
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
1111
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
1212
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
13+
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
1314

1415
use candle::{DType, Device, Tensor};
1516
use candle_examples::token_output_stream::TokenOutputStream;
@@ -22,6 +23,7 @@ enum Model {
2223
Base(ModelBase),
2324
Moe(ModelMoe),
2425
Base3(Model3),
26+
Moe3(ModelMoe3),
2527
}
2628

2729
impl Model {
@@ -30,6 +32,7 @@ impl Model {
3032
Self::Moe(ref mut m) => m.forward(xs, s),
3133
Self::Base(ref mut m) => m.forward(xs, s),
3234
Self::Base3(ref mut m) => m.forward(xs, s),
35+
Self::Moe3(ref mut m) => m.forward(xs, s),
3336
}
3437
}
3538
}
@@ -167,6 +170,8 @@ enum WhichModel {
167170
W3_4b,
168171
#[value(name = "3-8b")]
169172
W3_8b,
173+
#[value(name = "3-moe-a3b")]
174+
W3MoeA3b,
170175
}
171176

172177
#[derive(Parser, Debug)]
@@ -273,6 +278,7 @@ fn main() -> Result<()> {
273278
WhichModel::W3_1_7b => ("3", "1.7B"),
274279
WhichModel::W3_4b => ("3", "4B"),
275280
WhichModel::W3_8b => ("3", "8B"),
281+
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
276282
};
277283
format!("Qwen/Qwen{version}-{size}")
278284
}
@@ -308,7 +314,8 @@ fn main() -> Result<()> {
308314
| WhichModel::MoeA27b
309315
| WhichModel::W3_1_7b
310316
| WhichModel::W3_4b
311-
| WhichModel::W3_8b => {
317+
| WhichModel::W3_8b
318+
| WhichModel::W3MoeA3b => {
312319
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
313320
}
314321
},
@@ -334,6 +341,10 @@ fn main() -> Result<()> {
334341
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
335342
Model::Base3(Model3::new(&config, vb)?)
336343
}
344+
WhichModel::W3MoeA3b => {
345+
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
346+
Model::Moe3(ModelMoe3::new(&config, vb)?)
347+
}
337348
_ => {
338349
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
339350
Model::Base(ModelBase::new(&config, vb)?)

candle-transformers/src/models/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ pub mod quantized_t5;
100100
pub mod qwen2;
101101
pub mod qwen2_moe;
102102
pub mod qwen3;
103+
pub mod qwen3_moe;
103104
pub mod recurrent_gemma;
104105
pub mod repvgg;
105106
pub mod resnet;

0 commit comments

Comments
 (0)