@@ -118,32 +118,39 @@ class OpTopkValuesTest : public ::testing::Test {
118
118
// first.
119
119
torch::executor::runtime_init ();
120
120
}
121
+
122
+ template <ScalarType DTYPE>
123
+ void run_smoke_test () {
124
+ TensorFactory<DTYPE> tfDtype;
125
+ TensorFactory<ScalarType::Long> tfLong;
126
+
127
+ Tensor input =
128
+ tfDtype.make ({3 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
129
+ int64_t k = 2 ;
130
+ int64_t dim = 0 ;
131
+ bool largest = true ;
132
+ bool sorted = true ;
133
+ Tensor values = tfDtype.zeros ({2 , 2 , 2 });
134
+ Tensor indices = tfLong.zeros ({2 , 2 , 2 });
135
+ Tensor values_expected =
136
+ tfDtype.make ({2 , 2 , 2 }, {9 , 10 , 11 , 12 , 5 , 6 , 7 , 8 });
137
+ Tensor indices_expected = tfLong.make ({2 , 2 , 2 }, {2 , 2 , 2 , 2 , 1 , 1 , 1 , 1 });
138
+ op_topk_values (input, k, dim, largest, sorted, values, indices);
139
+ EXPECT_TENSOR_CLOSE (values, values_expected);
140
+ EXPECT_TENSOR_EQ (indices, indices_expected);
141
+
142
+ largest = false ;
143
+ values_expected = tfDtype.make ({2 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 });
144
+ indices_expected = tfLong.make ({2 , 2 , 2 }, {0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 });
145
+ op_topk_values (input, k, dim, largest, sorted, values, indices);
146
+ EXPECT_TENSOR_CLOSE (values, values_expected);
147
+ EXPECT_TENSOR_EQ (indices, indices_expected);
148
+ }
121
149
};
122
150
123
151
TEST_F (OpTopkValuesTest, SmokeTest) {
124
- TensorFactory<ScalarType::Float> tfFloat;
125
- TensorFactory<ScalarType::Long> tfLong;
126
-
127
- Tensor input =
128
- tfFloat.make ({3 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 });
129
- int64_t k = 2 ;
130
- int64_t dim = 0 ;
131
- bool largest = true ;
132
- bool sorted = true ;
133
- Tensor values = tfFloat.zeros ({2 , 2 , 2 });
134
- Tensor indices = tfLong.zeros ({2 , 2 , 2 });
135
- Tensor values_expected = tfFloat.make ({2 , 2 , 2 }, {9 , 10 , 11 , 12 , 5 , 6 , 7 , 8 });
136
- Tensor indices_expected = tfLong.make ({2 , 2 , 2 }, {2 , 2 , 2 , 2 , 1 , 1 , 1 , 1 });
137
- op_topk_values (input, k, dim, largest, sorted, values, indices);
138
- EXPECT_TENSOR_CLOSE (values, values_expected);
139
- EXPECT_TENSOR_EQ (indices, indices_expected);
140
-
141
- largest = false ;
142
- values_expected = tfFloat.make ({2 , 2 , 2 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 });
143
- indices_expected = tfLong.make ({2 , 2 , 2 }, {0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 });
144
- op_topk_values (input, k, dim, largest, sorted, values, indices);
145
- EXPECT_TENSOR_CLOSE (values, values_expected);
146
- EXPECT_TENSOR_EQ (indices, indices_expected);
152
+ #define RUN_SMOKE_TEST (ctype, dtype ) run_smoke_test<ScalarType::dtype>();
153
+ ET_FORALL_REALHBF16_TYPES (RUN_SMOKE_TEST);
147
154
}
148
155
149
156
TEST_F (OpTopkValuesTest, NonPartialSort) {
0 commit comments