Skip to content

Commit c90b933

Browse files
committed
savepoint
1 parent 26dd1ee commit c90b933

File tree

7 files changed

+227
-81
lines changed

7 files changed

+227
-81
lines changed

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 56 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "core/util/qmath.h"
1212
#include <cassert>
1313

14+
#include <chrono>
1415
#include <algorithm>
1516

1617
namespace onnxruntime {
@@ -20,30 +21,20 @@ namespace {
2021

2122

2223
using Index = uint32_t;
23-
extern "C" void
24-
__attribute__((import_module("wasm_gemm"), import_name("int8_multiply")))
25-
int8Multiply(const uint8_t* input_A,
26-
float zero_point_A,
27-
const int8_t* input_B,
28-
//const uint8_t* zero_point_B,
29-
Index rows_A,
30-
Index width,
31-
Index cols_B,
32-
float* output);
3324

3425
extern "C" void
35-
__attribute__((import_module("wasm_gemm"), import_name("f32_multiply")))
36-
f32Multiply(
37-
const uint8_t* a_data,
38-
float zero_point_A,
39-
const int8_t* input_B,
40-
const uint8_t* zero_point_B,
41-
Index rows_A,
42-
Index width,
43-
Index cols_B,
44-
const float* b_scale_data,
45-
float is_b_scale_per_column,
46-
float* output
26+
__attribute__((import_module("wasm_gemm"), import_name("onnx_matmul_integer_to_float")))
27+
GeckoMatmulIntegerToFloat(
28+
const uint8_t* a_data,
29+
float zero_point_A,
30+
const int8_t* input_B,
31+
const uint8_t* zero_point_B,
32+
uint32_t rows_A,
33+
uint32_t width,
34+
uint32_t cols_B,
35+
const float* b_scale_data,
36+
float is_b_scale_per_column,
37+
float* output
4738
);
4839

4940

@@ -234,19 +225,9 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
234225
params.ldc = gemm_shape.N;
235226
}
236227

237-
// rowsA = M
238-
// width = K
239-
// colsB = N
240-
size_t rowsA = static_cast<size_t>(helper.M());
241-
size_t width = static_cast<size_t>(helper.K());
242-
size_t colsB = static_cast<size_t>(helper.N());
243-
const int8_t* b_data = static_cast<const int8_t*>(b_tensor->DataRaw());
244-
245-
#if 0
246-
size_t total_elements = rowsA * colsB;
247-
size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
228+
#if 0
248229
std::vector<float> y_data_2(rowsA * colsB, 0.0f);
249-
std::cout << "Calling MatMulFull with the following parameters:\n";
230+
if (rowsA > 1) {
250231
std::cout << "rowsA: " << rowsA << ", width: " << width << ", colsB: " << colsB << "\n";
251232
std::cout << "a_zp: " << static_cast<int>(a_zp) << "\n";
252233
std::cout << "is_b_scale_per_column: " << is_b_scale_per_column << "\n";
@@ -276,13 +257,22 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
276257
}
277258
std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << std::endl;
278259
std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl;
260+
}
279261
#endif
262+
//auto start = std::chrono::steady_clock::now();
263+
//std::cout << "Calling f32Multiply\n";
264+
// should split in parts and call ctx.ParallelFor just on the rows part
280265

281-
//MatMulFull(a_data, b_data, y_data, rowsA, width, colsB, a_zp, b_zp_ptr, b_scale_data, is_b_scale_per_column);
282-
283-
std::cout << "Calling f32Multiply\n";
266+
// rowsA = M
267+
// width = K
268+
// colsB = N
269+
size_t rowsA = static_cast<size_t>(helper.M());
270+
size_t width = static_cast<size_t>(helper.K());
271+
size_t colsB = static_cast<size_t>(helper.N());
272+
273+
const int8_t* b_data = static_cast<const int8_t*>(b_tensor->DataRaw());
284274

285-
f32Multiply(a_data,
275+
GeckoMatmulIntegerToFloat(a_data,
286276
a_zp,
287277
b_data,
288278
b_zp_ptr,
@@ -291,15 +281,28 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
291281
colsB,
292282
b_scale_data,
293283
is_b_scale_per_column,
294-
y_data);
295-
296-
std::cout << "Done calling f32Multiply\n";
297-
298-
299-
#if 0
284+
y_data
285+
);
286+
287+
// MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
288+
/*
289+
auto end = std::chrono::steady_clock::now();
290+
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
291+
std::cout << "Done calling f32Multiply. Duration: " << duration << " nano\n";
292+
293+
std::cout << "Calling MlasGemmBatch\n";
294+
auto start2 = std::chrono::steady_clock::now();
300295
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
296+
auto end2 = std::chrono::steady_clock::now();
297+
auto duration2 = std::chrono::duration_cast<std::chrono::nanoseconds>(end2 - start2).count();
298+
std::cout << "Done calling MlasGemmBatch. Duration: " << duration2 << " nano\n";
299+
*/
300+
/*
301301
302302
// Compare y_data and y_data_2
303+
304+
size_t total_elements = rowsA * colsB;
305+
size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
303306
bool mismatch_found = false;
304307
for (size_t i = 0; i < total_elements; ++i) {
305308
if (std::fabs(y_data[i] - y_data_2[i]) > 1e-6) { // Tolerance for floating-point comparison
@@ -322,23 +325,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
322325
std::cerr << "Mismatch found between y_data and y_data_2!" << std::endl;
323326
assert(false && "Validation failed: y_data and y_data_2 are not equal.");
324327
}
325-
#endif
328+
*/
326329
return Status::OK();
327330
}
328331

