diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index cf3b8be50e8f14..16fc235f942fd2 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -4306,8 +4306,8 @@ data_type: w - backward_op: rms_norm_grad - forward: rms_norm (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon = 1e-5) -> Tensor(y), Tensor(invvar) - args: (Tensor x, Tensor scale, Tensor invvar, Tensor y_grad, int64_t[] normalized_shape={}, double epsilon = 1e-5) + forward: rms_norm (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon = 1.19209289550781250e-7) -> Tensor(y), Tensor(invvar) + args: (Tensor x, Tensor scale, Tensor invvar, Tensor y_grad, int64_t[] normalized_shape={}, double epsilon = 1.19209289550781250e-7) output: Tensor(x_grad), Tensor(scale_grad) infer_meta: func: RMSNormGradInferMeta diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index a1706ff28cb3a2..de1c21c66e6523 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -6166,7 +6166,7 @@ traits : paddle::dialect::ForwardOnlyTrait - op: rms_norm - args: (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon= 1e-5) + args: (Tensor x, Tensor scale, int64_t[] normalized_shape={}, double epsilon=1.19209289550781250e-7) output: Tensor(y), Tensor(invvar) infer_meta: func: RmsNormInferMeta diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 4e0f88c25707b2..e1f238ba29b162 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -18,6 +18,7 @@ import numbers from typing import TYPE_CHECKING, Any +import numpy as np from typing_extensions import overload import paddle @@ -465,9 +466,9 @@ def rms_norm( input: Tensor, normalized_shape: Sequence[int], weight: Tensor | None = None, - eps: float = 1e-5, + eps: float | None = None, name: str | None = None, -) -> tuple[Tensor, Tensor]: +) -> Tensor: """ Applies Layer Normalization over the last dimension of the input tensor using CUDA implementation. @@ -478,33 +479,23 @@ def rms_norm( If it is a single integer, this module will normalize over the last dimension which is expected to be of that specific size. weight(Tensor, optional): The weight tensor of rms_norm. Default: None. - eps(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. + eps(float, optional): The small value added to the variance to prevent division by zero. + If None, uses machine epsilon for the compute dtype: ``float64`` inputs use + ``np.finfo(np.float64).eps`` (double epsilon), all other dtypes use + ``np.finfo(np.float32).eps`` (float epsilon). Default: None. name (str, optional): Name of the operator. Returns: out (Tensor): Normalized tensor of same shape as input. - invvar (Tensor): Tensor of shape [rows], the inverse standard deviation of each row. """ - if in_dynamic_or_pir_mode(): - return _C_ops.rms_norm(input, weight, normalized_shape, eps) - - helper = LayerHelper('rms_norm', **locals()) - from paddle.base.data_feeder import convert_dtype - - dtype = convert_dtype(input.dtype) - out = helper.create_variable_for_type_inference(dtype) - invvar = helper.create_variable_for_type_inference('float32') - - inputs = {'input': input, 'weight': weight} + if eps is None: + if input.dtype == paddle.float64: + eps = np.finfo(np.float64).eps # ~2.22e-16, double machine epsilon + else: + eps = np.finfo(np.float32).eps # ~1.19e-7, float machine epsilon - helper.append_op( - type='rms_norm', - inputs=inputs, - outputs={'out': out, 'invvar': invvar}, - attrs={"normalized_shape": normalized_shape, "eps": eps}, - ) - return out, invvar + return _C_ops.rms_norm(input, weight, normalized_shape, eps)[0] def instance_norm( diff --git a/test/legacy_test/test_rms_norm_op.py b/test/legacy_test/test_rms_norm_op.py index 43c5bbad049632..8ce5fd72c336ec 100644 --- a/test/legacy_test/test_rms_norm_op.py +++ b/test/legacy_test/test_rms_norm_op.py @@ -60,7 +60,9 @@ def setUp(self): self.outputs = {'y': y_ref, 'invvar': invvar_ref} def rms_norm_wrapper(x, scale): - return rms_norm(x, scale.shape, scale, eps=self.epsilon) + from paddle import _C_ops + + return _C_ops.rms_norm(x, scale, scale.shape, self.epsilon) self.python_api = rms_norm_wrapper @@ -124,15 +126,12 @@ def test_api_dygraph(self): scale.stop_gradient = False # Test forward - y_fused, invvar_fused = rms_norm(x, (cols,), scale) - y_ref, invvar_ref = self.rms_norm_reference(x, scale) + y_fused = rms_norm(x, (cols,), scale) + y_ref, _ = self.rms_norm_reference(x, scale) np.testing.assert_allclose( y_fused.numpy(), y_ref.numpy(), rtol=1e-5, atol=1e-5 ) - np.testing.assert_allclose( - invvar_fused.numpy(), invvar_ref.numpy(), rtol=1e-5, atol=1e-5 - ) # Test backward loss = paddle.mean(y_fused) @@ -174,5 +173,83 @@ def test_weight_shape_mismatch(self): rms_norm(x, [3], weight=weight) +class TestRMSNormEpsNone(unittest.TestCase): + """Tests that eps=None selects the correct machine epsilon per dtype.""" + + def _ref(self, x_np, scale_np, epsilon): + variance = np.mean(np.square(x_np), axis=-1, keepdims=True) + rms = np.sqrt(variance + epsilon) + return x_np / rms * scale_np + + def test_eps_none_float32(self): + """eps=None with float32 input should use float machine epsilon.""" + + rows, cols = 8, 16 + x_np = np.random.randn(rows, cols).astype("float32") + scale_np = np.ones(cols, dtype="float32") + + x = paddle.to_tensor(x_np) + scale = paddle.to_tensor(scale_np) + + y_none = rms_norm(x, (cols,), scale, eps=None) + float_eps = 1.1920929e-07 + y_explicit = rms_norm(x, (cols,), scale, eps=float_eps) + + np.testing.assert_array_equal(y_none.numpy(), y_explicit.numpy()) + + y_ref = self._ref(x_np, scale_np, float_eps) + np.testing.assert_allclose( + y_none.numpy(), y_ref.astype("float32"), rtol=1e-5, atol=1e-5 + ) + + def test_eps_none_float64(self): + """eps=None with float64 input should use double machine epsilon.""" + import sys + + rows, cols = 8, 16 + x_np = np.random.randn(rows, cols).astype("float64") + scale_np = np.ones(cols, dtype="float64") + + x = paddle.to_tensor(x_np) + scale = paddle.to_tensor(scale_np) + + y_none = rms_norm(x, (cols,), scale, eps=None) + double_eps = sys.float_info.epsilon # ~2.22e-16 + y_explicit = rms_norm(x, (cols,), scale, eps=double_eps) + + np.testing.assert_array_equal(y_none.numpy(), y_explicit.numpy()) + + y_ref = self._ref(x_np, scale_np, double_eps) + np.testing.assert_allclose( + y_none.numpy(), y_ref, rtol=1e-12, atol=1e-12 + ) + + def test_eps_none_float32_differs_from_float64(self): + """float32 and float64 defaults should be different epsilon values.""" + import sys + + float_eps = 1.1920929e-07 + double_eps = sys.float_info.epsilon + self.assertNotAlmostEqual(float_eps, double_eps, places=10) + + def test_eps_none_backward_float32(self): + """eps=None should work through backward pass for float32.""" + rows, cols = 8, 16 + x_np = np.random.randn(rows, cols).astype("float32") + scale_np = np.ones(cols, dtype="float32") + + x = paddle.to_tensor(x_np) + x.stop_gradient = False + scale = paddle.to_tensor(scale_np) + scale.stop_gradient = False + + y = rms_norm(x, (cols,), scale, eps=None) + loss = paddle.mean(y) + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertIsNotNone(scale.grad) + + if __name__ == '__main__': unittest.main() diff --git a/test/xpu/test_rms_norm_xpu.py b/test/xpu/test_rms_norm_xpu.py index e465b969579685..4ce394a2ed54e9 100644 --- a/test/xpu/test_rms_norm_xpu.py +++ b/test/xpu/test_rms_norm_xpu.py @@ -33,30 +33,27 @@ def rms_norm_reference(self, x, scale, bias=None, epsilon=1e-5): if bias is not None: y = y + bias.reshape([1, -1]) - return y, paddle.flatten(1.0 / rms) + return y def test_2d_input(self): rows, cols = 32, 64 x = paddle.randn([rows, cols]) scale = paddle.randn([cols]) - y_fused, invvar_fused = rms_norm(x, (cols,), scale) + y_fused = rms_norm(x, (cols,), scale) - y_ref, invvar_ref = self.rms_norm_reference(x, scale) + y_ref = self.rms_norm_reference(x, scale) np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose( - invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 - ) def test_3d_input(self): batch, rows, cols = 16, 32, 64 x = paddle.randn([batch, rows, cols]) scale = paddle.randn([cols]) - y_fused, invvar_fused = rms_norm(x, (cols,), scale) + y_fused = rms_norm(x, (cols,), scale) - y_ref, invvar_ref = self.rms_norm_reference(x, scale) + y_ref = self.rms_norm_reference(x, scale) np.testing.assert_allclose( y_fused.astype("float32"), @@ -64,23 +61,17 @@ def test_3d_input(self): rtol=1e-5, atol=1e-5, ) - np.testing.assert_allclose( - invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 - ) def test_without_bias(self): rows, cols = 32, 64 x = paddle.randn([rows, cols]) scale = paddle.randn([cols]) - y_fused, invvar_fused = rms_norm(x, (cols,), scale) + y_fused = rms_norm(x, (cols,), scale) - y_ref, invvar_ref = self.rms_norm_reference(x, scale) + y_ref = self.rms_norm_reference(x, scale) np.testing.assert_allclose(y_fused, y_ref, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose( - invvar_fused, invvar_ref, rtol=1e-5, atol=1e-5 - ) def test_3d_backward(self): batch, rows, cols = 8, 16, 32 @@ -89,7 +80,7 @@ def test_3d_backward(self): scale = paddle.randn([cols], dtype='float32') scale.stop_gradient = False - y_fused, invvar = rms_norm(x, (cols,), scale) + y_fused = rms_norm(x, (cols,), scale) loss = paddle.mean(y_fused) loss.backward() @@ -100,7 +91,7 @@ def test_3d_backward(self): x.clear_gradient() scale.clear_gradient() - y_ref, invvar_ref = self.rms_norm_reference(x, scale) + y_ref = self.rms_norm_reference(x, scale) loss_ref = paddle.mean(y_ref) loss_ref.backward() @@ -126,7 +117,7 @@ def test_backward(self): scale = paddle.randn([cols], dtype=scale_type) scale.stop_gradient = False - y_fused, invvar = rms_norm(x, (cols,), scale) + y_fused = rms_norm(x, (cols,), scale) loss = paddle.mean(y_fused) loss.backward() @@ -141,9 +132,9 @@ def test_backward(self): # FIXME(yangjianbang): XPU sqrt_grad does not support bfloat16 x_fp32 = x.cast("float32") scale_fp32 = scale.cast("float32") - y_ref, invvar_ref = self.rms_norm_reference(x_fp32, scale_fp32) + y_ref = self.rms_norm_reference(x_fp32, scale_fp32) else: - y_ref, invvar_ref = self.rms_norm_reference(x, scale) + y_ref = self.rms_norm_reference(x, scale) loss_ref = paddle.mean(y_ref) loss_ref.backward()