From 30ab06568f05dc931894733ccedc1c8cc0795e21 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 17 Jan 2025 15:53:03 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/optimized/cpu/op_native_layer_norm.cpp | 2 +- kernels/portable/cpu/op_native_layer_norm.cpp | 11 ++++++----- kernels/test/op_native_layer_norm_test.cpp | 12 +++++++++++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/kernels/optimized/cpu/op_native_layer_norm.cpp b/kernels/optimized/cpu/op_native_layer_norm.cpp index d04265f3367..3bbb37708ec 100644 --- a/kernels/optimized/cpu/op_native_layer_norm.cpp +++ b/kernels/optimized/cpu/op_native_layer_norm.cpp @@ -155,7 +155,7 @@ std::tuple opt_native_layer_norm_out( InvalidArgument, ret_val); - ET_SWITCH_FLOAT_TYPES( + ET_SWITCH_FLOATHBF16_TYPES( input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() { layer_norm( input, diff --git a/kernels/portable/cpu/op_native_layer_norm.cpp b/kernels/portable/cpu/op_native_layer_norm.cpp index 36417e952de..788a844855f 100644 --- a/kernels/portable/cpu/op_native_layer_norm.cpp +++ b/kernels/portable/cpu/op_native_layer_norm.cpp @@ -66,15 +66,16 @@ void layer_norm( bias_data = nullptr; } + const CTYPE ct_normalized = static_cast(normalized); for (int i = 0; i < leading; ++i) { const CTYPE* x = input_data + i * normalized; CTYPE* y = out_data + i * normalized; // compute E[X] and Var[x] = E[x^2] - E[x]^2 - CTYPE sum = reduce_add(x, normalized); - CTYPE sq_sum = vec_powerf(x, normalized); - CTYPE mean_value = sum / normalized; - CTYPE variance = sq_sum / normalized - mean_value * mean_value; + CTYPE sum = reduce_add(x, ct_normalized); + CTYPE sq_sum = vec_powerf(x, ct_normalized); + CTYPE mean_value = sum / ct_normalized; + CTYPE variance = sq_sum / ct_normalized - mean_value * mean_value; CTYPE std = std::sqrt(variance + eps); // Calculate the elements of output @@ -167,7 +168,7 @@ std::tuple native_layer_norm_out( InvalidArgument, ret_val); - ET_SWITCH_FLOAT_TYPES( + ET_SWITCH_FLOATHBF16_TYPES( input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() { layer_norm( input, diff --git a/kernels/test/op_native_layer_norm_test.cpp b/kernels/test/op_native_layer_norm_test.cpp index fd1ca982d5b..99bf15d989d 100644 --- a/kernels/test/op_native_layer_norm_test.cpp +++ b/kernels/test/op_native_layer_norm_test.cpp @@ -95,7 +95,15 @@ class OpNativeLayerNormTest : public OperatorTest { EXPECT_TENSOR_CLOSE(out0, std::get<0>(result)); Tensor expected = tf.make(test_case.sizes, test_case.expected_data); - EXPECT_TENSOR_CLOSE(out0, expected); + if constexpr (DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out0, + expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultBFloat16Atol); + } else { + EXPECT_TENSOR_CLOSE(out0, expected); + } } } @@ -393,6 +401,8 @@ std::vector vector_32_to_64(std::vector vector_32) { TEST_F(OpNativeLayerNormTest, FloatTensors) { run_floating_point_test_cases(); run_floating_point_test_cases(); + run_floating_point_test_cases(); + run_floating_point_test_cases(); } TEST_F(OpNativeLayerNormTest, IntTensorsDies) {