Skip to content

Commit be6802f

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in topk (#7755)
Partial fix for #7748.
1 parent 2a4418f commit be6802f

File tree

3 files changed

+32
-24
lines changed

3 files changed

+32
-24
lines changed

kernels/portable/cpu/op_topk.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ std::tuple<Tensor&, Tensor&> topk_values(
186186

187187
bool temp_mem_allocated = false;
188188

189-
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
189+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
190190
using elem_t = std::pair<CTYPE, int64_t>;
191191
size_t temp_mem_size = nonempty_size(in, dim) * sizeof(elem_t);
192192

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ set(all_test_sources
220220
"op_tan_test.cpp"
221221
"op_tanh_test.cpp"
222222
"op_to_copy_test.cpp"
223+
"op_topk_test.cpp"
223224
"op_transpose_copy_test.cpp"
224225
"op_tril_test.cpp"
225226
"op_trunc_test.cpp"

kernels/test/op_topk_test.cpp

+30-23
Original file line numberDiff line numberDiff line change
@@ -118,32 +118,39 @@ class OpTopkValuesTest : public ::testing::Test {
118118
// first.
119119
torch::executor::runtime_init();
120120
}
121+
122+
template <ScalarType DTYPE>
123+
void run_smoke_test() {
124+
TensorFactory<DTYPE> tfDtype;
125+
TensorFactory<ScalarType::Long> tfLong;
126+
127+
Tensor input =
128+
tfDtype.make({3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
129+
int64_t k = 2;
130+
int64_t dim = 0;
131+
bool largest = true;
132+
bool sorted = true;
133+
Tensor values = tfDtype.zeros({2, 2, 2});
134+
Tensor indices = tfLong.zeros({2, 2, 2});
135+
Tensor values_expected =
136+
tfDtype.make({2, 2, 2}, {9, 10, 11, 12, 5, 6, 7, 8});
137+
Tensor indices_expected = tfLong.make({2, 2, 2}, {2, 2, 2, 2, 1, 1, 1, 1});
138+
op_topk_values(input, k, dim, largest, sorted, values, indices);
139+
EXPECT_TENSOR_CLOSE(values, values_expected);
140+
EXPECT_TENSOR_EQ(indices, indices_expected);
141+
142+
largest = false;
143+
values_expected = tfDtype.make({2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
144+
indices_expected = tfLong.make({2, 2, 2}, {0, 0, 0, 0, 1, 1, 1, 1});
145+
op_topk_values(input, k, dim, largest, sorted, values, indices);
146+
EXPECT_TENSOR_CLOSE(values, values_expected);
147+
EXPECT_TENSOR_EQ(indices, indices_expected);
148+
}
121149
};
122150

123151
TEST_F(OpTopkValuesTest, SmokeTest) {
124-
TensorFactory<ScalarType::Float> tfFloat;
125-
TensorFactory<ScalarType::Long> tfLong;
126-
127-
Tensor input =
128-
tfFloat.make({3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
129-
int64_t k = 2;
130-
int64_t dim = 0;
131-
bool largest = true;
132-
bool sorted = true;
133-
Tensor values = tfFloat.zeros({2, 2, 2});
134-
Tensor indices = tfLong.zeros({2, 2, 2});
135-
Tensor values_expected = tfFloat.make({2, 2, 2}, {9, 10, 11, 12, 5, 6, 7, 8});
136-
Tensor indices_expected = tfLong.make({2, 2, 2}, {2, 2, 2, 2, 1, 1, 1, 1});
137-
op_topk_values(input, k, dim, largest, sorted, values, indices);
138-
EXPECT_TENSOR_CLOSE(values, values_expected);
139-
EXPECT_TENSOR_EQ(indices, indices_expected);
140-
141-
largest = false;
142-
values_expected = tfFloat.make({2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
143-
indices_expected = tfLong.make({2, 2, 2}, {0, 0, 0, 0, 1, 1, 1, 1});
144-
op_topk_values(input, k, dim, largest, sorted, values, indices);
145-
EXPECT_TENSOR_CLOSE(values, values_expected);
146-
EXPECT_TENSOR_EQ(indices, indices_expected);
152+
#define RUN_SMOKE_TEST(ctype, dtype) run_smoke_test<ScalarType::dtype>();
153+
ET_FORALL_REALHBF16_TYPES(RUN_SMOKE_TEST);
147154
}
148155

149156
TEST_F(OpTopkValuesTest, NonPartialSort) {

0 commit comments

Comments
 (0)