@@ -45,6 +45,24 @@ class OpProdOutTest : public ::testing::Test {
45
45
// first.
46
46
torch::executor::runtime_init ();
47
47
}
48
+
49
+ template <ScalarType DTYPE>
50
+ void test_dtype () {
51
+ TensorFactory<DTYPE> tf;
52
+ TensorFactory<
53
+ executorch::runtime::isIntegralType (DTYPE, /* includeBool*/ true )
54
+ ? ScalarType::Long
55
+ : DTYPE>
56
+ tf_out;
57
+
58
+ Tensor self = tf.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
59
+ optional<ScalarType> dtype{};
60
+ Tensor out = tf_out.zeros ({});
61
+ Tensor out_expected =
62
+ tf_out.make ({}, {DTYPE == ScalarType::Bool ? 1 : 720 });
63
+ op_prod_out (self, dtype, out);
64
+ EXPECT_TENSOR_CLOSE (out, out_expected);
65
+ }
48
66
};
49
67
50
68
class OpProdIntOutTest : public ::testing::Test {
@@ -54,30 +72,32 @@ class OpProdIntOutTest : public ::testing::Test {
54
72
// first.
55
73
torch::executor::runtime_init ();
56
74
}
57
- };
58
75
59
- TEST_F (OpProdOutTest, SmokeTest) {
60
- TensorFactory<ScalarType::Float> tfFloat;
76
+ template <ScalarType DTYPE>
77
+ void test_dtype () {
78
+ TensorFactory<DTYPE> tf;
61
79
62
- Tensor self = tfFloat.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
63
- optional<ScalarType> dtype{};
64
- Tensor out = tfFloat.zeros ({});
65
- Tensor out_expected = tfFloat.make ({}, {720 });
66
- op_prod_out (self, dtype, out);
67
- EXPECT_TENSOR_CLOSE (out, out_expected);
68
- }
80
+ Tensor self = tf.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
81
+ int64_t dim = 0 ;
82
+ bool keepdim = false ;
83
+ optional<ScalarType> dtype{};
84
+ Tensor out = tf.zeros ({3 });
85
+ Tensor out_expected = tf.make ({3 }, {4 , 10 , 18 });
86
+ op_prod_int_out (self, dim, keepdim, dtype, out);
87
+ EXPECT_TENSOR_CLOSE (out, out_expected);
88
+ }
89
+ };
69
90
70
- TEST_F (OpProdIntOutTest, SmokeTest) {
71
- TensorFactory<ScalarType::Float> tfFloat;
91
+ TEST_F (OpProdOutTest, SmokeTest){
92
+ #define TEST_ENTRY (ctype, dtype ) test_dtype<ScalarType::dtype>();
93
+ ET_FORALL_REALHBBF16_TYPES (TEST_ENTRY)
94
+ #undef TEST_ENTRY
95
+ }
72
96
73
- Tensor self = tfFloat.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
74
- int64_t dim = 0 ;
75
- bool keepdim = false ;
76
- optional<ScalarType> dtype{};
77
- Tensor out = tfFloat.zeros ({3 });
78
- Tensor out_expected = tfFloat.make ({3 }, {4 , 10 , 18 });
79
- op_prod_int_out (self, dim, keepdim, dtype, out);
80
- EXPECT_TENSOR_CLOSE (out, out_expected);
97
+ TEST_F (OpProdIntOutTest, SmokeTest){
98
+ #define TEST_ENTRY (ctype, dtype ) test_dtype<ScalarType::dtype>();
99
+ ET_FORALL_REALHBBF16_TYPES (TEST_ENTRY)
100
+ #undef TEST_ENTRY
81
101
}
82
102
83
103
TEST_F (OpProdIntOutTest, SmokeTestKeepdim) {
0 commit comments