329-
/*
330-
int8Multiply(
331-
reinterpret_cast<const uint8_t*>(a_data),
332-
a_zp,
333-
b_data,
334-
//reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
335-
rowsA,
336-
width,
337-
colsB,
338-
reinterpret_cast<float*>(y_data)
339-
);
340-
*/
341-
342332
class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
343333
public:
344334
DynamicQuantizeMatMul(const OpKernelInfo& info) : MatMulIntegerToFloatBase(info) {}
@@ -405,6 +395,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
405395
ParQuantizeLinearStd(a_data, a_data_quant, narrow<size_t>(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool());
406396

407397
bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_);
398+
399+
//std::cout << "dynamic quantize matmul calling ComputeCommon" << std::endl;
400+
401+
408402
ORT_RETURN_IF_ERROR(ComputeCommon(
409403
ctx,
410404
a_data_quant,
@@ -418,7 +412,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
418412
ctx->Input<Tensor>(IN_BIAS)));
419413

420414
if (!is_b_scale_supported) {
421-
std::cout << "dynamic quantize matmul: b scale is not supported\n";
415+
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
422416
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
423417
}
424418

@@ -460,6 +454,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
460454
a_zero_point = *(static_cast<const uint8_t*>(a_zero_point_tensor->DataRaw()));
461455
}
462456

457+
//std::cout << "matmul integer float calling ComputeCommon" << std::endl;
463458
const Tensor* b_zp_tensor = ctx->Input<Tensor>(IN_B_ZERO_POINT);
464459
ORT_RETURN_IF_ERROR(ComputeCommon(
465460
ctx,
@@ -474,11 +469,11 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
474469
ctx->Input<Tensor>(IN_BIAS)));
475470

476471
if (!is_a_scale_scalar) {
477-
std::cout << "dynamic quantize matmul: a scale is not scalar\n";
472+
//std::cout << "dynamic quantize matmul: a scale is not scalar\n";
478473
ScaleOutput(*a_scale_tensor, *ctx->Output<Tensor>(0));
479474
}
480475
if (!is_b_scale_supported) {
481-
std::cout << "dynamic quantize matmul: b scale is not supported\n";
476+
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
482477
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
483478
}
484479

onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,18 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
149149
}
150150
#endif
151151
//auto start_matmul = Clock::now();
152+
/*
152153
int8Multiply(
153154
reinterpret_cast<const uint8_t*>(a->DataRaw()),
154155
a_offset,
155156
reinterpret_cast<const int8_t*>(b->DataRaw()),
156-
//reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
157+
reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
157158
M,
158159
K,
159160
N,
160161
reinterpret_cast<float*>(y_data)
161162
);
163+
*/
162164
//auto end_matmul = Clock::now();
163165
//auto matmul_time = std::chrono::duration_cast<Microseconds>(end_matmul - start_matmul).count();
164166

