|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include <functional> |
| 16 | +#include <cstdint> |
16 | 17 | #include <random> |
17 | 18 | #include <string> |
18 | 19 | #include <thread> |
@@ -135,7 +136,8 @@ TEST(DistanceMatrix, SquaredEuclidean_General) { |
135 | 136 |
|
136 | 137 | template <size_t M, size_t N> |
137 | 138 | void TestEuclideanMatrix(void) { |
138 | | - std::mt19937 gen((std::random_device())()); |
| 139 | + std::mt19937 gen(static_cast<uint32_t>(0x5EED1234u + M * 131u + N * 17u)); |
| 140 | + constexpr int kFp16MatrixUlpTolerance = 20000; |
139 | 141 |
|
140 | 142 | const size_t batch_size = M; |
141 | 143 | const size_t query_size = N; |
@@ -174,13 +176,15 @@ void TestEuclideanMatrix(void) { |
174 | 176 |
|
175 | 177 | for (size_t i = 0; i < batch_size * query_size; ++i) { |
176 | 178 | // EXPECT_FLOAT_EQ(result1[i], result2[i]); |
177 | | - EXPECT_TRUE(MathHelper::IsAlmostEqual(result1[i], result2[i], 10000)); |
| 179 | + EXPECT_TRUE(MathHelper::IsAlmostEqual( |
| 180 | + result1[i], result2[i], kFp16MatrixUlpTolerance)); |
178 | 181 | } |
179 | 182 | } |
180 | 183 |
|
181 | 184 | template <size_t M, size_t N> |
182 | 185 | void TestSquaredEuclideanMatrix(void) { |
183 | | - std::mt19937 gen((std::random_device())()); |
| 186 | + std::mt19937 gen(static_cast<uint32_t>(0x5EED5678u + M * 131u + N * 17u)); |
| 187 | + constexpr int kFp16MatrixUlpTolerance = 20000; |
184 | 188 |
|
185 | 189 | const size_t batch_size = M; |
186 | 190 | const size_t query_size = N; |
@@ -219,7 +223,8 @@ void TestSquaredEuclideanMatrix(void) { |
219 | 223 |
|
220 | 224 | for (size_t i = 0; i < batch_size * query_size; ++i) { |
221 | 225 | // EXPECT_FLOAT_EQ(result1[i], result2[i]); |
222 | | - EXPECT_TRUE(MathHelper::IsAlmostEqual(result1[i], result2[i], 10000)); |
| 226 | + EXPECT_TRUE(MathHelper::IsAlmostEqual( |
| 227 | + result1[i], result2[i], kFp16MatrixUlpTolerance)); |
223 | 228 | } |
224 | 229 | } |
225 | 230 |
|
|
0 commit comments