Skip to content

Commit ba11af4

Browse files
[QNN-EP] Add MatMulNBits translation for GPU (#26340)
### Description Add support for translation of MatMulNBits contrib op to QNN with FullyConnected operation with INT4 BlockQuantized weights Implementation details: - Translate MatMulNBits to FullyConnected in OpBuilder - Support QNN_QUANTIZATION_ENCODING_BLOCK for INT4 weights - Pass INT4 weights and quant params as BlockQuantization encoding params in QNN Testing: - Added new unit tests for MNB -> QNN-GPU - Validated all OnnxRuntime tests - Validated the following LLMs through Olive and ORT-GenAI execution flow - LlaMA3.2 1B - Qwen2.5 - DeepSeek-R1-Qwen 1.5b - Phi3.5-mini-instruct ### Motivation and Context LLMs with INT4 quantization pass in Olive will generate a model with MatMulMBits contrib ops. To run these ops via QNN-EP, MatMulNBits is translated to QNN FullyConnected op with INT4 weights. --------- Co-authored-by: tirupath-qti <tirupath@qti.qualcomm.com>
1 parent b6ed7f3 commit ba11af4

File tree

11 files changed

+568
-16
lines changed

11 files changed

+568
-16
lines changed

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
227227
{
228228
CreateInverseOpBuilder("Inverse", *this);
229229
}
230+
231+
{
232+
CreateMatMulNBitsOpBuilder("MatMulNBits", *this);
233+
}
230234
}
231235

232236
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
127127

128128
void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
129129

130+
void CreateMatMulNBitsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
130131
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
131132

132133
} // namespace qnn

onnxruntime/core/providers/qnn/builder/opbuilder/matmulnbits_op_builder.cc

Lines changed: 381 additions & 0 deletions
Large diffs are not rendered by default.

onnxruntime/core/providers/qnn/builder/qnn_def.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface,
456456
return false;
457457
}
458458
// verify size expressed by the dims matches the raw tensor size
459-
uint32_t qnn_tensor_size = CalcQnnTensorNumElems(qnn_tensor) * gsl::narrow_cast<uint32_t>(data_size);
459+
const auto qnn_tensor_size = utils::GetQnnTensorDataSizeInBytes(qnn_tensor);
460460
auto qnn_tensor_buf_size = GetQnnTensorClientBuf(qnn_tensor).dataSize;
461461
if (qnn_tensor_size != qnn_tensor_buf_size) {
462462
ss << "Data length mismatch for static tensor. node_name: " << node_name

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,8 @@ void QnnModelWrapper::GetGraphInputOutputTensorWrapper(const std::vector<std::st
881881
}
882882

883883
Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
884-
std::vector<uint8_t>& unpacked_tensor) const {
884+
std::vector<uint8_t>& unpacked_tensor,
885+
const bool unpack_4_bit_to_8_bit) const {
885886
if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) {
886887
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(),
887888
unpacked_tensor));
@@ -891,12 +892,13 @@ Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto&
891892

892893
int32_t onnx_data_type = initializer.data_type();
893894

894-
// If this is an int4, we need to unpack it because QNN treats int4 as a full int8.
895-
if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
895+
// If this is an int4,
896+
// If unpack_4_bit_to_8_bit is true, we need to unpack it because QNN HTP treats int4 as a full int8.
897+
if (unpack_4_bit_to_8_bit && onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
896898
TensorShape shape(qnn::utils::GetInitializerShape<int64_t>(initializer));
897899
const size_t num_int4_elems = shape.Size();
898900
ORT_RETURN_IF_ERROR(qnn::utils::UnpackInt4ToInt8<true>(num_int4_elems, unpacked_tensor));
899-
} else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
901+
} else if (unpack_4_bit_to_8_bit && onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
900902
TensorShape shape(qnn::utils::GetInitializerShape<int64_t>(initializer));
901903
const size_t num_uint4_elems = shape.Size();
902904
ORT_RETURN_IF_ERROR(qnn::utils::UnpackInt4ToInt8<false>(num_uint4_elems, unpacked_tensor));

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ class QnnModelWrapper {
245245
}
246246

247247
Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
248-
std::vector<uint8_t>& unpacked_tensor) const;
248+
std::vector<uint8_t>& unpacked_tensor,
249+
const bool unpack_4_bit_to_8_bit = true) const;
249250

