Skip to content

Commit 5560a90

Browse files
committed
added a firefox matmul backend
1 parent 49a80df commit 5560a90

File tree

17 files changed

+797
-53
lines changed

17 files changed

+797
-53
lines changed

build.sh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
#!/bin/bash
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
4+
set -ex
45

56
# Get directory this script is in
6-
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
7+
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
78
OS=$(uname -s)
89

910
if [ "$OS" = "Darwin" ]; then
10-
DIR_OS="MacOS"
11+
DIR_OS="MacOS"
1112
else
12-
DIR_OS="Linux"
13+
DIR_OS="Linux"
1314
fi
1415

1516
if [[ "$*" == *"--ios"* ]]; then
16-
DIR_OS="iOS"
17+
DIR_OS="iOS"
1718
elif [[ "$*" == *"--android"* ]]; then
18-
DIR_OS="Android"
19+
DIR_OS="Android"
1920
fi
2021

21-
python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"
22+
PYTHON="${PYTHON:-python3}"
23+
24+
$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"

cmake/onnxruntime_webassembly.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ jsepDownload:_pp_")
382382
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
383383
"SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']"
384384
"SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']"
385+
"SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0"
385386
)
386387
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
387388
endif()

onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram
6262
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
6363
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
6464
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
65-
65+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8);
66+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8);
6667
// ******** Start: Quantization ******************* //
6768
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
6869
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool);
@@ -285,6 +286,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
285286
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
286287
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>,
287288

289+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8)>,
290+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8)>,
288291
// add more kernels here
289292
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample)>,
290293
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention)>,
@@ -364,7 +367,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
364367
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,
365368
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor)>,
366369
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping)>,
367-
368370
#ifdef ENABLE_ATEN
369371
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
370372
#endif

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,35 @@
99
#include "core/providers/cpu/quantization/matmul_integer_base.h"
1010
#include "core/util/math_cpuonly.h"
1111
#include "core/util/qmath.h"
12+
#include <cassert>
1213

14+
#include <chrono>
1315
#include <algorithm>
1416

