Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4306,8 +4306,8 @@
data_type: w

- backward_op: rms_norm_grad
forward: rms_norm (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon = 1e-5) -> Tensor(y), Tensor(invvar)
args: (Tensor x, Tensor scale, Tensor invvar, Tensor y_grad, int64_t[] normalized_shape={}, double epsilon = 1e-5)
forward: rms_norm (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon = 1.19209289550781250e-7) -> Tensor(y), Tensor(invvar)
args: (Tensor x, Tensor scale, Tensor invvar, Tensor y_grad, int64_t[] normalized_shape={}, double epsilon = 1.19209289550781250e-7)
output: Tensor(x_grad), Tensor(scale_grad)
infer_meta:
func: RMSNormGradInferMeta
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6166,7 +6166,7 @@
traits : paddle::dialect::ForwardOnlyTrait

- op: rms_norm
args: (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon= 1e-5)
args: (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon=1.19209289550781250e-7)
output: Tensor(y), Tensor(invvar)
infer_meta:
func: RmsNormInferMeta
Expand Down
35 changes: 13 additions & 22 deletions python/paddle/nn/functional/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numbers
from typing import TYPE_CHECKING, Any

import numpy as np
from typing_extensions import overload

import paddle
Expand Down Expand Up @@ -465,9 +466,9 @@ def rms_norm(
input: Tensor,
normalized_shape: Sequence[int],
weight: Tensor | None = None,
eps: float = 1e-5,
eps: float | None = None,
name: str | None = None,
) -> tuple[Tensor, Tensor]:
) -> Tensor:
"""
Applies Layer Normalization over the last dimension of the input tensor using CUDA implementation.

Expand All @@ -478,33 +479,23 @@ def rms_norm(
If it is a single integer, this module will normalize over the last dimension
which is expected to be of that specific size.
weight(Tensor, optional): The weight tensor of rms_norm. Default: None.
eps(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05.
eps(float|None, optional): The small value added to the variance to prevent division by zero.
If None, uses machine epsilon for the compute dtype: ``float64`` inputs use
``np.finfo(np.float64).eps`` (double epsilon), all other dtypes use
``np.finfo(np.float32).eps`` (float epsilon). Default: None.
name (str, optional): Name of the operator.

Returns:
out (Tensor): Normalized tensor of same shape as input.
invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row.
"""

if in_dynamic_or_pir_mode():
return _C_ops.rms_norm(input, weight, normalized_shape, eps)

helper = LayerHelper('rms_norm', **locals())
from paddle.base.data_feeder import convert_dtype

dtype = convert_dtype(input.dtype)
out = helper.create_variable_for_type_inference(dtype)
invvar = helper.create_variable_for_type_inference('float32')

inputs = {'input': input, 'weight': weight}
if eps is None:
if input.dtype == paddle.float64:
eps = np.finfo(np.float64).eps # ~2.22e-16, double machine epsilon
else:
eps = np.finfo(np.float32).eps # ~1.19e-7, float machine epsilon

helper.append_op(
type='rms_norm',
inputs=inputs,
outputs={'out': out, 'invvar': invvar},
attrs={"normalized_shape": normalized_shape, "eps": eps},
)
return out, invvar
return _C_ops.rms_norm(input, weight, normalized_shape, eps)[0]


