Skip to content

Commit 03cba2c

Browse files
authored
Support Half/BFloat16 in pdist_forward (#7852)
Partial fix for #7748.
1 parent debafbe commit 03cba2c

File tree

2 files changed

+56
-33
lines changed

2 files changed

+56
-33
lines changed

kernels/portable/cpu/op_pdist_forward.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Tensor& _pdist_forward_out(
4242
ScalarType in_type = in.scalar_type();
4343
constexpr auto name = "_pdist_forward.out";
4444

45-
ET_SWITCH_FLOAT_TYPES(
45+
ET_SWITCH_FLOATHBF16_TYPES(
4646
in_type, ctx, name, CTYPE, [&] { pdist<CTYPE>(in, out, p); });
4747

4848
return out;

kernels/test/op_pdist_forward_test.cpp

+55-32
Original file line numberDiff line numberDiff line change
@@ -33,45 +33,68 @@ class OpPdistForwardOutTest : public ::testing::Test {
3333
// first.
3434
torch::executor::runtime_init();
3535
}
36-
};
3736

38-
TEST_F(OpPdistForwardOutTest, SmokeTest) {
39-
TensorFactory<ScalarType::Float> tfFloat;
37+
template <ScalarType DTYPE>
38+
void test_dtype() {
39+
TensorFactory<DTYPE> tf;
40+
41+
Tensor in = tf.make({4, 5}, {0, 1, 2, 3, 5, 4, 3, 2, -1, 5,
42+
1, 1, -2, 1, 5, 4, 3, 2, -1, 5});
43+
Tensor out = tf.zeros({6});
4044

41-
Tensor in = tfFloat.make(
42-
{4, 5}, {0, 1, 2, 3, 5, 4, 3, 2, -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5});
43-
Tensor out = tfFloat.zeros({6});
45+
Tensor l0 = tf.make({6}, {3., 3., 3., 4., 0., 4.});
46+
op_pdist_forward_out(in, 0.0, out);
47+
EXPECT_TENSOR_CLOSE(out, l0);
4448

45-
Tensor l0 = tfFloat.make({6}, {3., 3., 3., 4., 0., 4.});
46-
op_pdist_forward_out(in, 0.0, out);
47-
EXPECT_TENSOR_CLOSE(out, l0);
49+
Tensor l0p5 = tf.make(
50+
{6},
51+
{29.31370926, 19.48528290, 29.31370926, 43.03986740, 0.0, 43.03986740});
52+
op_pdist_forward_out(in, 0.5, out);
53+
if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
54+
EXPECT_TENSOR_CLOSE_WITH_TOL(
55+
out,
56+
l0p5,
57+
1e-2,
58+
executorch::runtime::testing::internal::kDefaultAtol);
59+
} else {
60+
EXPECT_TENSOR_CLOSE(out, l0p5);
61+
}
4862

49-
Tensor l0p5 = tfFloat.make(
50-
{6},
51-
{29.31370926, 19.48528290, 29.31370926, 43.03986740, 0.0, 43.03986740});
52-
op_pdist_forward_out(in, 0.5, out);
53-
EXPECT_TENSOR_CLOSE(out, l0p5);
63+
Tensor l1 = tf.make({6}, {10., 7., 10., 11., 0., 11.});
64+
op_pdist_forward_out(in, 1.0, out);
65+
EXPECT_TENSOR_CLOSE(out, l1);
5466

55-
Tensor l1 = tfFloat.make({6}, {10., 7., 10., 11., 0., 11.});
56-
op_pdist_forward_out(in, 1.0, out);
57-
EXPECT_TENSOR_CLOSE(out, l1);
67+
Tensor l1p5 = tf.make(
68+
{6}, {7.07743692, 5.19140196, 7.07743692, 7.08359480, 0.0, 7.08359480});
69+
op_pdist_forward_out(in, 1.5, out);
70+
if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
71+
EXPECT_TENSOR_CLOSE_WITH_TOL(
72+
out,
73+
l1p5,
74+
1e-2,
75+
executorch::runtime::testing::internal::kDefaultAtol);
76+
} else {
77+
EXPECT_TENSOR_CLOSE(out, l1p5);
78+
}
5879

59-
Tensor l1p5 = tfFloat.make(
60-
{6}, {7.07743692, 5.19140196, 7.07743692, 7.08359480, 0.0, 7.08359480});
61-
op_pdist_forward_out(in, 1.5, out);
62-
EXPECT_TENSOR_CLOSE(out, l1p5);
80+
Tensor l2 =
81+
tf.make({6}, {6.0, 4.58257580, 6.0, 5.74456263, 0.0, 5.74456263});
82+
op_pdist_forward_out(in, 2.0, out);
83+
EXPECT_TENSOR_CLOSE(out, l2);
6384

64-
Tensor l2 =
65-
tfFloat.make({6}, {6.0, 4.58257580, 6.0, 5.74456263, 0.0, 5.74456263});
66-
op_pdist_forward_out(in, 2.0, out);
67-
EXPECT_TENSOR_CLOSE(out, l2);
85+
Tensor l3 = tf.make(
86+
{6}, {5.14256334, 4.17933941, 5.14256334, 4.74745941, 0.0, 4.74745941});
87+
op_pdist_forward_out(in, 3.0, out);
88+
EXPECT_TENSOR_CLOSE(out, l3);
6889

69-
Tensor l3 = tfFloat.make(
70-
{6}, {5.14256334, 4.17933941, 5.14256334, 4.74745941, 0.0, 4.74745941});
71-
op_pdist_forward_out(in, 3.0, out);
72-
EXPECT_TENSOR_CLOSE(out, l3);
90+
Tensor linf = tf.make({6}, {4., 4., 4., 4., 0., 4.});
91+
op_pdist_forward_out(in, INFINITY, out);
92+
EXPECT_TENSOR_CLOSE(out, linf);
93+
}
94+
};
7395

74-
Tensor linf = tfFloat.make({6}, {4., 4., 4., 4., 0., 4.});
75-
op_pdist_forward_out(in, INFINITY, out);
76-
EXPECT_TENSOR_CLOSE(out, linf);
96+
TEST_F(OpPdistForwardOutTest, SmokeTest) {
97+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
98+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY)
99+
#undef TEST_ENTRY
77100
}

0 commit comments

Comments
 (0)