Skip to content

Commit 8e5b8b6

Browse files
quic-ashigargAshish Garg (AISW)
authored andcommitted
Set shared memory type based on options during the compilation phase (microsoft#24196)
### Description During inference, using the QNN EP option to set enable_htp_shared_memory_allocator gives a hint that we use RPC allocated buffers to avoid buffer copy between CPU and NPU. With the current PR, we add hints in the compilation phase that if RPC memory is going to be used, any additional allocations done on the CPU can be avoided. ### Motivation and Context This should help reduce the peak CPU memory consumption while running AI work loads using shared memory. Related PR: microsoft#23136 Co-authored-by: Ashish Garg (AISW) <ashigarg@qti.qualcomm.com>
1 parent 54f07b5 commit 8e5b8b6

File tree

5 files changed

+13
-7
lines changed

5 files changed

+13
-7
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
204204
std::string output_name;
205205
};
206206
std::vector<CastNodeInfo> cast_node_info_vec;
207-
207+
auto mem_type = QNN_TENSORMEMTYPE_RAW;
208+
if (true == qnn_model_wrapper.GetModelSettings().htp_shared_memory) {
209+
mem_type = QNN_TENSORMEMTYPE_MEMHANDLE;
210+
}
208211
const auto output_count = GetOutputCountQnnRequired(node_unit);
209212
for (size_t output_i = 0; output_i < output_count; ++output_i) {
210213
const auto& output_name = outputs[output_i].node_arg.Name();
@@ -255,7 +258,8 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper,
255258
QNN_TENSOR_TYPE_NATIVE,
256259
supported_qnn_data_type,
257260
output_info.quant_param.Copy(),
258-
std::move(cast_output_shape));
261+
std::move(cast_output_shape), {},
262+
mem_type);
259263
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor.");
260264
output_names.push_back(cast_input_name);
261265
cast_node_info_vec.push_back({cast_node_name, cast_input_name, output_name});

onnxruntime/core/providers/qnn/builder/qnn_def.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,6 @@ class QnnTensorWrapper {
188188
SetQnnTensorClientBuf(qnn_tensor_, client_buf_);
189189
}
190190

191-
if (mem_type != QNN_TENSORMEMTYPE_RAW) {
192-
ORT_THROW("mem_type not supported for now.");
193-
}
194-
195191
SetQnnTensorQParams(qnn_tensor_, quant_params_.Get());
196192
}
197193

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ Status QnnModelWrapper::MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensor
6868
ORT_RETURN_IF_ERROR(UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor));
6969
}
7070

71+
Qnn_TensorMemType_t mem_type = QNN_TENSORMEMTYPE_RAW;
72+
if (true == model_settings_.htp_shared_memory) {
73+
mem_type = QNN_TENSORMEMTYPE_MEMHANDLE;
74+
}
7175
tensor_wrapper = QnnTensorWrapper(tensor_name, GetTensorType(tensor_name), tensor_info.qnn_data_type,
7276
std::move(tensor_info.quant_param), std::move(tensor_info.shape),
73-
std::move(unpacked_tensor));
77+
std::move(unpacked_tensor), mem_type);
7478
return Status::OK();
7579
}
7680

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct TensorInfo {
3030

3131
struct ModelSettings {
3232
bool offload_graph_io_quantization = false;
33+
bool htp_shared_memory = false;
3334
};
3435

3536
class QnnModelWrapper {

onnxruntime/core/providers/qnn/qnn_execution_provider.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
431431
// Initialize rpcmem_library_.
432432
// This is necessary for HtpSharedMemoryAllocator to function and also indicates that the allocator is available.
433433
rpcmem_library_ = std::make_shared<qnn::RpcMemLibrary>();
434+
model_settings_.htp_shared_memory = true;
434435
}
435436

436437
dump_json_qnn_graph_ = ParseBoolOption("dump_json_qnn_graph", false, provider_options_map);

0 commit comments

Comments
 (0)