1517
namespace onnxruntime {
1618
namespace contrib {
1719

1820
namespace {
21+
22+
23+
using Index = uint32_t;
24+
25+
extern "C" void
26+
__attribute__((import_module("wasm_gemm"), import_name("onnx_matmul_integer_to_float")))
27+
GeckoMatmulIntegerToFloat(
28+
const uint8_t* a_data,
29+
float zero_point_A,
30+
const int8_t* input_B,
31+
const uint8_t* zero_point_B,
32+
uint32_t rows_A,
33+
uint32_t width,
34+
uint32_t cols_B,
35+
const float* b_scale_data,
36+
float is_b_scale_per_column,
37+
float* output
38+
);
39+
40+
1941
void ScaleOutput(const Tensor& scale, Tensor& output) {
2042
ProcessBroadcastSpanFuncs funcs{
2143
[](BroadcastHelper& per_iter_bh) {
@@ -51,12 +73,65 @@ class MatMulIntegerToFloatBase : public MatMulIntegerBase {
5173
float a_scale,
5274
uint8_t a_zp,
5375
bool a_is_signed,
54-
const Tensor* b_tensor,
76+
const Tensor* b_tensor,
5577
const Tensor* b_scale,
5678
const Tensor* b_zp,
5779
const Tensor* bias_tensor) const;
5880
};
5981

82+
void MatMulFull(const uint8_t* inputMatrixA,
83+
const int8_t* inputMatrixB,
84+
float* output,
85+
size_t rowsA,
86+
size_t width,
87+
size_t colsB,
88+
uint8_t zeroPointA,
89+
const uint8_t* zeroPointB,
90+
const float* b_scale_data,
91+
bool is_b_scale_per_column) {
92+
93+
float matrixScale = is_b_scale_per_column ? 0.0f : b_scale_data[0];
94+
int32_t matrixZeroPointB = is_b_scale_per_column ? 0 : static_cast<int32_t>(zeroPointB[0]);
95+
96+
for (size_t rowIndex = 0; rowIndex < rowsA; ++rowIndex) {
97+
const uint8_t* aRow = inputMatrixA + rowIndex * width; // Start of row in A
98+
for (size_t colIndex = 0; colIndex < colsB; ++colIndex) {
99+
int32_t tempResult = 0;
100+
101+
for (size_t k = 0; k < width; ++k) {
102+
// Row-major access
103+
uint8_t aValue = aRow[k];
104+
105+
// Column-major access for B
106+
int8_t bValue = inputMatrixB[k * colsB + colIndex];
107+
108+
// Adjust for zero-point offsets
109+
int32_t adjustedA = static_cast<int32_t>(aValue) - static_cast<int32_t>(zeroPointA);
110+
int32_t adjustedB = static_cast<int32_t>(bValue);
111+
112+
if (is_b_scale_per_column) {
113+
adjustedB -= static_cast<int32_t>(zeroPointB[colIndex]);
114+
} else {
115+
adjustedB -= matrixZeroPointB;
116+
}
117+
// Accumulate product
118+
tempResult += adjustedA * adjustedB;
119+
}
120+
121+
float scaledResult = tempResult;
122+
if (is_b_scale_per_column) {
123+
scaledResult *= b_scale_data[colIndex];
124+
}
125+
else {
126+
scaledResult *= matrixScale;
127+
}
128+
129+
// Store the scaled result in y_data
130+
output[rowIndex * colsB + colIndex] = scaledResult;
131+
}
132+
}
133+
}
134+
60135
Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
61136
const uint8_t* a_data,
62137
const TensorShape& a_shape,
@@ -150,8 +225,107 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
150225
params.ldc = gemm_shape.N;
151226
}
152227

228+
#if 0
229+
std::vector<float> y_data_2(rowsA * colsB, 0.0f);
230+
if (rowsA > 1) {
231+
std::cout << "rowsA: " << rowsA << ", width: " << width << ", colsB: " << colsB << "\n";
232+
std::cout << "a_zp: " << static_cast<int>(a_zp) << "\n";
233+
std::cout << "is_b_scale_per_column: " << is_b_scale_per_column << "\n";
234+
std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << "\n";
235+
std::cout << "b_scale_data sample: [";
236+
for (size_t i = 0; i < 25; ++i) {
237+
if (i > 0) std::cout << ", ";
238+
std::cout << b_scale_data[i];
239+
}
240+
std::cout << "]\n";
241+
std::cout << "b_zero point sample: [";
242+
for (size_t i = 0; i < 25; ++i) {
243+
if (i > 0) std::cout << ", ";
244+
std::cout << static_cast<int>(b_zp_ptr[i]) << ", ";
245+
}
246+
std::cout << "]\n";
247+
248+
if (bias_data != nullptr) {
249+
size_t bias_size = static_cast<size_t>(bias_tensor->Shape().Size()); // Get the total size of bias_data
250+
size_t display_limit = std::min(bias_size, static_cast<size_t>(100));
251+
std::cout << "First " << display_limit << " elements of bias_data: [";
252+
for (size_t i = 0; i < display_limit; ++i) {
253+
if (i > 0) std::cout << ", ";
254+
std::cout << bias_data[i];
255+
}
256+
std::cout << "]" << std::endl;
257+
}
258+
std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << std::endl;
259+
std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl;
260+
}
261+
#endif
262+
//auto start = std::chrono::steady_clock::now();
263+
//std::cout << "Calling f32Multiply\n";
264+
// should split in parts and call ctx.ParallelFor just on the rows part
265+
266+
// rowsA = M
267+
// width = K
268+
// colsB = N
269+
size_t rowsA = static_cast<size_t>(helper.M());
270+
size_t width = static_cast<size_t>(helper.K());
271+
size_t colsB = static_cast<size_t>(helper.N());
272+
273+
const int8_t* b_data = static_cast<const int8_t*>(b_tensor->DataRaw());
274+
275+
GeckoMatmulIntegerToFloat(a_data,
276+
a_zp,
277+
b_data,
278+
b_zp_ptr,
279+
rowsA,
280+
width,
281+
colsB,
282+
b_scale_data,
283+
is_b_scale_per_column,
284+
y_data
285+
);
286+
287+
// MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
288+
/*
289+
auto end = std::chrono::steady_clock::now();
290+
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
291+
std::cout << "Done calling f32Multiply. Duration: " << duration << " nano\n";
292+
293+
std::cout << "Calling MlasGemmBatch\n";
294+
auto start2 = std::chrono::steady_clock::now();
153295
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
296+
auto end2 = std::chrono::steady_clock::now();
297+
auto duration2 = std::chrono::duration_cast<std::chrono::nanoseconds>(end2 - start2).count();
298+
std::cout << "Done calling MlasGemmBatch. Duration: " << duration2 << " nano\n";
299+
*/
300+
/*
301+
302+
// Compare y_data and y_data_2
303+
304+
size_t total_elements = rowsA * colsB;
305+
size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
306+
bool mismatch_found = false;
307+
for (size_t i = 0; i < total_elements; ++i) {
308+
if (std::fabs(y_data[i] - y_data_2[i]) > 1e-6) { // Tolerance for floating-point comparison
309+
std::cerr << "Mismatch at index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << std::endl;
310+
mismatch_found = true;
311+
break;
312+
}
313+
}
154314
315+
if (mismatch_found) {
316+
std::cerr << "Displaying the first 100 elements of y_data and y_data_2:" << std::endl;
317+
std::cerr << "[";
318+
for (size_t i = 0; i < display_limit; ++i) {
319+
std::cerr << "(Index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << ")";
320+
if (i != display_limit - 1) {
321+
std::cerr << ", ";
322+
}
323+
}
324+
std::cerr << "]" << std::endl;
325+
std::cerr << "Mismatch found between y_data and y_data_2!" << std::endl;
326+
assert(false && "Validation failed: y_data and y_data_2 are not equal.");
327+
}
328+
*/
155329
return Status::OK();
156330
}
157331

@@ -221,6 +395,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
221395
ParQuantizeLinearStd(a_data, a_data_quant, narrow<size_t>(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool());
222396

223397
bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_);
398+
399+
//std::cout << "dynamic quantize matmul calling ComputeCommon" << std::endl;
400+
401+
224402
ORT_RETURN_IF_ERROR(ComputeCommon(
225403
ctx,
226404
a_data_quant,
@@ -234,6 +412,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
234412
ctx->Input<Tensor>(IN_BIAS)));
235413

236414
if (!is_b_scale_supported) {
415+
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
237416
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
238417
}
239418

@@ -275,6 +454,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
275454
a_zero_point = *(static_cast<const uint8_t*>(a_zero_point_tensor->DataRaw()));
276455
}
277456

457+
//std::cout << "matmul integer float calling ComputeCommon" << std::endl;
278458
const Tensor* b_zp_tensor = ctx->Input<Tensor>(IN_B_ZERO_POINT);
279459
ORT_RETURN_IF_ERROR(ComputeCommon(
280460
ctx,
@@ -289,9 +469,11 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
289469
ctx->Input<Tensor>(IN_BIAS)));
290470

291471
if (!is_a_scale_scalar) {
472+
//std::cout << "dynamic quantize matmul: a scale is not scalar\n";
292473
ScaleOutput(*a_scale_tensor, *ctx->Output<Tensor>(0));
293474
}
294475
if (!is_b_scale_supported) {
476+
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
295477
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
296478
}
297479

0 commit comments

Comments
 (0)