250251
QnnBackendType GetQnnBackendType() const { return qnn_backend_type_; }
251252

onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ QnnQuantParamsWrapper::QnnQuantParamsWrapper(const QnnQuantParamsWrapper& other)
2020
size_t num_scaleoffsets = 0;
2121
if (other.IsLPBQ()) {
2222
num_scaleoffsets = other.per_channel_scales_size_;
23+
} else if (other.IsBlockQuantized()) {
24+
block_encoding_tensor_rank_ = other.block_encoding_tensor_rank_;
25+
num_scaleoffsets = other.num_blocks_;
2326
}
24-
Status status = Init(other.params_, num_scaleoffsets);
27+
Status status = Init(other.params_, num_scaleoffsets, block_encoding_tensor_rank_);
2528
assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding.
2629
}
2730

@@ -30,8 +33,11 @@ QnnQuantParamsWrapper& QnnQuantParamsWrapper::operator=(const QnnQuantParamsWrap
3033
size_t num_scaleoffsets = 0;
3134
if (other.IsLPBQ()) {
3235
num_scaleoffsets = other.per_channel_scales_size_;
36+
} else if (other.IsBlockQuantized()) {
37+
block_encoding_tensor_rank_ = other.block_encoding_tensor_rank_;
38+
num_scaleoffsets = other.num_blocks_;
3339
}
34-
Status status = Init(other.params_, num_scaleoffsets);
40+
Status status = Init(other.params_, num_scaleoffsets, block_encoding_tensor_rank_);
3541
assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding.
3642
}
3743

@@ -156,6 +162,39 @@ QnnQuantParamsWrapper::QnnQuantParamsWrapper(gsl::span<const float> per_channel_
156162
params_.blockwiseExpansion = lpbqPtr;
157163
}
158164

165+
// Construct a BlockEncoding BQ quantization param.
166+
QnnQuantParamsWrapper::QnnQuantParamsWrapper(
167+
gsl::span<const float> scales,
168+
gsl::span<const int32_t> offsets,
169+
gsl::span<const uint32_t> block_sizes,
170+
Qnn_DataType_t tensor_data_type) {
171+
ORT_UNUSED_PARAMETER(tensor_data_type);
172+
assert(block_sizes.size() > 0);
173+
assert(scales.size() > 0);
174+
assert(scales.size() == offsets.size()); // Logic error if sizes don't match.
175+
176+
num_blocks_ = static_cast<uint32_t>(scales.size());
177+
params_.encodingDefinition = QNN_DEFINITION_DEFINED;
178+
params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCK;
179+
180+
block_encoding_tensor_rank_ = static_cast<uint32_t>(block_sizes.size());
181+
block_encoding_axis_data_ = std::make_unique<uint32_t[]>(block_encoding_tensor_rank_);
182+
std::memcpy(block_encoding_axis_data_.get(),
183+
block_sizes.data(),
184+
static_cast<size_t>(block_encoding_tensor_rank_) * sizeof(uint32_t));
185+
params_.blockEncoding.blockSize = block_encoding_axis_data_.get();
186+
187+
// Deep copy the scale offsets
188+
if (num_blocks_ > 0) {
189+
block_encoding_scale_offsets_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_blocks_);
190+
for (size_t i = 0; i < num_blocks_; ++i) {
191+
block_encoding_scale_offsets_data_[i].offset = offsets[i];
192+
block_encoding_scale_offsets_data_[i].scale = scales[i];
193+
}
194+
params_.blockEncoding.scaleOffset = block_encoding_scale_offsets_data_.get();
195+
}
196+
}
197+
159198
// Get a copy of scales. Works for both per-tensor and per-channel.
160199
Status QnnQuantParamsWrapper::GetScales(/*out*/ std::vector<float>& scales) const {
161200
ORT_RETURN_IF_NOT(params_.encodingDefinition == QNN_DEFINITION_DEFINED, "Unquantized qparams does not have scales");
@@ -195,6 +234,18 @@ Status QnnQuantParamsWrapper::GetScales(/*out*/ std::vector<float>& scales) cons
195234
}
196235
break;
197236
}
237+
case QNN_QUANTIZATION_ENCODING_BLOCK: {
238+
scales.resize(num_blocks_);
239+
240+
if (num_blocks_ > 0) {
241+
gsl::span<const Qnn_ScaleOffset_t> scale_offsets(params_.blockEncoding.scaleOffset, num_blocks_);
242+
243+
for (size_t i = 0; i < num_blocks_; i++) {
244+
scales[i] = scale_offsets[i].scale;
245+
}
246+
}
247+
break;
248+
}
198249
default:
199250
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ",
200251
params_.quantizationEncoding);
@@ -208,7 +259,7 @@ QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const {
208259
}
209260

