Skip to content

Commit 0710642

Browse files
committed
Fixes
1 parent beeadf9 commit 0710642

10 files changed

Lines changed: 40 additions & 36 deletions

File tree

candle-examples/examples/nvembed_v2/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ fn encode(
148148
device,
149149
)?;
150150
let b = attention_mask.dims()[0];
151-
attention_mask.slice_assign(&[..b, ..instruction_lens], &zeros)?
151+
attention_mask.slice_assign(&[&(..b), &(..instruction_lens)], &zeros)?
152152
} else {
153153
attention_mask.clone()
154154
};

candle-nn/src/layer_norm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,6 @@ pub fn rms_norm(
361361
) -> Result<RmsNorm<RmsNormNonQuantized>> {
362362
rms_norm_non_quant(size, eps, vb)
363363
}
364+
365+
/// Type alias for backward compatibility - non-quantized RmsNorm.
366+
pub type RmsNormDefault = RmsNorm<RmsNormNonQuantized>;

candle-nn/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub use group_norm::{group_norm, GroupNorm};
5353
pub use init::Init;
5454
pub use layer_norm::{
5555
layer_norm, layer_norm_no_bias, rms_norm, rms_norm_non_quant, rms_norm_quant, LayerNorm,
56-
LayerNormConfig, RmsNorm,
56+
LayerNormConfig, RmsNorm, RmsNormDefault, RmsNormNonQuantized, RmsNormQuantized,
5757
};
5858
pub use linear::{linear, linear_b, linear_no_bias, Linear};
5959
pub use ops::Dropout;

candle-transformers/src/models/csm.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
///
99
use crate::generation::LogitsProcessor;
1010
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
11-
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
11+
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, RmsNormNonQuantized, VarBuilder};
1212
use std::sync::Arc;
1313

1414
#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
@@ -142,9 +142,9 @@ impl RotaryEmbedding {
142142
Ok((q_embed, k_embed))
143143
}
144144
}
145-
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
145+
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm<RmsNormNonQuantized>> {
146146
let weight = vb.get((hidden_size,), "scale")?;
147-
Ok(RmsNorm::new(weight, eps))
147+
Ok(RmsNorm::<RmsNormNonQuantized>::new(weight, eps))
148148
}
149149

