1111#include " core/util/qmath.h"
1212#include < cassert>
1313
14+ #include < chrono>
1415#include < algorithm>
1516
1617namespace onnxruntime {
@@ -20,30 +21,20 @@ namespace {
2021
2122
2223using Index = uint32_t ;
23- extern " C" void
24- __attribute__ ((import_module(" wasm_gemm" ), import_name(" int8_multiply" )))
25- int8Multiply (const uint8_t * input_A,
26- float zero_point_A,
27- const int8_t * input_B,
28- // const uint8_t* zero_point_B,
29- Index rows_A,
30- Index width,
31- Index cols_B,
32- float * output);
3324
3425extern " C" void
35- __attribute__ ((import_module(" wasm_gemm" ), import_name(" f32_multiply " )))
36- f32Multiply (
37- const uint8_t * a_data,
38- float zero_point_A,
39- const int8_t * input_B,
40- const uint8_t * zero_point_B,
41- Index rows_A,
42- Index width,
43- Index cols_B,
44- const float * b_scale_data,
45- float is_b_scale_per_column,
46- float * output
26+ __attribute__ ((import_module(" wasm_gemm" ), import_name(" onnx_matmul_integer_to_float " )))
27+ GeckoMatmulIntegerToFloat (
28+ const uint8_t * a_data,
29+ float zero_point_A,
30+ const int8_t * input_B,
31+ const uint8_t * zero_point_B,
32+ uint32_t rows_A,
33+ uint32_t width,
34+ uint32_t cols_B,
35+ const float * b_scale_data,
36+ float is_b_scale_per_column,
37+ float * output
4738);
4839
4940
@@ -234,19 +225,9 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
234225 params.ldc = gemm_shape.N ;
235226 }
236227
237- // rowsA = M
238- // width = K
239- // colsB = N
240- size_t rowsA = static_cast <size_t >(helper.M ());
241- size_t width = static_cast <size_t >(helper.K ());
242- size_t colsB = static_cast <size_t >(helper.N ());
243- const int8_t * b_data = static_cast <const int8_t *>(b_tensor->DataRaw ());
244-
245- #if 0
246- size_t total_elements = rowsA * colsB;
247- size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
228+ #if 0
248229 std::vector<float> y_data_2(rowsA * colsB, 0.0f);
249- std::cout << "Calling MatMulFull with the following parameters:\n";
230+ if (rowsA > 1) {
250231 std::cout << "rowsA: " << rowsA << ", width: " << width << ", colsB: " << colsB << "\n";
251232 std::cout << "a_zp: " << static_cast<int>(a_zp) << "\n";
252233 std::cout << "is_b_scale_per_column: " << is_b_scale_per_column << "\n";
@@ -276,13 +257,22 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
276257 }
277258 std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << std::endl;
278259 std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl;
260+ }
279261 #endif
262+ // auto start = std::chrono::steady_clock::now();
263+ // std::cout << "Calling f32Multiply\n";
264+ // should split in parts and call ctx.ParallelFor just on the rows part
280265
281- // MatMulFull(a_data, b_data, y_data, rowsA, width, colsB, a_zp, b_zp_ptr, b_scale_data, is_b_scale_per_column);
282-
283- std::cout << " Calling f32Multiply\n " ;
266+ // rowsA = M
267+ // width = K
268+ // colsB = N
269+ size_t rowsA = static_cast <size_t >(helper.M ());
270+ size_t width = static_cast <size_t >(helper.K ());
271+ size_t colsB = static_cast <size_t >(helper.N ());
272+
273+ const int8_t * b_data = static_cast <const int8_t *>(b_tensor->DataRaw ());
284274
285- f32Multiply (a_data,
275+ GeckoMatmulIntegerToFloat (a_data,
286276 a_zp,
287277 b_data,
288278 b_zp_ptr,
@@ -291,15 +281,28 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
291281 colsB,
292282 b_scale_data,
293283 is_b_scale_per_column,
294- y_data);
295-
296- std::cout << " Done calling f32Multiply\n " ;
297-
298-
299- #if 0
284+ y_data
285+ );
286+
287+ // MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
288+ /*
289+ auto end = std::chrono::steady_clock::now();
290+ auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
291+ std::cout << "Done calling f32Multiply. Duration: " << duration << " nano\n";
292+
293+ std::cout << "Calling MlasGemmBatch\n";
294+ auto start2 = std::chrono::steady_clock::now();
300295 MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
296+ auto end2 = std::chrono::steady_clock::now();
297+ auto duration2 = std::chrono::duration_cast<std::chrono::nanoseconds>(end2 - start2).count();
298+ std::cout << "Done calling MlasGemmBatch. Duration: " << duration2 << " nano\n";
299+ */
300+ /*
301301
302302 // Compare y_data and y_data_2
303+
304+ size_t total_elements = rowsA * colsB;
305+ size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
303306 bool mismatch_found = false;
304307 for (size_t i = 0; i < total_elements; ++i) {
305308 if (std::fabs(y_data[i] - y_data_2[i]) > 1e-6) { // Tolerance for floating-point comparison
@@ -322,23 +325,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
322325 std::cerr << "Mismatch found between y_data and y_data_2!" << std::endl;
323326 assert(false && "Validation failed: y_data and y_data_2 are not equal.");
324327 }
325- # endif
328+ */
326329 return Status::OK ();
327330}
328331
329- /*
330- int8Multiply(
331- reinterpret_cast<const uint8_t*>(a_data),
332- a_zp,
333- b_data,
334- //reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
335- rowsA,
336- width,
337- colsB,
338- reinterpret_cast<float*>(y_data)
339- );
340- */
341-
342332class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
343333 public:
344334 DynamicQuantizeMatMul (const OpKernelInfo& info) : MatMulIntegerToFloatBase(info) {}
@@ -405,6 +395,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
405395 ParQuantizeLinearStd (a_data, a_data_quant, narrow<size_t >(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool ());
406396
407397 bool is_b_scale_supported = IsBQuantParamSupported (b_scale_tensor->Shape (), b ? b->Shape () : b_shape_);
398+
399+ // std::cout << "dynamic quantize matmul calling ComputeCommon" << std::endl;
400+
401+
408402 ORT_RETURN_IF_ERROR (ComputeCommon (
409403 ctx,
410404 a_data_quant,
@@ -418,7 +412,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
418412 ctx->Input <Tensor>(IN_BIAS)));
419413
420414 if (!is_b_scale_supported) {
421- std::cout << " dynamic quantize matmul: b scale is not supported\n " ;
415+ // std::cout << "dynamic quantize matmul: b scale is not supported\n";
422416 ScaleOutput (*b_scale_tensor, *ctx->Output <Tensor>(0 ));
423417 }
424418
@@ -460,6 +454,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
460454 a_zero_point = *(static_cast <const uint8_t *>(a_zero_point_tensor->DataRaw ()));
461455 }
462456
457+ // std::cout << "matmul integer float calling ComputeCommon" << std::endl;
463458 const Tensor* b_zp_tensor = ctx->Input <Tensor>(IN_B_ZERO_POINT);
464459 ORT_RETURN_IF_ERROR (ComputeCommon (
465460 ctx,
@@ -474,11 +469,11 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
474469 ctx->Input <Tensor>(IN_BIAS)));
475470
476471 if (!is_a_scale_scalar) {
477- std::cout << " dynamic quantize matmul: a scale is not scalar\n " ;
472+ // std::cout << "dynamic quantize matmul: a scale is not scalar\n";
478473 ScaleOutput (*a_scale_tensor, *ctx->Output <Tensor>(0 ));
479474 }
480475 if (!is_b_scale_supported) {
481- std::cout << " dynamic quantize matmul: b scale is not supported\n " ;
476+ // std::cout << "dynamic quantize matmul: b scale is not supported\n";
482477 ScaleOutput (*b_scale_tensor, *ctx->Output <Tensor>(0 ));
483478 }
484479
0 commit comments