Skip to content

Commit 28667b8

Browse files
xnnpack-botJayakrishna T N
authored andcommitted
Bulk sync to github
PiperOrigin-RevId: 743060532
1 parent 78e889c commit 28667b8

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

test/subgraph/rope.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
// This source code is licensed under the BSD-style license found in the
44
// LICENSE file in the root directory of this source tree.
55

6+
#include <algorithm>
67
#include <cassert>
7-
#include <chrono>
8+
#include <chrono> // NOLINT(build/c++11)
89
#include <cstddef>
910
#include <cstdlib>
1011
#include <random>
@@ -20,6 +21,13 @@
2021

2122
namespace xnnpack {
2223

24+
namespace {
25+
26+
static const float kMaxR = 10.0f;
27+
static const float kMaxI = 1.0;
28+
29+
}; // namespace
30+
2331
template <typename T>
2432
Tensor<T> ReferenceImpl(Tensor<T> x, Tensor<T> w) {
2533
assert(x.rank() == 5);
@@ -76,8 +84,8 @@ void TestImpl() {
7684
Tensor<T> input({batch_size, tokens, heads, 2, channels},
7785
PaddingBytes{XNN_EXTRA_BYTES});
7886
Tensor<T> weights({max_tokens, 2, channels}, PaddingBytes{XNN_EXTRA_BYTES});
79-
DatatypeGenerator<T> gen_r(1.0f, 10.0f);
80-
DatatypeGenerator<T> gen_i(0.01f, 1.0f);
87+
DatatypeGenerator<T> gen_r(1.0f, kMaxR);
88+
DatatypeGenerator<T> gen_i(0.01f, kMaxI);
8189
input.slice(3, 0).generate([&]() { return gen_r(rng); });
8290
input.slice(3, 1).generate([&]() { return gen_i(rng); });
8391
weights.slice(1, 0).generate([&]() { return gen_r(rng); });
@@ -98,12 +106,9 @@ void TestImpl() {
98106
.InvokeRuntime();
99107

100108
// Verify results.
101-
T max_abs_val = 0.0f;
102-
for (const T& val : output) {
103-
max_abs_val = std::max<T>(max_abs_val, std::abs(val));
104-
}
109+
const float max_input_val = std::max(kMaxR, kMaxR);
105110
const float abs_tol =
106-
max_abs_val * 2.0f * xnnpack::epsilon(xnn_datatype_of<T>());
111+
max_input_val * max_input_val * xnnpack::epsilon(xnn_datatype_of<T>());
107112
for (const auto& i : EnumerateIndices(output.extents())) {
108113
ASSERT_NEAR(output(i), expected(i), abs_tol);
109114
}

0 commit comments

Comments
 (0)