150150
#[derive(Debug, Clone)]
@@ -274,8 +274,8 @@ impl Module for Mlp {
274274

275275
#[derive(Debug, Clone)]
276276
struct Layer {
277-
mlp_norm: RmsNorm,
278-
sa_norm: RmsNorm,
277+
mlp_norm: RmsNorm<RmsNormNonQuantized>,
278+
sa_norm: RmsNorm<RmsNormNonQuantized>,
279279
attn: Attention,
280280
mlp: Mlp,
281281
}
@@ -317,7 +317,7 @@ impl Layer {
317317
#[derive(Debug, Clone)]
318318
pub struct LlamaModel {
319319
layers: Vec<Layer>,
320-
norm: RmsNorm,
320+
norm: RmsNorm<RmsNormNonQuantized>,
321321
device: Device,
322322
dtype: DType,
323323
}

candle-transformers/src/models/deepseek2.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use candle::{
66
shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape,
77
Tensor, WithDType, D,
88
};
9-
use candle_nn::{embedding, rms_norm, Activation, Embedding, Linear, Module, RmsNorm, VarBuilder};
9+
use candle_nn::{embedding, rms_norm, Activation, Embedding, Linear, Module, RmsNorm, RmsNormNonQuantized, VarBuilder};
1010
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
1111
use serde::Deserialize;
1212

@@ -520,7 +520,7 @@ impl DeepSeekV2Config {
520520

521521
enum QProj {
522522
Plain(Linear),
523-
Lora { a: Linear, norm: RmsNorm, b: Linear },
523+
Lora { a: Linear, norm: RmsNorm<RmsNormNonQuantized>, b: Linear },
524524
}
525525

526526
impl QProj {
@@ -535,7 +535,7 @@ impl QProj {
535535
struct Attention {
536536
q: QProj,
537537
kv_a_proj_with_mqa: Linear,
538-
kv_a_layernorm: RmsNorm,
538+
kv_a_layernorm: RmsNorm<RmsNormNonQuantized>,
539539
kv_b_proj: Linear,
540540
o_proj: Linear,
541541
rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
@@ -905,8 +905,8 @@ impl MoeOrMlp {
905905
}
906906

907907
struct DecoderLayer {
908-
input_layernorm: RmsNorm,
909-
post_attention_layernorm: RmsNorm,
908+
input_layernorm: RmsNorm<RmsNormNonQuantized>,
909+
post_attention_layernorm: RmsNorm<RmsNormNonQuantized>,
910910
attn: Attention,
911911
moe_or_mlp: MoeOrMlp,
912912
}
@@ -976,7 +976,7 @@ impl DecoderLayer {
976976
pub struct DeepSeekV2 {
977977
lm_head: Linear,
978978
embed_tokens: Embedding,
979-
norm: RmsNorm,
979+
norm: RmsNorm<RmsNormNonQuantized>,
980980
layers: Vec<DecoderLayer>,
981981
dtype: DType,
982982
device: Device,

candle-transformers/src/models/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ pub mod quantized_mixformer;
9191
pub mod quantized_moondream;
9292
pub mod quantized_mpt;
9393
pub mod quantized_phi;
94+
pub mod quantized_phi3;
9495
pub mod quantized_qwen2;
9596
pub mod quantized_qwen3;
9697
pub mod quantized_recurrent_gemma;

candle-transformers/src/models/olmo2.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//!
77
//!
88
use candle::{DType, Device, Module, Result, Tensor, D};
9-
use candle_nn::{linear_b, linear_no_bias, rms_norm, Activation, Linear, RmsNorm, VarBuilder};
9+
use candle_nn::{linear_b, linear_no_bias, rms_norm, Activation, Linear, RmsNorm, RmsNormNonQuantized, VarBuilder};
1010
use std::sync::Arc;
1111

1212
#[derive(Debug, Clone, serde::Deserialize)]
@@ -106,8 +106,8 @@ struct Attention {
106106
k_proj: Linear,
107107
v_proj: Linear,
108108
o_proj: Linear,
109-
q_norm: RmsNorm,
110-
k_norm: RmsNorm,
109+
q_norm: RmsNorm<RmsNormNonQuantized>,
110+
k_norm: RmsNorm<RmsNormNonQuantized>,
111111
num_heads: usize,
112112
num_kv_heads: usize,
113113
num_kv_groups: usize,
@@ -217,8 +217,8 @@ impl Attention {
217217
struct DecoderLayer {
218218
self_attn: Attention,
219219
mlp: MLP,
220-
post_attention_layernorm: RmsNorm,
221-
post_feedforward_layernorm: RmsNorm,
220+
post_attention_layernorm: RmsNorm<RmsNormNonQuantized>,
221+
post_feedforward_layernorm: RmsNorm<RmsNormNonQuantized>,
222222
}
223223

224224
impl DecoderLayer {
@@ -268,7 +268,7 @@ impl DecoderLayer {
268268
pub struct Model {
269269
embed_tokens: candle_nn::Embedding,
270270
layers: Vec<DecoderLayer>,
271-
norm: RmsNorm,
271+
norm: RmsNorm<RmsNormNonQuantized>,
272272
lm_head: Linear,
273273
device: Device,
274274
dtype: DType,

candle-transformers/src/models/quantized_phi3.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::collections::HashMap;
1818
use candle::quantized::gguf_file;
1919
use candle::quantized::QTensor;
2020
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
21-
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
21+
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm, RmsNormNonQuantized};
2222

2323
#[derive(Debug, Clone)]
2424
struct QLinear {
@@ -64,18 +64,18 @@ impl Module for Mlp {
6464
}
6565
}
6666

67-
fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
67+
fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm<RmsNormNonQuantized>> {
6868
let w = w.dequantize(&w.device())?;
69-
let rms = RmsNorm::new(w, eps);
69+
let rms = RmsNorm::<RmsNormNonQuantized>::new(w, eps);
7070
Ok(rms)
7171
}
7272

7373
#[derive(Debug, Clone)]
7474
struct LayerWeights {
7575
attn_qkv: QLinear,
7676
attn_output: QLinear,
77-
attn_norm: RmsNorm,
78-
ffn_norm: RmsNorm,
77+
attn_norm: RmsNorm<RmsNormNonQuantized>,
78+
ffn_norm: RmsNorm<RmsNormNonQuantized>,
7979
mlp: Mlp,
8080
n_head: usize,
8181
n_kv_head: usize,
@@ -192,7 +192,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
192192
pub struct ModelWeights {
193193
tok_embeddings: Embedding,
194194
layers: Vec<LayerWeights>,
195-
output_norm: RmsNorm,
195+
output_norm: RmsNorm<RmsNormNonQuantized>,
196196
output: QLinear,
197197
masks: HashMap<usize, Tensor>,
198198
span: tracing::Span,

candle-transformers/src/models/qwen3_vl/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ impl Qwen3VLModel {
126126
let chunk = image_embeds.narrow(0, offset, len)?;
127127
offset += len;
128128
input_embeds = input_embeds.slice_assign(
129-
&[batch..batch + 1, start..end, 0..hidden_dim],
129+
&[&(batch..batch + 1), &(start..end), &(0..hidden_dim)],
130130
&chunk.unsqueeze(0)?,
131131
)?;
132132
let ones = Tensor::ones((1, len), DType::F32, input_ids.device())?;
133-
image_mask = image_mask.slice_assign(&[batch..batch + 1, start..end], &ones)?;
133+
image_mask = image_mask.slice_assign(&[&(batch..batch + 1), &(start..end)], &ones)?;
134134
}
135135
}
136136
image_mask_opt = Some(image_mask.to_dtype(DType::U8)?);
@@ -175,11 +175,11 @@ impl Qwen3VLModel {
175175
let chunk = video_embeds.narrow(0, offset, len)?;
176176
offset += len;
177177
input_embeds = input_embeds.slice_assign(
178-
&[batch..batch + 1, start..end, 0..hidden_dim],
178+
&[&(batch..batch + 1), &(start..end), &(0..hidden_dim)],
179179
&chunk.unsqueeze(0)?,
180180
)?;
181181
let ones = Tensor::ones((1, len), DType::F32, input_ids.device())?;
182-
video_mask = video_mask.slice_assign(&[batch..batch + 1, start..end], &ones)?;
182+
video_mask = video_mask.slice_assign(&[&(batch..batch + 1), &(start..end)], &ones)?;
183183
}
184184
}
185185
video_mask_opt = Some(video_mask.to_dtype(DType::U8)?);

candle-transformers/src/models/qwen3_vl/text.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex};
33
use candle::{DType, Device, IndexOp, Result, Tensor};
44
use candle_nn::{
55
embedding, kv_cache::KvCache, linear, linear_b, rms_norm, Activation, Embedding, Linear,
6-
Module, RmsNorm, VarBuilder,
6+
Module, RmsNorm, RmsNormNonQuantized, VarBuilder,
77
};
88

99
use super::config::TextConfig;
@@ -96,8 +96,8 @@ struct Attention {
9696
k_proj: Linear,
9797
v_proj: Linear,
9898
o_proj: Linear,
99-
q_norm: RmsNorm,
100-
k_norm: RmsNorm,
99+
q_norm: RmsNorm<RmsNormNonQuantized>,
100+
k_norm: RmsNorm<RmsNormNonQuantized>,
101101
num_heads: usize,
102102
num_kv_heads: usize,
103103
head_dim: usize,
@@ -205,8 +205,8 @@ impl Attention {
205205
pub struct DecoderLayer {
206206
self_attn: Attention,
207207
mlp: Mlp,
208-
input_layernorm: RmsNorm,
209-
post_attention_layernorm: RmsNorm,
208+
input_layernorm: RmsNorm<RmsNormNonQuantized>,
209+
post_attention_layernorm: RmsNorm<RmsNormNonQuantized>,
210210
}
211211

212212
impl DecoderLayer {
@@ -251,7 +251,7 @@ impl DecoderLayer {
251251

252252
pub struct Qwen3VLTextModel {
253253
embed_tokens: Embedding,
254-
pub(super) norm: RmsNorm,
254+
pub(super) norm: RmsNorm<RmsNormNonQuantized>,
255255
layers: Vec<DecoderLayer>,
256256
lm_head: Linear,
257257
pub(super) dtype: DType,

0 commit comments

Comments
 (0)