210261
// Initializes by copying from a Qnn_QuantizeParams_t.
211-
Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const size_t lpbq_num_scaleoffsets) {
262+
Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const size_t num_scaleoffsets, const size_t tensor_rank) {
212263
if (per_channel_data_) {
213264
per_channel_data_.reset(nullptr);
214265
params_ = QNN_QUANTIZE_PARAMS_INIT;
@@ -278,7 +329,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz
278329
break;
279330
}
280331
case QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION: {
281-
assert(lpbq_num_scaleoffsets && "Can't create BlockwiseExpansion encoding object with zero ScaleOffsets");
332+
assert(num_scaleoffsets && "Can't create BlockwiseExpansion encoding object with zero ScaleOffsets");
282333
params_.encodingDefinition = params.encodingDefinition;
283334
params_.quantizationEncoding = params.quantizationEncoding;
284335

@@ -291,7 +342,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz
291342
params_.blockwiseExpansion = bwe_aligned_dst;
292343

293344
// Deep copy the scaleoffsets
294-
const size_t so_num_elems = lpbq_num_scaleoffsets;
345+
const size_t so_num_elems = num_scaleoffsets;
295346
const size_t so_num_bytes = so_num_elems * sizeof(Qnn_ScaleOffset_t);
296347
constexpr std::uintptr_t so_align = alignof(Qnn_ScaleOffset_t);
297348
per_channel_data_ = std::make_unique<char[]>(so_num_bytes + so_align);
@@ -301,7 +352,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz
301352
params_.blockwiseExpansion->scaleOffsets = so_aligned_dst;
302353

303354
// Deep copy blockscales
304-
const size_t bs_num_elems = lpbq_num_scaleoffsets * params.blockwiseExpansion->numBlocksPerAxis;
355+
const size_t bs_num_elems = num_scaleoffsets * params.blockwiseExpansion->numBlocksPerAxis;
305356
const size_t bs_num_bytes = bs_num_elems * sizeof(uint8_t);
306357
constexpr std::uintptr_t bs_align = alignof(uint8_t);
307358
block_scales_data_ = std::make_unique<uint8_t[]>(bs_num_bytes + bs_align);
@@ -310,6 +361,28 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz
310361
params_.blockwiseExpansion->blocksScale8 = bs_aligned_dst;
311362
break;
312363
}
364+
case QNN_QUANTIZATION_ENCODING_BLOCK: {
365+
assert(num_scaleoffsets && "Can't create Block encoding object with zero ScaleOffsets");
366+
params_.encodingDefinition = params.encodingDefinition;
367+
params_.quantizationEncoding = params.quantizationEncoding;
368+
369+
block_encoding_tensor_rank_ = static_cast<uint32_t>(tensor_rank);
370+
block_encoding_axis_data_ = std::make_unique<uint32_t[]>(block_encoding_tensor_rank_);
371+
std::memcpy(block_encoding_axis_data_.get(),
372+
params.blockEncoding.blockSize,
373+
static_cast<size_t>(block_encoding_tensor_rank_) * sizeof(uint32_t));
374+
params_.blockEncoding.blockSize = block_encoding_axis_data_.get();
375+
376+
// Deep copy the scale offsets
377+
block_encoding_scale_offsets_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_scaleoffsets);
378+
for (size_t i = 0; i < num_scaleoffsets; ++i) {
379+
block_encoding_scale_offsets_data_[i].scale = params.blockEncoding.scaleOffset[i].scale;
380+
block_encoding_scale_offsets_data_[i].offset = params.blockEncoding.scaleOffset[i].offset;
381+
}
382+
params_.blockEncoding.scaleOffset = block_encoding_scale_offsets_data_.get();
383+
384+
break;
385+
}
313386
default:
314387
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding);
315388
}

onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,16 @@ class QnnQuantParamsWrapper {
3434
QnnQuantParamsWrapper(gsl::span<const float> per_channel_float_scales, gsl::span<const uint8_t> per_block_int_scales,
3535
gsl::span<const int32_t> offsets, int64_t axis, int64_t block_size, bool is_int4);
3636

37+
// Construct a BQ quantization param.
38+
QnnQuantParamsWrapper(
39+
gsl::span<const float> scales, gsl::span<const int32_t> offsets,
40+
gsl::span<const uint32_t> block_size, Qnn_DataType_t tensor_data_type);
41+
3742
Qnn_QuantizeParams_t& Get() { return params_; }
3843
const Qnn_QuantizeParams_t& Get() const { return params_; }
3944

4045
// Initialize this object from a raw Qnn_QuantizeParam_t object.
41-
Status Init(const Qnn_QuantizeParams_t& params, const size_t lpbq_num_scaleoffsets = 0);
46+
Status Init(const Qnn_QuantizeParams_t& params, const size_t num_scaleoffsets = 0, const size_t tensor_rank = 0);
4247

4348
// Initialize this object from a (potentially) quantized ONNX tensor.
4449
// QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers.
@@ -67,6 +72,11 @@ class QnnQuantParamsWrapper {
6772
(params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION);
6873
}
6974

75+
bool IsBlockQuantized() const {
76+
return params_.encodingDefinition == QNN_DEFINITION_DEFINED &&
77+
(params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCK);
78+
}
79+
7080
// Get a copy of scales. Works for both per-tensor and per-channel.
7181
Status GetScales(/*out*/ std::vector<float>& scales) const;
7282

@@ -163,6 +173,12 @@ class QnnQuantParamsWrapper {
163173
uint32_t per_channel_scales_size_;
164174
std::unique_ptr<uint8_t[]> block_scales_data_;
165175
std::unique_ptr<char[]> blockwise_expansion_data_;
176+
177+
// Stores BlockEncoding axis and scale offset data
178+
uint32_t block_encoding_tensor_rank_ = 0;
179+
uint32_t num_blocks_ = 0;
180+
std::unique_ptr<uint32_t[]> block_encoding_axis_data_;
181+
std::unique_ptr<Qnn_ScaleOffset_t[]> block_encoding_scale_offsets_data_;
166182
};
167183

168184
} // namespace qnn

