Skip to content

Commit ad9e38d

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in abs/neg (#7760)
Partial fix for #7748.
1 parent d0a9ebe commit ad9e38d

File tree

5 files changed

+37
-21
lines changed

5 files changed

+37
-21
lines changed

kernels/optimized/cpu/op_neg.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Tensor& opt_neg_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
2626
out,
2727
"Failed to resize output tensor.");
2828

29-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
29+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
3030
using Vec = executorch::vec::Vectorized<CTYPE>;
3131
executorch::vec::map<CTYPE>(
3232
[](Vec x) { return x.neg(); },

kernels/portable/cpu/op_abs.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3131
ET_KERNEL_CHECK(
3232
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3333

34-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
34+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
3535
apply_unary_map_fn(
3636
[](const CTYPE val_in) {
3737
if (val_in < 0) {

kernels/portable/cpu/op_neg.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Tensor& neg_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3333
ET_KERNEL_CHECK(
3434
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3535

36-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
36+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
3737
apply_unary_map_fn(
3838
[](const CTYPE val_in) { return static_cast<CTYPE>(-val_in); },
3939
in.const_data_ptr<CTYPE>(),

kernels/test/op_abs_test.cpp

+17-9
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,27 @@ class OpAbsTest : public OperatorTest {
2424
Tensor& op_abs_out(const Tensor& self, Tensor& out) {
2525
return torch::executor::aten::abs_outf(context_, self, out);
2626
}
27-
};
2827

29-
TEST_F(OpAbsTest, SanityCheck) {
30-
TensorFactory<ScalarType::Float> tf;
28+
template <ScalarType DTYPE>
29+
void run_smoke_test() {
30+
TensorFactory<DTYPE> tf;
31+
32+
Tensor in = tf.make({1, 7}, {-3.0, -2.5, -1.01, 0.0, 1.01, 2.5, 3.0});
33+
Tensor out = tf.zeros({1, 7});
34+
Tensor expected = tf.make({1, 7}, {3.0, 2.5, 1.01, 0.0, 1.01, 2.5, 3.0});
3135

32-
Tensor in = tf.make({1, 7}, {-3.0, -2.5, -1.01, 0.0, 1.01, 2.5, 3.0});
33-
Tensor out = tf.zeros({1, 7});
34-
Tensor expected = tf.make({1, 7}, {3.0, 2.5, 1.01, 0.0, 1.01, 2.5, 3.0});
36+
Tensor ret = op_abs_out(in, out);
3537

36-
Tensor ret = op_abs_out(in, out);
38+
EXPECT_TENSOR_EQ(out, ret);
39+
EXPECT_TENSOR_EQ(out, expected);
40+
}
41+
};
3742

38-
EXPECT_TENSOR_EQ(out, ret);
39-
EXPECT_TENSOR_EQ(out, expected);
43+
TEST_F(OpAbsTest, SmokeTest) {
44+
#define RUN_SMOKE_TEST(ctype, dtype) run_smoke_test<ScalarType::dtype>();
45+
// TODO: cover all REALHBF16 types with generalized unary function test
46+
// harness.
47+
ET_FORALL_FLOATHBF16_TYPES(RUN_SMOKE_TEST);
4048
}
4149

4250
TEST_F(OpAbsTest, MemoryFormatCheck) {

kernels/test/op_neg_test.cpp

+17-9
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,25 @@ class OpNegTest : public OperatorTest {
2424
Tensor& op_neg_out(const Tensor& self, Tensor& out) {
2525
return torch::executor::aten::neg_outf(context_, self, out);
2626
}
27-
};
2827

29-
TEST_F(OpNegTest, SanityCheck) {
30-
TensorFactory<ScalarType::Float> tf;
28+
template <ScalarType DTYPE>
29+
void run_smoke_test() {
30+
TensorFactory<DTYPE> tf;
31+
32+
Tensor in = tf.make({1, 7}, {-3.0, -2.5, -1.01, 0.0, 1.01, 2.5, 3.0});
33+
Tensor out = tf.zeros({1, 7});
34+
Tensor expected = tf.make({1, 7}, {3.0, 2.5, 1.01, 0.0, -1.01, -2.5, -3.0});
3135

32-
Tensor in = tf.make({1, 7}, {-3.0, -2.5, -1.01, 0.0, 1.01, 2.5, 3.0});
33-
Tensor out = tf.zeros({1, 7});
34-
Tensor expected = tf.make({1, 7}, {3.0, 2.5, 1.01, 0.0, -1.01, -2.5, -3.0});
36+
Tensor ret = op_neg_out(in, out);
3537

36-
Tensor ret = op_neg_out(in, out);
38+
EXPECT_TENSOR_EQ(out, ret);
39+
EXPECT_TENSOR_EQ(out, expected);
40+
}
41+
};
3742

38-
EXPECT_TENSOR_EQ(out, ret);
39-
EXPECT_TENSOR_EQ(out, expected);
43+
TEST_F(OpNegTest, SmokeTest) {
44+
#define RUN_SMOKE_TEST(ctype, dtype) run_smoke_test<ScalarType::dtype>();
45+
// TODO: cover all REALHBF16 types with generalized unary function test
46+
// harness.
47+
ET_FORALL_FLOATHBF16_TYPES(RUN_SMOKE_TEST);
4048
}

0 commit comments

Comments
 (0)