Skip to content

Commit 54ff971

Browse files
Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models. * Add more models.
1 parent b9fac7e commit 54ff971

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

candle-examples/examples/qwen/main.rs

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ enum WhichModel {
144144
W72b,
145145
#[value(name = "moe-a2.7b")]
146146
MoeA27b,
147+
#[value(name = "2-0.5b")]
148+
W2_0_5b,
149+
#[value(name = "2-1.5b")]
150+
W2_1_5b,
151+
#[value(name = "2-7b")]
152+
W2_7b,
153+
#[value(name = "2-72b")]
154+
W2_72b,
147155
}
148156

149157
#[derive(Parser, Debug)]
@@ -234,16 +242,20 @@ fn main() -> Result<()> {
234242
let model_id = match args.model_id {
235243
Some(model_id) => model_id,
236244
None => {
237-
let size = match args.model {
238-
WhichModel::W0_5b => "0.5B",
239-
WhichModel::W1_8b => "1.8B",
240-
WhichModel::W4b => "4B",
241-
WhichModel::W7b => "7B",
242-
WhichModel::W14b => "14B",
243-
WhichModel::W72b => "72B",
244-
WhichModel::MoeA27b => "MoE-A2.7B",
245+
let (version, size) = match args.model {
246+
WhichModel::W2_0_5b => ("2", "0.5B"),
247+
WhichModel::W2_1_5b => ("2", "1.5B"),
248+
WhichModel::W2_7b => ("2", "7B"),
249+
WhichModel::W2_72b => ("2", "72B"),
250+
WhichModel::W0_5b => ("1.5", "0.5B"),
251+
WhichModel::W1_8b => ("1.5", "1.8B"),
252+
WhichModel::W4b => ("1.5", "4B"),
253+
WhichModel::W7b => ("1.5", "7B"),
254+
WhichModel::W14b => ("1.5", "14B"),
255+
WhichModel::W72b => ("1.5", "72B"),
256+
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
245257
};
246-
format!("Qwen/Qwen1.5-{size}")
258+
format!("Qwen/Qwen{version}-{size}")
247259
}
248260
};
249261
let repo = api.repo(Repo::with_revision(
@@ -261,11 +273,15 @@ fn main() -> Result<()> {
261273
.map(std::path::PathBuf::from)
262274
.collect::<Vec<_>>(),
263275
None => match args.model {
264-
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
276+
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
277+
vec![repo.get("model.safetensors")?]
278+
}
265279
WhichModel::W4b
266280
| WhichModel::W7b
281+
| WhichModel::W2_7b
267282
| WhichModel::W14b
268283
| WhichModel::W72b
284+
| WhichModel::W2_72b
269285
| WhichModel::MoeA27b => {
270286
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
271287
}

candle-transformers/src/models/qwen2.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
360360

361361
impl ModelForCausalLM {
362362
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
363-
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
364-
let base_model = Model::new(cfg, vb)?;
363+
let base_model = Model::new(cfg, vb.clone())?;
364+
let lm_head = if vb.contains_tensor("lm_head") {
365+
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
366+
} else {
367+
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
368+
};
365369
Ok(Self {
366370
base_model,
367371
lm_head,

0 commit comments

Comments
 (0)