@@ -10,6 +10,7 @@ use clap::Parser;
1010use candle_transformers:: models:: qwen2:: { Config as ConfigBase , ModelForCausalLM as ModelBase } ;
1111use candle_transformers:: models:: qwen2_moe:: { Config as ConfigMoe , Model as ModelMoe } ;
1212use candle_transformers:: models:: qwen3:: { Config as Config3 , ModelForCausalLM as Model3 } ;
13+ use candle_transformers:: models:: qwen3_moe:: { Config as ConfigMoe3 , ModelForCausalLM as ModelMoe3 } ;
1314
1415use candle:: { DType , Device , Tensor } ;
1516use candle_examples:: token_output_stream:: TokenOutputStream ;
@@ -22,6 +23,7 @@ enum Model {
2223 Base ( ModelBase ) ,
2324 Moe ( ModelMoe ) ,
2425 Base3 ( Model3 ) ,
26+ Moe3 ( ModelMoe3 ) ,
2527}
2628
2729impl Model {
@@ -30,6 +32,7 @@ impl Model {
3032 Self :: Moe ( ref mut m) => m. forward ( xs, s) ,
3133 Self :: Base ( ref mut m) => m. forward ( xs, s) ,
3234 Self :: Base3 ( ref mut m) => m. forward ( xs, s) ,
35+ Self :: Moe3 ( ref mut m) => m. forward ( xs, s) ,
3336 }
3437 }
3538}
@@ -167,6 +170,8 @@ enum WhichModel {
167170 W3_4b ,
168171 #[ value( name = "3-8b" ) ]
169172 W3_8b ,
173+ #[ value( name = "3-moe-a3b" ) ]
174+ W3MoeA3b ,
170175}
171176
172177#[ derive( Parser , Debug ) ]
@@ -273,6 +278,7 @@ fn main() -> Result<()> {
273278 WhichModel :: W3_1_7b => ( "3" , "1.7B" ) ,
274279 WhichModel :: W3_4b => ( "3" , "4B" ) ,
275280 WhichModel :: W3_8b => ( "3" , "8B" ) ,
281+ WhichModel :: W3MoeA3b => ( "3" , "30B-A3B" ) ,
276282 } ;
277283 format ! ( "Qwen/Qwen{version}-{size}" )
278284 }
@@ -308,7 +314,8 @@ fn main() -> Result<()> {
308314 | WhichModel :: MoeA27b
309315 | WhichModel :: W3_1_7b
310316 | WhichModel :: W3_4b
311- | WhichModel :: W3_8b => {
317+ | WhichModel :: W3_8b
318+ | WhichModel :: W3MoeA3b => {
312319 candle_examples:: hub_load_safetensors ( & repo, "model.safetensors.index.json" ) ?
313320 }
314321 } ,
@@ -334,6 +341,10 @@ fn main() -> Result<()> {
334341 let config: Config3 = serde_json:: from_slice ( & std:: fs:: read ( config_file) ?) ?;
335342 Model :: Base3 ( Model3 :: new ( & config, vb) ?)
336343 }
344+ WhichModel :: W3MoeA3b => {
345+ let config: ConfigMoe3 = serde_json:: from_slice ( & std:: fs:: read ( config_file) ?) ?;
346+ Model :: Moe3 ( ModelMoe3 :: new ( & config, vb) ?)
347+ }
337348 _ => {
338349 let config: ConfigBase = serde_json:: from_slice ( & std:: fs:: read ( config_file) ?) ?;
339350 Model :: Base ( ModelBase :: new ( & config, vb) ?)
0 commit comments