@@ -202,6 +204,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
202204
std::cout << "MlasGemmBatch: " << mblas_time << "\n";
203205

204206
#endif
207+
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());
205208
return Status::OK();
206209
}
207210

onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ class FirefoxMatMulInteger8 final : public MatMulIntegerBase {
3939

4040
#include <cstdint>
4141

42-
using Index = uint32_t;
42+
#if 0
4343
extern "C" void
4444
__attribute__((import_module("wasm_gemm"), import_name("int8_multiply")))
4545
int8Multiply(const uint8_t* input_A,
4646
float zero_point_A,
4747
const int8_t* input_B,
48-
//const uint8_t* zero_point_B,
49-
Index rows_A,
50-
Index width,
51-
Index cols_B,
48+
const uint8_t* zero_point_B,
49+
float rows_A,
50+
float width,
51+
float cols_B,
5252
float* output);
53-
53+
#endif
5454

5555
#endif
5656

onnxruntime/core/mlas/lib/qgemm.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Module Name:
1717

1818
#include "mlasi.h"
1919
#include "qgemm.h"
20+
#include <iostream>
2021

2122
//
2223
// Define the parameters to execute segments of a QGEMM operation on worker
@@ -144,6 +145,8 @@ MlasGemmBatch(
144145

145146
const double Complexity = double(M) * double(N) * double(K) * double(BatchN);
146147

148+
//std::cout << "Complexity: " << Complexity << std::endl;
149+
147150
ptrdiff_t TargetThreadCount;
148151

149152
if (Complexity < double(MLAS_QGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) {
@@ -194,10 +197,16 @@ MlasGemmBatch(
194197
WorkBlock.ThreadCountN = 1;
195198
}
196199
TargetThreadCount = ThreadsPerGemm * BatchN;
200+
//std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << std::endl;
201+
//std::cout << "TargetThreadCount: " << TargetThreadCount << std::endl;
202+
//std::cout << "MaximumThreadCount: " << MaximumThreadCount << std::endl;
203+
204+
197205

198206
MlasTrySimpleParallel(ThreadPool, TargetThreadCount, [&](ptrdiff_t tid) {
199207
const auto gemm_i = tid / ThreadsPerGemm;
200208
const auto blk_i = tid % ThreadsPerGemm;
209+
//std::cout << "gemm_i: " << gemm_i << " blk_i: " << blk_i << std::endl;
201210
MlasGemmQuantThreaded(&WorkBlock, &Shape, &DataParams[gemm_i], blk_i);
202211
});
203212
}
@@ -277,6 +286,13 @@ MlasSymmQgemmBatch(
277286
const size_t ThreadCountM = MlasDivRoundup(M, StrideM);
278287
const size_t ThreadCountN = MlasDivRoundup(N, StrideN);
279288
ThreadsPerGemm = ThreadCountM * ThreadCountN;
289+
290+
/*
291+
std::cout << "ThreadsPerGemm" << ThreadsPerGemm << std::endl;
292+
std::cout << "TargetThreadCount " <<TargetThreadCount << std::endl;
293+
std::cout << "ThreadCountM" << ThreadCountM << std::endl;
294+
std::cout << "ThreadCountN" << ThreadCountN << std::endl;
295+
*/
280296

281297
MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
282298
auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().IsCurrentCoreArmv8NarrowLd();

onnxruntime/core/mlas/lib/qgemm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ Module Name:
3636
#include <string>
3737
#include <cstdlib>
3838

39+
extern "C" void
40+
__attribute__((import_module("wasm_gemm"), import_name("mlas_gemm_u8x8")))
41+
xMlasGemmU8X8MultiplyAccumulateRowWasmSimd(
42+
const float* A,
43+
const float* B,
44+
const float* C
45+
);
46+
47+
3948
//
4049
// Define the default striding parameters used for the quantized integer
4150
// matrix/matrix multiply operation.

0 commit comments

Comments
 (0)