Skip to content

Commit bff3ad3

Browse files
committed
Support Half/BFloat16 in native_layer_norm
ghstack-source-id: f8d98dddce9c5fc56d75eb3bb48f228686a4dc76 ghstack-comment-id: 2599385748 Pull Request resolved: #7752
1 parent 57a09f4 commit bff3ad3

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

kernels/optimized/cpu/op_native_layer_norm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> opt_native_layer_norm_out(
155155
InvalidArgument,
156156
ret_val);
157157

158-
ET_SWITCH_FLOAT_TYPES(
158+
ET_SWITCH_FLOATHBF16_TYPES(
159159
input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() {
160160
layer_norm<CTYPE>(
161161
input,

kernels/portable/cpu/op_native_layer_norm.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,16 @@ void layer_norm(
6666
bias_data = nullptr;
6767
}
6868

69+
const CTYPE ct_normalized = static_cast<CTYPE>(normalized);
6970
for (int i = 0; i < leading; ++i) {
7071
const CTYPE* x = input_data + i * normalized;
7172
CTYPE* y = out_data + i * normalized;
7273

7374
// compute E[X] and Var[x] = E[x^2] - E[x]^2
74-
CTYPE sum = reduce_add(x, normalized);
75-
CTYPE sq_sum = vec_powerf(x, normalized);
76-
CTYPE mean_value = sum / normalized;
77-
CTYPE variance = sq_sum / normalized - mean_value * mean_value;
75+
CTYPE sum = reduce_add(x, ct_normalized);
76+
CTYPE sq_sum = vec_powerf(x, ct_normalized);
77+
CTYPE mean_value = sum / ct_normalized;
78+
CTYPE variance = sq_sum / ct_normalized - mean_value * mean_value;
7879
CTYPE std = std::sqrt(variance + eps);
7980

8081
// Calculate the elements of output
@@ -167,7 +168,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out(
167168
InvalidArgument,
168169
ret_val);
169170

170-
ET_SWITCH_FLOAT_TYPES(
171+
ET_SWITCH_FLOATHBF16_TYPES(
171172
input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() {
172173
layer_norm<CTYPE>(
173174
input,

kernels/test/op_native_layer_norm_test.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,15 @@ class OpNativeLayerNormTest : public OperatorTest {
9595
EXPECT_TENSOR_CLOSE(out0, std::get<0>(result));
9696

9797
Tensor expected = tf.make(test_case.sizes, test_case.expected_data);
98-
EXPECT_TENSOR_CLOSE(out0, expected);
98+
if constexpr (DTYPE == ScalarType::BFloat16) {
99+
EXPECT_TENSOR_CLOSE_WITH_TOL(
100+
out0,
101+
expected,
102+
1e-2,
103+
executorch::runtime::testing::internal::kDefaultBFloat16Atol);
104+
} else {
105+
EXPECT_TENSOR_CLOSE(out0, expected);
106+
}
99107
}
100108
}
101109

@@ -393,6 +401,8 @@ std::vector<int64_t> vector_32_to_64(std::vector<int32_t> vector_32) {
393401
TEST_F(OpNativeLayerNormTest, FloatTensors) {
394402
run_floating_point_test_cases<ScalarType::Float>();
395403
run_floating_point_test_cases<ScalarType::Double>();
404+
run_floating_point_test_cases<ScalarType::Half>();
405+
run_floating_point_test_cases<ScalarType::BFloat16>();
396406
}
397407

398408
TEST_F(OpNativeLayerNormTest, IntTensorsDies) {

0 commit comments

Comments
 (0)