Skip to content

Commit

Permalink
Fix more warnings and type safety issues
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728877342
  • Loading branch information
dsharletg authored and xnnpack-bot committed Feb 20, 2025
1 parent 6b08d6e commit 9f9f69c
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 52 deletions.
10 changes: 6 additions & 4 deletions bench/models/qd8-attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len,
// Scales must be positive.
auto f32rng = std::bind(std::uniform_real_distribution<float>(0.01f, +1.0f),
std::ref(rng));
auto i8rng = std::bind(std::uniform_int_distribution<int>(-127, 127),
std::ref(rng));

// External inputs and outputs.
uint32_t input_id = XNN_INVALID_VALUE_ID;
Expand Down Expand Up @@ -83,7 +85,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len,
std::generate(weights.value_scale.begin(), weights.value_scale.end(),
std::ref(f32rng));
std::generate(weights.value_data.begin(), weights.value_data.end(),
std::ref(f32rng));
std::ref(i8rng));
status = xnn_define_channelwise_quantized_tensor_value(
subgraph, xnn_datatype_qcint8, weights.value_scale.data(),
value_dims.size(), value_dims.size() - 2, value_dims.data(),
Expand All @@ -100,7 +102,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len,
std::generate(weights.query_scale.begin(), weights.query_scale.end(),
std::ref(f32rng));
std::generate(weights.query_data.begin(), weights.query_data.end(),
std::ref(f32rng));
std::ref(i8rng));
status = xnn_define_channelwise_quantized_tensor_value(
subgraph, xnn_datatype_qcint8, weights.query_scale.data(),
query_dims.size(), query_dims.size() - 2, query_dims.data(),
Expand All @@ -117,7 +119,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len,
std::generate(weights.key_scale.begin(), weights.key_scale.end(),
std::ref(f32rng));
std::generate(weights.key_data.begin(), weights.key_data.end(),
std::ref(f32rng));
std::ref(i8rng));
status = xnn_define_channelwise_quantized_tensor_value(
subgraph, xnn_datatype_qcint8, weights.key_scale.data(), key_dims.size(),
key_dims.size() - 2, key_dims.data(), weights.key_data.data(),
Expand Down Expand Up @@ -311,7 +313,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len,
std::generate(weights.post_proj_scale.begin(), weights.post_proj_scale.end(),
std::ref(f32rng));
std::generate(weights.post_proj_data.begin(), weights.post_proj_data.end(),
std::ref(f32rng));
std::ref(i8rng));
status = xnn_define_channelwise_quantized_tensor_value(
subgraph, xnn_datatype_qcint8, weights.post_proj_scale.data(),
post_proj_dims.size(), post_proj_dims.size() - 2, post_proj_dims.data(),
Expand Down
3 changes: 2 additions & 1 deletion bench/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ static void benchmark_unary_operator(benchmark::State& state,

xnnpack::Buffer<In> input(batch_size + XNN_EXTRA_BYTES / sizeof(In));
xnnpack::Buffer<Out> output(batch_size);
std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
std::generate(input.begin(), input.end(),
[&]() { return static_cast<In>(f32dist(rng)); });

xnn_status status = xnn_initialize(nullptr /* allocator */);
if (status != xnn_status_success) {
Expand Down
5 changes: 2 additions & 3 deletions test/binary-elementwise-nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ class BinaryElementwiseOperatorTester {
MinMaxLow limits = DatatypeMinMaxLow(datatype());

xnnpack::ReplicableRandomDevice rng;
std::uniform_real_distribution<double> dist(limits.min, limits.max);

// Compute generalized shapes.
std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
Expand Down Expand Up @@ -358,8 +357,8 @@ class BinaryElementwiseOperatorTester {
xnnpack::Buffer<T> input2(XNN_EXTRA_BYTES + num_input2_elements());
xnnpack::Buffer<T> output(num_output_elements);
for (size_t iteration = 0; iteration < iterations(); iteration++) {
xnnpack::randomize_buffer(datatype(), rng, dist, input1);
xnnpack::randomize_buffer(datatype(), rng, dist, input2);
xnnpack::randomize_buffer(datatype(), rng, limits.min, limits.max, input1);
xnnpack::randomize_buffer(datatype(), rng, limits.min, limits.max, input2);

if (mode == RunMode::kCreateReshapeRun) {
// Create, setup, run, and destroy a binary elementwise operator.
Expand Down
5 changes: 2 additions & 3 deletions test/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,8 @@ void MatchesOperatorApi(xnn_datatype datatype, xnn_binary_operator binary_op) {
assert(false);
break;
}
std::uniform_real_distribution<double> dist(datatype_min, datatype_max);
randomize_buffer(datatype, rng, dist, input0);
randomize_buffer(datatype, rng, dist, input1);
randomize_buffer(datatype, rng, datatype_min, datatype_max, input0);
randomize_buffer(datatype, rng, datatype_min, datatype_max, input1);

ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

Expand Down
6 changes: 3 additions & 3 deletions test/dwconv-microkernel-tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ void DWConvMicrokernelTester::Test(
xnnpack::Buffer<xnn_float16> bias(channels());
xnnpack::Buffer<xnn_float16, XNN_ALLOCATION_ALIGNMENT> packed_weights(
(kernel_tile() + 1) * packed_channels());
xnnpack::Buffer<xnn_float16> zero(channels() + XNN_EXTRA_BYTES / sizeof(xnn_float16), 0);
xnnpack::Buffer<xnn_float16> zero(channels() + XNN_EXTRA_BYTES / sizeof(xnn_float16), 0.0f);
xnnpack::Buffer<xnn_float16> output((width() - 1) * output_stride() + channels());
xnnpack::Buffer<float> output_ref(width() * channels());

Expand All @@ -984,7 +984,7 @@ void DWConvMicrokernelTester::Test(
std::generate(bias.begin(), bias.end(),
[&]() { return f32dist(rng); });

std::fill(packed_weights.begin(), packed_weights.end(), 0);
std::fill(packed_weights.begin(), packed_weights.end(), 0.0f);
xnn_pack_f16_dwconv_ghw_w(
kernel_tile(), 0, 0, kernel_tile(), 1, channels(), channel_tile(),
channel_tile(), channel_tile(),
Expand Down Expand Up @@ -1101,7 +1101,7 @@ void DWConvMicrokernelTester::Test(
std::generate(bias.begin(), bias.end(),
[&]() { return f32dist(rng); });

std::fill(packed_weights.begin(), packed_weights.end(), 0);
std::fill(packed_weights.begin(), packed_weights.end(), 0.0f);
xnn_pack_f16_dwconv_ghw_w(
first_pass_tile(), middle_pass_tile(), last_pass_tile(), kernel_size(),
1, channels(), channel_tile(), channel_subtile(), channel_round(),
Expand Down
2 changes: 1 addition & 1 deletion test/fully-connected-operator-tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class FullyConnectedOperatorTester {
}

// Compute reference results, without renormalization.
std::fill(output_ref.begin(), output_ref.end(), 0);
std::fill(output_ref.begin(), output_ref.end(), 0.0f);

// TODO: Not supported right now.
assert (transpose_weights() == false);
Expand Down
17 changes: 11 additions & 6 deletions test/fully-connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1416,9 +1416,12 @@ TEST_F(FullyConnectedTestBF16F32, matches_operator_api) {
// }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng);
// });
int counter = 0;
std::generate(input.begin(), input.end(), [&]() { return counter++ % 10; });
std::generate(kernel.begin(), kernel.end(), [&]() { return counter++ % 10; });
std::generate(bias.begin(), bias.end(), [&]() { return counter++ % 10; });
std::generate(input.begin(), input.end(),
[&]() { return static_cast<float>(counter++ % 10); });
std::generate(kernel.begin(), kernel.end(),
[&]() { return static_cast<float>(counter++ % 10); });
std::generate(bias.begin(), bias.end(),
[&]() { return static_cast<float>(counter++ % 10); });

// Call operator API.
const xnn_status status = xnn_create_fully_connected_nc_bf16_f32(
Expand Down Expand Up @@ -3607,8 +3610,9 @@ TEST_F(FullyConnectedTestQD8F32QC4W,
// 2nd inference: The dq-params should be properly allocated to handle a
// resize without memory retrigger
input_dims[0] += 2;
size_t batch_size2 = std::accumulate(input_dims.begin(), input_dims.end() - 1,
1, std::multiplies<size_t>());
size_t batch_size2 =
std::accumulate(input_dims.begin(), input_dims.end() - 1,
static_cast<size_t>(1), std::multiplies<size_t>());
xnnpack::Buffer<float> convert_input2(batch_size2 * input_channels +
XNN_EXTRA_BYTES / sizeof(float));
std::generate(convert_input2.begin(), convert_input2.end(),
Expand All @@ -3629,7 +3633,8 @@ TEST_F(FullyConnectedTestQD8F32QC4W,
// retrigger
input_dims[0] += 2; // +4 total
size_t batch_size3 = std::accumulate(input_dims.begin(), input_dims.end() - 1,
1, std::multiplies<size_t>());
static_cast<size_t>(1),
std::multiplies<size_t>());
xnnpack::Buffer<float> convert_input3(batch_size3 * input_channels +
XNN_EXTRA_BYTES / sizeof(float));
std::generate(convert_input3.begin(), convert_input3.end(),
Expand Down
18 changes: 9 additions & 9 deletions test/gemm-microkernel-tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ void GemmMicrokernelTester::Test(
/* bias */ packed_n() * sizeof(float));

xnnpack::Buffer<xnn_float16> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1);
xnnpack::Buffer<float> c_ref(m() * n(), 0);
xnnpack::Buffer<float> c_ref(m() * n());

for (size_t iteration = 0; iteration < kIterations; iteration++) {
std::generate(input.begin(), input.end(), std::ref(f32rng));
Expand Down Expand Up @@ -1377,7 +1377,7 @@ void GemmMicrokernelTester::Test(
(void*) start);

// Compute 32-bit results and output quantization arguments.
std::fill(c_ref.begin(), c_ref.end(), 0);
std::fill(c_ref.begin(), c_ref.end(), 0.0f);
for (size_t m_index = 0; m_index < m(); m_index++) {
for (size_t n_index = 0; n_index < n(); n_index++) {
float kfsum = 0.0;
Expand Down Expand Up @@ -1473,7 +1473,7 @@ void GemmMicrokernelTester::Test(
packed_n() * (sizeof(int32_t) + sizeof(float) * 2));
xnnpack::Buffer<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1);
xnnpack::Buffer<int32_t> acc(m() * n());
xnnpack::Buffer<float> c_ref(m() * n(), 0);
xnnpack::Buffer<float> c_ref(m() * n());

for (size_t iteration = 0; iteration < kIterations; iteration++) {
std::generate(input.begin(), input.end(), std::ref(f32rng));
Expand Down Expand Up @@ -1526,7 +1526,7 @@ void GemmMicrokernelTester::Test(
(void*) ((uintptr_t) packed_w.data() + nr() * (ks() * packed_k_bytes + 2 * sizeof(float))));

// Compute 32-bit results and output quantization arguments.
std::fill(c_ref.begin(), c_ref.end(), 0);
std::fill(c_ref.begin(), c_ref.end(), 0.0f);
for (size_t m_index = 0; m_index < m(); m_index++) {
for (size_t n_index = 0; n_index < n(); n_index++) {
int32_t ksum = 0;
Expand Down Expand Up @@ -1625,7 +1625,7 @@ void GemmMicrokernelTester::Test(
/* bias */ packed_n() * sizeof(float));

xnnpack::Buffer<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1);
xnnpack::Buffer<float> c_ref(m() * n(), 0);
xnnpack::Buffer<float> c_ref(m() * n());

for (size_t iteration = 0; iteration < 1 /* kIterations */; iteration++) {
std::generate(input.begin(), input.end(), std::ref(f32rng));
Expand Down Expand Up @@ -1685,7 +1685,7 @@ void GemmMicrokernelTester::Test(
(void*) start);

// Compute 32-bit results and output quantization arguments.
std::fill(c_ref.begin(), c_ref.end(), 0);
std::fill(c_ref.begin(), c_ref.end(), 0.0f);
for (size_t m_index = 0; m_index < m(); m_index++) {
for (size_t n_index = 0; n_index < n(); n_index++) {
float kfsum = 0.0;
Expand Down Expand Up @@ -2589,7 +2589,7 @@ void GemmMicrokernelTester::Test(
std::generate(bias.begin(), bias.end(), [&] { return f32rng(rng); });
std::fill(c_ref.begin(), c_ref.end(), 0.0f);

std::fill(packed_w.begin(), packed_w.end(), 0);
std::fill(packed_w.begin(), packed_w.end(), 0.0f);
pack(/*g=*/1, n(), k(), nr(), kr(), sr(),
b.data(),
bias.data(), /*scale=*/nullptr,
Expand Down Expand Up @@ -2669,7 +2669,7 @@ void GemmMicrokernelTester::Test(
std::generate(bias.begin(), bias.end(), [&] { return f32rng(rng); });
std::fill(c_ref.begin(), c_ref.end(), 0.0f);

std::fill(packed_w.begin(), packed_w.end(), 0);
std::fill(packed_w.begin(), packed_w.end(), 0.0f);
pack(/*g=*/1, n(), k(), nr(), kr(), sr(),
reinterpret_cast<const uint16_t*>(b.data()),
reinterpret_cast<const uint16_t*>(bias.data()), /*scale=*/nullptr,
Expand Down Expand Up @@ -2749,7 +2749,7 @@ void GemmMicrokernelTester::Test(
std::generate(bias.begin(), bias.end(), f32rng);
std::fill(c_ref.begin(), c_ref.end(), 0.0f);

std::fill(packed_w.begin(), packed_w.end(), 0);
std::fill(packed_w.begin(), packed_w.end(), 0.0f);
pack(/*g=*/1, n(), k(), nr(), kr(), sr(),
reinterpret_cast<const uint16_t*>(b.data()),
reinterpret_cast<const uint16_t*>(bias.data()),
Expand Down
46 changes: 30 additions & 16 deletions test/operator-test-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,52 @@

namespace xnnpack {

template <typename T, typename Buffer>
void randomize_int_buffer(xnn_datatype datatype,
xnnpack::ReplicableRandomDevice& rng, double min,
double max, Buffer& buf) {
std::uniform_int_distribution<int> dist(static_cast<int>(min),
static_cast<int>(max));
const auto f = [&]() { return static_cast<T>(dist(rng)); };
std::generate(reinterpret_cast<T*>(buf.begin()),
reinterpret_cast<T*>(buf.end()), f);
}

template <typename T, typename Buffer>
void randomize_float_buffer(xnn_datatype datatype,
xnnpack::ReplicableRandomDevice& rng, double min,
double max, Buffer& buf) {
std::uniform_real_distribution<float> dist(static_cast<float>(min),
static_cast<float>(max));
const auto f = [&]() { return dist(rng); };
std::generate(reinterpret_cast<T*>(buf.begin()),
reinterpret_cast<T*>(buf.end()), f);
}

// Given ann xnnpack::Buffer<char> type, initialize it with
// the given datatype using the given RNG and distribution.
template <typename Buffer>
void randomize_buffer(xnn_datatype datatype,
xnnpack::ReplicableRandomDevice& rng,
std::uniform_real_distribution<double>& dist,
Buffer& buf) {
const auto f = [&]() { return dist(rng); };
xnnpack::ReplicableRandomDevice& rng, double min,
double max, Buffer& buf) {
switch (datatype) {
case xnn_datatype_quint8:
std::generate(reinterpret_cast<uint8_t*>(buf.begin()),
reinterpret_cast<uint8_t*>(buf.end()), f);
randomize_int_buffer<uint8_t>(datatype, rng, min, max, buf);
break;
case xnn_datatype_qint8:
std::generate(reinterpret_cast<int8_t*>(buf.begin()),
reinterpret_cast<int8_t*>(buf.end()), f);
randomize_int_buffer<int8_t>(datatype, rng, min, max, buf);
break;
case xnn_datatype_int32:
std::generate(reinterpret_cast<int32_t*>(buf.begin()),
reinterpret_cast<int32_t*>(buf.end()), f);
randomize_int_buffer<int32_t>(datatype, rng, min, max, buf);
break;
case xnn_datatype_fp16:
std::generate(reinterpret_cast<xnn_float16*>(buf.begin()),
reinterpret_cast<xnn_float16*>(buf.end()), f);
randomize_float_buffer<xnn_float16>(datatype, rng, min, max, buf);
break;
case xnn_datatype_bf16:
std::generate(reinterpret_cast<xnn_bfloat16*>(buf.begin()),
reinterpret_cast<xnn_bfloat16*>(buf.end()), f);
randomize_float_buffer<xnn_bfloat16>(datatype, rng, min, max, buf);
break;
case xnn_datatype_fp32:
std::generate(reinterpret_cast<float*>(buf.begin()),
reinterpret_cast<float*>(buf.end()), f);
randomize_float_buffer<float>(datatype, rng, min, max, buf);
break;
default:
assert(false);
Expand Down
4 changes: 2 additions & 2 deletions test/packing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4214,7 +4214,7 @@ TEST(PACK_QS8_MULTIPASS_DWCONV_GHW_W, one_middle_pass_channel_subtile_rounded) {
// c rounded to channel_subtile is 8, so we will have 2 channel_tile loops in first and middle pass.

std::vector<int32_t> b(c);
std::iota(b.begin(), b.end(), 0.0f); // b = [0, 1, 2, 3, 4, 5, 6]
std::iota(b.begin(), b.end(), 0); // b = [0, 1, 2, 3, 4, 5, 6]
std::vector<int8_t> k(c * h * w); // k = [7, 8, // first 2x2 kernel
// 9, 10,
// 11, 12, // second 2x2 kernel
Expand Down Expand Up @@ -4429,7 +4429,7 @@ TEST(PACK_QS8_MULTIPASS_DWCONV_HWG_W, one_middle_pass_tile) {
const size_t cr = 2;

std::vector<int32_t> b(c);
std::iota(b.begin(), b.end(), 0.0f); // b = [0, 1]
std::iota(b.begin(), b.end(), 0); // b = [0, 1]
std::vector<int8_t> k(c * h * w); // k = [2, 3,
// 4, 5,
// 6, 7,
Expand Down
2 changes: 1 addition & 1 deletion test/packw-microkernel-tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class PackWMicrokernelTester {
std::fill(packed_w_ref.begin(), packed_w_ref.end(), pad_value);

// Mandate zero-padding of weights to packed_k() in K dimension.
std::fill(padded_weights.begin(), padded_weights.end(), 0);
std::fill(padded_weights.begin(), padded_weights.end(), 0.0f);
for (size_t gid = 0; gid < g(); gid++) {
for (size_t i = 0; i < n(); i++) {
for (size_t j = 0; j < k(); j++) {
Expand Down
3 changes: 2 additions & 1 deletion test/runtime-tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class RuntimeTester : public SubgraphTester {
EXPECT_EQ(status, xnn_status_success);
size_t num_elements = NumElements(dims);
xnnpack::Buffer<char> input(num_elements * sizeof(float) + XNN_EXTRA_BYTES * sizeof(char));
std::generate(input.begin(), input.end(), [&]() { return f32dist(rng_); });
float* data = reinterpret_cast<float*>(input.data());
std::generate(data, data + num_elements, [&]() { return f32dist(rng_); });
external_tensors_[external_id] = std::move(input);
}

Expand Down
2 changes: 1 addition & 1 deletion test/spmm-microkernel-tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class SpMMMicrokernelTester {
std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
std::fill(nmap.begin(), nmap.end(), 0);
std::fill(dmap.begin(), dmap.end(), 0);
std::fill(w.begin(), w.end(), 0);
std::fill(w.begin(), w.end(), 0.0f);

for (xnn_float16& b_value : b) {
if (pdist(rng) <= sparsity()) {
Expand Down
3 changes: 2 additions & 1 deletion test/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ TEST_P(UnaryTest, matches_operator_api) {
std::generate(dims.begin(), dims.end(), [&]() { return dim_dist(rng_); });

size_t size =
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
std::accumulate(dims.begin(), dims.end(), static_cast<size_t>(1),
std::multiplies<size_t>());
size_t channels = dims.empty() ? 1 : dims.back();
size_t batch_size = size / channels;

Expand Down

0 comments on commit 9f9f69c

Please sign in to comment.