onnxruntime/core/providers/qnn/builder/qnn_utils.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type) {
3636
{QNN_DATATYPE_FLOAT_32, 4},
3737
{QNN_DATATYPE_BFLOAT_16, 2},
3838
{QNN_DATATYPE_BOOL_8, 1},
39+
{QNN_DATATYPE_SFIXED_POINT_4, sizeof(Int4x2)},
3940
{QNN_DATATYPE_SFIXED_POINT_8, 1},
4041
{QNN_DATATYPE_SFIXED_POINT_16, 2},
4142
{QNN_DATATYPE_SFIXED_POINT_32, 4},
43+
{QNN_DATATYPE_UFIXED_POINT_4, sizeof(Int4x2)},
4244
{QNN_DATATYPE_UFIXED_POINT_8, 1},
4345
{QNN_DATATYPE_UFIXED_POINT_16, 2},
4446
{QNN_DATATYPE_UFIXED_POINT_32, 4},
@@ -105,11 +107,25 @@ size_t GetElementSizeByType(ONNX_NAMESPACE::TensorProto_DataType onnx_type) {
105107
}
106108
// Unreachable
107109
}
110+
size_t GetQnnTensorDataSizeInBytes(size_t num_elements, Qnn_DataType_t element_type) {
111+
SafeInt<size_t> safe_num_elements = num_elements;
112+
if (element_type == QNN_DATATYPE_SFIXED_POINT_4 || element_type == QNN_DATATYPE_UFIXED_POINT_4) {
113+
return (safe_num_elements + 1) / 2;
114+
}
115+
return (safe_num_elements * GetElementSizeByType(element_type));
116+
}
108117

109118
size_t GetQnnTensorDataSizeInBytes(gsl::span<const uint32_t> shape, Qnn_DataType_t element_type) {
110119
ORT_ENFORCE(!shape.empty(), "Empty shape not allowed."); // TODO can we just treat empty shape as a scalar?
111-
SafeInt<size_t> data_length = GetElementSizeByType(element_type);
112-
return std::accumulate(shape.begin(), shape.end(), data_length, std::multiplies<>{});
120+
SafeInt<size_t> num_elements = std::accumulate(shape.begin(), shape.end(), SafeInt<size_t>{1}, std::multiplies<>{});
121+
return GetQnnTensorDataSizeInBytes(num_elements, element_type);
122+
}
123+
124+
size_t GetQnnTensorDataSizeInBytes(const Qnn_Tensor_t& tensor) {
125+
uint32_t rank = GetQnnTensorRank(tensor);
126+
uint32_t* dims = GetQnnTensorDims(tensor);
127+
gsl::span<const uint32_t> shape{dims, static_cast<size_t>(rank)};
128+
return GetQnnTensorDataSizeInBytes(shape, GetQnnTensorDataType(tensor));
113129
}
114130

115131
bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor) {
@@ -999,7 +1015,7 @@ Status QuantizeData(gsl::span<const float> data, gsl::span<const uint32_t> shape
9991015
const size_t num_dims = shape.size();
10001016
const size_t num_elems = ShapeSizeCalc(shape, 0, num_dims);
10011017
ORT_RETURN_IF_NOT(num_elems == data.size(), "Shape mismatch with data to quantize");
1002-
size_t expected_num_quant_bytes = GetElementSizeByType(data_type) * data.size();
1018+
size_t expected_num_quant_bytes = GetQnnTensorDataSizeInBytes(data.size(), data_type);
10031019
ORT_RETURN_IF_NOT(quant_bytes.size() == expected_num_quant_bytes,
10041020
"Cannot quantize data because output buffer is not the correct size");
10051021

onnxruntime/core/providers/qnn/builder/qnn_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ class QnnJSONGraph {
7878

7979
size_t GetElementSizeByType(ONNX_NAMESPACE::TensorProto_DataType onnx_type);
8080

81+
size_t GetQnnTensorDataSizeInBytes(size_t num_elements, Qnn_DataType_t element_data_type);
8182
size_t GetQnnTensorDataSizeInBytes(gsl::span<const uint32_t> shape, Qnn_DataType_t element_data_type);
83+
size_t GetQnnTensorDataSizeInBytes(const Qnn_Tensor_t& tensor);
8284

8385
bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor);
8486

0 commit comments

Comments
 (0)