Skip to content

Commit

Permalink
Support Half/BFloat16 in native_layer_norm
Browse files Browse the repository at this point in the history
ghstack-source-id: f8d98dddce9c5fc56d75eb3bb48f228686a4dc76
ghstack-comment-id: 2599385748
Pull Request resolved: #7752
  • Loading branch information
swolchok committed Jan 17, 2025
1 parent 57a09f4 commit bff3ad3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion kernels/optimized/cpu/op_native_layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> opt_native_layer_norm_out(
InvalidArgument,
ret_val);

ET_SWITCH_FLOAT_TYPES(
ET_SWITCH_FLOATHBF16_TYPES(
input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() {
layer_norm<CTYPE>(
input,
Expand Down
11 changes: 6 additions & 5 deletions kernels/portable/cpu/op_native_layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,16 @@ void layer_norm(
bias_data = nullptr;
}

const CTYPE ct_normalized = static_cast<CTYPE>(normalized);
for (int i = 0; i < leading; ++i) {
const CTYPE* x = input_data + i * normalized;
CTYPE* y = out_data + i * normalized;

// compute E[X] and Var[x] = E[x^2] - E[x]^2
CTYPE sum = reduce_add(x, normalized);
CTYPE sq_sum = vec_powerf(x, normalized);
CTYPE mean_value = sum / normalized;
CTYPE variance = sq_sum / normalized - mean_value * mean_value;
CTYPE sum = reduce_add(x, ct_normalized);
CTYPE sq_sum = vec_powerf(x, ct_normalized);
CTYPE mean_value = sum / ct_normalized;
CTYPE variance = sq_sum / ct_normalized - mean_value * mean_value;
CTYPE std = std::sqrt(variance + eps);

// Calculate the elements of output
Expand Down Expand Up @@ -167,7 +168,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out(
InvalidArgument,
ret_val);

ET_SWITCH_FLOAT_TYPES(
ET_SWITCH_FLOATHBF16_TYPES(
input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() {
layer_norm<CTYPE>(
input,
Expand Down
12 changes: 11 additions & 1 deletion kernels/test/op_native_layer_norm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,15 @@ class OpNativeLayerNormTest : public OperatorTest {
EXPECT_TENSOR_CLOSE(out0, std::get<0>(result));

Tensor expected = tf.make(test_case.sizes, test_case.expected_data);
EXPECT_TENSOR_CLOSE(out0, expected);
if constexpr (DTYPE == ScalarType::BFloat16) {
EXPECT_TENSOR_CLOSE_WITH_TOL(
out0,
expected,
1e-2,
executorch::runtime::testing::internal::kDefaultBFloat16Atol);
} else {
EXPECT_TENSOR_CLOSE(out0, expected);
}
}
}

Expand Down Expand Up @@ -393,6 +401,8 @@ std::vector<int64_t> vector_32_to_64(std::vector<int32_t> vector_32) {
TEST_F(OpNativeLayerNormTest, FloatTensors) {
run_floating_point_test_cases<ScalarType::Float>();
run_floating_point_test_cases<ScalarType::Double>();
run_floating_point_test_cases<ScalarType::Half>();
run_floating_point_test_cases<ScalarType::BFloat16>();
}

TEST_F(OpNativeLayerNormTest, IntTensorsDies) {
Expand Down

0 comments on commit bff3ad3

Please sign in to comment.