def instance_norm(
Expand Down
89 changes: 83 additions & 6 deletions test/legacy_test/test_rms_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def setUp(self):
self.outputs = {'y': y_ref, 'invvar': invvar_ref}

def rms_norm_wrapper(x, scale):
return rms_norm(x, scale.shape, scale, eps=self.epsilon)
from paddle import _C_ops

return _C_ops.rms_norm(x, scale, scale.shape, self.epsilon)

self.python_api = rms_norm_wrapper

Expand Down Expand Up @@ -124,15 +126,12 @@ def test_api_dygraph(self):
scale.stop_gradient = False

# Test forward
y_fused, invvar_fused = rms_norm(x, (cols,), scale)
y_ref, invvar_ref = self.rms_norm_reference(x, scale)
y_fused = rms_norm(x, (cols,), scale)
y_ref, _ = self.rms_norm_reference(x, scale)

np.testing.assert_allclose(
y_fused.numpy(), y_ref.numpy(), rtol=1e-5, atol=1e-5
)
np.testing.assert_allclose(
invvar_fused.numpy(), invvar_ref.numpy(), rtol=1e-5, atol=1e-5
)

# Test backward
loss = paddle.mean(y_fused)
Expand Down Expand Up @@ -174,5 +173,83 @@ def test_weight_shape_mismatch(self):
rms_norm(x, [3], weight=weight)


class TestRMSNormEpsNone(unittest.TestCase):
"""Tests that eps=None selects the correct machine epsilon per dtype."""

def _ref(self, x_np, scale_np, epsilon):
variance = np.mean(np.square(x_np), axis=-1, keepdims=True)
rms = np.sqrt(variance + epsilon)
return x_np / rms * scale_np

def test_eps_none_float32(self):
"""eps=None with float32 input should use float machine epsilon."""

rows, cols = 8, 16
x_np = np.random.randn(rows, cols).astype("float32")
scale_np = np.ones(cols, dtype="float32")

x = paddle.to_tensor(x_np)
scale = paddle.to_tensor(scale_np)

y_none = rms_norm(x, (cols,), scale, eps=None)
float_eps = 1.1920929e-07
y_explicit = rms_norm(x, (cols,), scale, eps=float_eps)

np.testing.assert_array_equal(y_none.numpy(), y_explicit.numpy())

y_ref = self._ref(x_np, scale_np, float_eps)
np.testing.assert_allclose(
y_none.numpy(), y_ref.astype("float32"), rtol=1e-5, atol=1e-5
)

def test_eps_none_float64(self):
"""eps=None with float64 input should use double machine epsilon."""
import sys

rows, cols = 8, 16
x_np = np.random.randn(rows, cols).astype("float64")
scale_np = np.ones(cols, dtype="float64")

x = paddle.to_tensor(x_np)
scale = paddle.to_tensor(scale_np)

y_none = rms_norm(x, (cols,), scale, eps=None)
double_eps = sys.float_info.epsilon # ~2.22e-16
y_explicit = rms_norm(x, (cols,), scale, eps=double_eps)

np.testing.assert_array_equal(y_none.numpy(), y_explicit.numpy())

y_ref = self._ref(x_np, scale_np, double_eps)
np.testing.assert_allclose(
y_none.numpy(), y_ref, rtol=1e-12, atol=1e-12
)

def test_eps_none_float32_differs_from_float64(self):
"""float32 and float64 defaults should be different epsilon values."""
import sys

float_eps = 1.1920929e-07
double_eps = sys.float_info.epsilon
self.assertNotAlmostEqual(float_eps, double_eps, places=10)

def test_eps_none_backward_float32(self):
"""eps=None should work through backward pass for float32."""
rows, cols = 8, 16
x_np = np.random.randn(rows, cols).astype("float32")
scale_np = np.ones(cols, dtype="float32")

x = paddle.to_tensor(x_np)
x.stop_gradient = False
scale = paddle.to_tensor(scale_np)
scale.stop_gradient = False

y = rms_norm(x, (cols,), scale, eps=None)
loss = paddle.mean(y)
loss.backward()

self.assertIsNotNone(x.grad)
self.assertIsNotNone(scale.grad)


if __name__ == '__main__':
unittest.main()
33 changes: 12 additions & 21 deletions test/xpu/test_rms_norm_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,54 +33,45 @@ def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5):
if bias is not None:
y = y + bias.reshape([1, -1])

return y, paddle.flatten(1.0 / rms)
return y

def test_2d_input(self):
rows, cols = 32, 64
x = paddle.randn([rows, cols])
scale = paddle.randn([cols])

y_fused, invvar_fused = rms_norm(x, (cols,), scale)
y_fused = rms_norm(x, (cols,), scale)

y_ref, invvar_ref = self.rms_norm_reference(x, scale)
y_ref = self.rms_norm_reference(x, scale)

np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5
)

def test_3d_input(self):
batch, rows, cols = 16, 32, 64
x = paddle.randn([batch, rows, cols])
scale = paddle.randn([cols])

y_fused, invvar_fused = rms_norm(x, (cols,), scale)
y_fused = rms_norm(x, (cols,), scale)

y_ref, invvar_ref = self.rms_norm_reference(x, scale)
y_ref = self.rms_norm_reference(x, scale)

np.testing.assert_allclose(
y_fused.astype("float32"),
y_ref.astype("float32"),
rtol=1e-5,
atol=1e-5,
)
np.testing.assert_allclose(
invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5
)

def test_without_bias(self):
rows, cols = 32, 64
x = paddle.randn([rows, cols])
scale = paddle.randn([cols])

y_fused, invvar_fused = rms_norm(x, (cols,), scale)
y_fused = rms_norm(x, (cols,), scale)

y_ref, invvar_ref = self.rms_norm_reference(x, scale)
y_ref = self.rms_norm_reference(x, scale)

np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5
)

def test_3d_backward(self):
batch, rows, cols = 8, 16, 32
Expand All @@ -89,7 +80,7 @@ def test_3d_backward(self):
scale = paddle.randn([cols], dtype='float32')
scale.stop_gradient = False

y_fused, invvar = rms_norm(x, (cols,), scale)
y_fused = rms_norm(x, (cols,), scale)

loss = paddle.mean(y_fused)
loss.backward()
Expand All @@ -100,7 +91,7 @@ def test_3d_backward(self):
x.clear_gradient()
scale.clear_gradient()

y_ref, invvar_ref = self.rms_norm_reference(x, scale)
y_ref = self.rms_norm_reference(x, scale)
loss_ref = paddle.mean(y_ref)
loss_ref.backward()

Expand All @@ -126,7 +117,7 @@ def test_backward(self):
scale = paddle.randn([cols], dtype=scale_type)
scale.stop_gradient = False

y_fused, invvar = rms_norm(x, (cols,), scale)
y_fused = rms_norm(x, (cols,), scale)

loss = paddle.mean(y_fused)
loss.backward()
Expand All @@ -141,9 +132,9 @@ def test_backward(self):
# FIXME(yangjianbang): XPU sqrt_grad does not support bfloat16
x_fp32 = x.cast("float32")
scale_fp32 = scale.cast("float32")
y_ref, invvar_ref = self.rms_norm_reference(x_fp32, scale_fp32)
y_ref = self.rms_norm_reference(x_fp32, scale_fp32)
else:
y_ref, invvar_ref = self.rms_norm_reference(x, scale)
y_ref = self.rms_norm_reference(x, scale)
loss_ref = paddle.mean(y_ref)
loss_ref.backward()

Expand Down
Loading