33// This source code is licensed under the BSD-style license found in the
44// LICENSE file in the root directory of this source tree.
55
6+ #include < algorithm>
67#include < cassert>
7- #include < chrono>
8+ #include < chrono> // NOLINT(build/c++11)
89#include < cstddef>
910#include < cstdlib>
1011#include < random>
2021
2122namespace xnnpack {
2223
24+ namespace {
25+
26+ static const float kMaxR = 10 .0f ;
27+ static const float kMaxI = 1.0 ;
28+
29+ }; // namespace
30+
2331template <typename T>
2432Tensor<T> ReferenceImpl (Tensor<T> x, Tensor<T> w) {
2533 assert (x.rank () == 5 );
@@ -76,8 +84,8 @@ void TestImpl() {
7684 Tensor<T> input ({batch_size, tokens, heads, 2 , channels},
7785 PaddingBytes{XNN_EXTRA_BYTES});
7886 Tensor<T> weights ({max_tokens, 2 , channels}, PaddingBytes{XNN_EXTRA_BYTES});
79- DatatypeGenerator<T> gen_r (1 .0f , 10 . 0f );
80- DatatypeGenerator<T> gen_i (0 .01f , 1 . 0f );
87+ DatatypeGenerator<T> gen_r (1 .0f , kMaxR );
88+ DatatypeGenerator<T> gen_i (0 .01f , kMaxI );
8189 input.slice (3 , 0 ).generate ([&]() { return gen_r (rng); });
8290 input.slice (3 , 1 ).generate ([&]() { return gen_i (rng); });
8391 weights.slice (1 , 0 ).generate ([&]() { return gen_r (rng); });
@@ -98,12 +106,9 @@ void TestImpl() {
98106 .InvokeRuntime ();
99107
100108 // Verify results.
101- T max_abs_val = 0 .0f ;
102- for (const T& val : output) {
103- max_abs_val = std::max<T>(max_abs_val, std::abs (val));
104- }
109+ const float max_input_val = std::max (kMaxR , kMaxR );
105110 const float abs_tol =
106- max_abs_val * 2 . 0f * xnnpack::epsilon (xnn_datatype_of<T>());
111+ max_input_val * max_input_val * xnnpack::epsilon (xnn_datatype_of<T>());
107112 for (const auto & i : EnumerateIndices (output.extents ())) {
108113 ASSERT_NEAR (output (i), expected (i), abs_tol);
109114 }
0 commit comments