Skip to content

Commit e1a2fbe

Browse files
committed
modified: scripts/codestyle/qwen2/modeling_qwen2.py
1 parent f79a81a commit e1a2fbe

1 file changed

Lines changed: 29 additions & 29 deletions

File tree

scripts/codestyle/qwen2/modeling_qwen2.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -403,40 +403,40 @@ def scaled_dot_product_attention(
403403
return (attn_output, attn_weights) if output_attentions else attn_output
404404

405405

406-
class Qwen2RMSNorm(nn.Layer):
407-
"""Qwen2的RMSNorm,继承自LlamaRMSNorm"""
408-
def __init__(self, config: Qwen2Config):
409-
super().__init__()
410-
self.hidden_size = config.hidden_size
411-
self.weight = paddle.create_parameter(
412-
shape=[self.hidden_size],
413-
dtype=paddle.get_default_dtype(),
414-
default_initializer=nn.initializer.Constant(1.0),
415-
)
416-
self.variance_epsilon = config.rms_norm_eps
417-
self.config = config
406+
class Qwen2RMSNorm(nn.Layer):
407+
"""Qwen2的RMSNorm,继承自LlamaRMSNorm"""
408+
def __init__(self, config: Qwen2Config):
409+
super().__init__()
410+
self.hidden_size = config.hidden_size
411+
self.weight = paddle.create_parameter(
412+
shape=[self.hidden_size],
413+
dtype=paddle.get_default_dtype(),
414+
default_initializer=nn.initializer.Constant(1.0),
415+
)
416+
self.variance_epsilon = config.rms_norm_eps
417+
self.config = config
418418

419-
if config.sequence_parallel:
420-
mark_as_sequence_parallel_parameter(self.weight)
419+
if config.sequence_parallel:
420+
mark_as_sequence_parallel_parameter(self.weight)
421421

422-
def forward(self, hidden_states):
423-
if self.config.use_fused_rms_norm:
424-
return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon)
422+
def forward(self, hidden_states):
423+
if self.config.use_fused_rms_norm:
424+
return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon)
425425

426-
if paddle.in_dynamic_mode():
427-
with paddle.amp.auto_cast(False):
428-
# hidden_states = hidden_states.astype("float32")
429-
# variance = hidden_states.pow(2).mean(-1, keepdim=True)
430-
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
431-
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
432-
else:
433-
hidden_states = hidden_states.astype("float32")
434-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
426+
if paddle.in_dynamic_mode():
427+
with paddle.amp.auto_cast(False):
428+
# hidden_states = hidden_states.astype("float32")
429+
# variance = hidden_states.pow(2).mean(-1, keepdim=True)
430+
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
435431
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
432+
else:
433+
hidden_states = hidden_states.astype("float32")
434+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
435+
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
436436

437-
if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
438-
hidden_states = paddle.cast(hidden_states, self.weight.dtype)
439-
return hidden_states * self.weight
437+
if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
438+
hidden_states = paddle.cast(hidden_states, self.weight.dtype)
439+
return hidden_states * self.weight
440440
class Qwen2RotaryEmbedding(nn.Layer):
441441
def __init__(self, dim, max_position_embeddings=2048, base=10000):
442442
super().__init__()

0 commit comments

Comments
 (0)