@@ -403,40 +403,40 @@ def scaled_dot_product_attention(
403403 return (attn_output , attn_weights ) if output_attentions else attn_output
404404
405405
406- class Qwen2RMSNorm (nn .Layer ):
407- """Qwen2的RMSNorm,继承自LlamaRMSNorm"""
408- def __init__ (self , config : Qwen2Config ):
409- super ().__init__ ()
410- self .hidden_size = config .hidden_size
411- self .weight = paddle .create_parameter (
412- shape = [self .hidden_size ],
413- dtype = paddle .get_default_dtype (),
414- default_initializer = nn .initializer .Constant (1.0 ),
415- )
416- self .variance_epsilon = config .rms_norm_eps
417- self .config = config
406+ class Qwen2RMSNorm (nn .Layer ):
407+ """Qwen2的RMSNorm,继承自LlamaRMSNorm"""
408+ def __init__ (self , config : Qwen2Config ):
409+ super ().__init__ ()
410+ self .hidden_size = config .hidden_size
411+ self .weight = paddle .create_parameter (
412+ shape = [self .hidden_size ],
413+ dtype = paddle .get_default_dtype (),
414+ default_initializer = nn .initializer .Constant (1.0 ),
415+ )
416+ self .variance_epsilon = config .rms_norm_eps
417+ self .config = config
418418
419- if config .sequence_parallel :
420- mark_as_sequence_parallel_parameter (self .weight )
419+ if config .sequence_parallel :
420+ mark_as_sequence_parallel_parameter (self .weight )
421421
422- def forward (self , hidden_states ):
423- if self .config .use_fused_rms_norm :
424- return fusion_ops .fusion_rms_norm (hidden_states , self .weight , self .variance_epsilon )
422+ def forward (self , hidden_states ):
423+ if self .config .use_fused_rms_norm :
424+ return fusion_ops .fusion_rms_norm (hidden_states , self .weight , self .variance_epsilon )
425425
426- if paddle .in_dynamic_mode ():
427- with paddle .amp .auto_cast (False ):
428- # hidden_states = hidden_states.astype("float32")
429- # variance = hidden_states.pow(2).mean(-1, keepdim=True)
430- variance = hidden_states .astype ("float32" ).pow (2 ).mean (- 1 , keepdim = True )
431- hidden_states = paddle .rsqrt (variance + self .variance_epsilon ) * hidden_states
432- else :
433- hidden_states = hidden_states .astype ("float32" )
434- variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
426+ if paddle .in_dynamic_mode ():
427+ with paddle .amp .auto_cast (False ):
428+ # hidden_states = hidden_states.astype("float32")
429+ # variance = hidden_states.pow(2).mean(-1, keepdim=True)
430+ variance = hidden_states .astype ("float32" ).pow (2 ).mean (- 1 , keepdim = True )
435431 hidden_states = paddle .rsqrt (variance + self .variance_epsilon ) * hidden_states
432+ else :
433+ hidden_states = hidden_states .astype ("float32" )
434+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
435+ hidden_states = paddle .rsqrt (variance + self .variance_epsilon ) * hidden_states
436436
437- if self .weight .dtype in [paddle .float16 , paddle .bfloat16 ]:
438- hidden_states = paddle .cast (hidden_states , self .weight .dtype )
439- return hidden_states * self .weight
437+ if self .weight .dtype in [paddle .float16 , paddle .bfloat16 ]:
438+ hidden_states = paddle .cast (hidden_states , self .weight .dtype )
439+ return hidden_states * self .weight
440440class Qwen2RotaryEmbedding (nn .Layer ):
441441 def __init__ (self , dim , max_position_embeddings = 2048 , base = 10000 ):
442442 super ().__init__ ()
0 commit comments