@@ -6,7 +6,7 @@ use candle::{
66 shape:: Dim , CpuStorage , CustomOp1 , DType , Device , Error , IndexOp , Layout , Result , Shape ,
77 Tensor , WithDType , D ,
88} ;
9- use candle_nn:: { embedding, rms_norm, Activation , Embedding , Linear , Module , RmsNorm , VarBuilder } ;
9+ use candle_nn:: { embedding, rms_norm, Activation , Embedding , Linear , Module , RmsNorm , RmsNormNonQuantized , VarBuilder } ;
1010use rayon:: iter:: { IntoParallelRefIterator , ParallelIterator } ;
1111use serde:: Deserialize ;
1212
@@ -520,7 +520,7 @@ impl DeepSeekV2Config {
520520
521521enum QProj {
522522 Plain ( Linear ) ,
523- Lora { a : Linear , norm : RmsNorm , b : Linear } ,
523+ Lora { a : Linear , norm : RmsNorm < RmsNormNonQuantized > , b : Linear } ,
524524}
525525
526526impl QProj {
@@ -535,7 +535,7 @@ impl QProj {
535535struct Attention {
536536 q : QProj ,
537537 kv_a_proj_with_mqa : Linear ,
538- kv_a_layernorm : RmsNorm ,
538+ kv_a_layernorm : RmsNorm < RmsNormNonQuantized > ,
539539 kv_b_proj : Linear ,
540540 o_proj : Linear ,
541541 rotary_emb : Arc < DeepSeekV2RotaryEmbedding > ,
@@ -905,8 +905,8 @@ impl MoeOrMlp {
905905}
906906
907907struct DecoderLayer {
908- input_layernorm : RmsNorm ,
909- post_attention_layernorm : RmsNorm ,
908+ input_layernorm : RmsNorm < RmsNormNonQuantized > ,
909+ post_attention_layernorm : RmsNorm < RmsNormNonQuantized > ,
910910 attn : Attention ,
911911 moe_or_mlp : MoeOrMlp ,
912912}
@@ -976,7 +976,7 @@ impl DecoderLayer {
976976pub struct DeepSeekV2 {
977977 lm_head : Linear ,
978978 embed_tokens : Embedding ,
979- norm : RmsNorm ,
979+ norm : RmsNorm < RmsNormNonQuantized > ,
980980 layers : Vec < DecoderLayer > ,
981981 dtype : DType ,
982982 device : Device ,
0 commit comments