Skip to content

Commit 9dd6a16

Browse files
edgchen1baijumeswani
authored andcommitted
Add QNN EP HTP shared memory allocator (#23136)
Adds QNN EP HTP shared memory allocator. The HTP shared memory allocator (`HtpSharedMemoryAllocator`) calls the rpcmem shared library (libcdsprpc.so/dll) to allocate and free memory that can be shared between HTP and CPU. The allocator can be enabled by setting QNN EP option `enable_htp_shared_memory_allocator` to `1`. `QNNExecutionProvider::CreatePreferredAllocators()` will then return an instance of `HtpSharedMemoryAllocator`. For each QNN context, we also need to register and unregister memory handles in order to use the HTP shared memory. This memory handle management is added to `QnnBackendManager`, which also manages the QNN context handles. For more information about using HTP shared memory with QNN, see: https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/htp_shared_buffer_tutorial.html#shared-buffer-tutorial Limitations: - HTP shared memory usage is only supported for graph inputs and outputs. Intermediate values are not supported. - An allocation is assigned to a single shared memory buffer. The allocator is not smart enough to have multiple allocations share a single shared memory buffer. Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
1 parent 0dd5526 commit 9dd6a16

35 files changed

+1469
-251
lines changed

include/onnxruntime/core/framework/allocator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ constexpr const char* OpenVINO_CPU = "OpenVINO_CPU";
5252
constexpr const char* OpenVINO_GPU = "OpenVINO_GPU";
5353
constexpr const char* OpenVINO_RT = "OpenVINO_RT";
5454
constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU";
55+
constexpr const char* QNN_HTP_SHARED = "QnnHtpShared";
5556
constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer";
5657
constexpr const char* WEBNN_TENSOR = "WebNN_Tensor";
5758

@@ -81,6 +82,10 @@ class IAllocator {
8182
*/
8283
virtual void* Alloc(size_t size) = 0;
8384

85+
/**
86+
* Free memory at p.
87+
* If p is nullptr, do nothing.
88+
*/
8489
virtual void Free(void* p) = 0;
8590

8691
// Reserve() is an interface exposed for an implementation of IAllocator

include/onnxruntime/core/framework/ortdevice.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct OrtDevice {
2525
static const MemoryType CUDA_PINNED = 1;
2626
static const MemoryType HIP_PINNED = 2;
2727
static const MemoryType CANN_PINNED = 3;
28+
static const MemoryType QNN_HTP_SHARED = 4;
2829
};
2930

3031
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)

include/onnxruntime/core/framework/ortmemoryinfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <string_view>
77

88
#include "core/common/hash_combine.h"
9+
#include "core/framework/ortdevice.h"
10+
#include "core/session/onnxruntime_c_api.h" // for OrtMemType, OrtAllocatorType
911

1012
struct OrtMemoryInfo {
1113
OrtMemoryInfo() = default; // to allow default construction of Tensor

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3670,6 +3670,10 @@ struct OrtApi {
36703670
* "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary.
36713671
* - "0": Default. Disabled.
36723672
* - "1": Enabled.
3673+
* "enable_htp_shared_memory_allocator": Enable the QNN HTP shared memory allocator. Requires libcdsprpc.so/dll to
3674+
* be available.
3675+
* - "0": Default. Disabled.
3676+
* - "1": Enabled.
36733677
*
36743678
* SNPE supported keys:
36753679
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,10 +2130,10 @@ struct KernelContext {
21302130
explicit KernelContext(OrtKernelContext* context);
21312131
size_t GetInputCount() const;
21322132
size_t GetOutputCount() const;
2133-
// If input is optional and is not present, the method returns en empty ConstValue
2133+
// If input is optional and is not present, the method returns an empty ConstValue
21342134
// which can be compared to nullptr.
21352135
ConstValue GetInput(size_t index) const;
2136-
// If outout is optional and is not present, the method returns en empty UnownedValue
2136+
// If output is optional and is not present, the method returns an empty UnownedValue
21372137
// which can be compared to nullptr.
21382138
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
21392139
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;

onnxruntime/core/framework/allocator.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
155155
mem_type1);
156156
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
157157
*out = new OrtMemoryInfo(
158-
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
158+
onnxruntime::CUDA_PINNED, type,
159+
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
159160
id1, mem_type1);
160161
} else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) {
161162
*out = new OrtMemoryInfo(
162-
onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
163+
onnxruntime::HIP_PINNED, type,
164+
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
165+
id1, mem_type1);
166+
} else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) {
167+
*out = new OrtMemoryInfo(
168+
onnxruntime::QNN_HTP_SHARED, type,
169+
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast<OrtDevice::DeviceId>(id1)),
163170
id1, mem_type1);
164171
} else {
165172
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");

onnxruntime/core/framework/session_state.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ SessionState::SessionState(Graph& graph,
101101
for (auto& ep : execution_providers_) {
102102
auto allocators = ep->CreatePreferredAllocators();
103103
for (auto& alloc : allocators) {
104-
allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key
104+
allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key
105105
}
106106
}
107107
}

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

Lines changed: 105 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@
77
#include <fstream>
88
#include <string>
99
#include "QnnOpDef.h"
10-
#include "HTP/QnnHtpPerfInfrastructure.h"
11-
#include "HTP/QnnHtpSystemContext.h"
1210
#include "CPU/QnnCpuCommon.h"
1311
// TODO: not exist for Windows yet
1412
// #include "GPU/QnnGpuCommon.h"
1513
#include "DSP/QnnDspCommon.h"
1614
#include "HTP/QnnHtpCommon.h"
1715
#include "HTP/QnnHtpContext.h"
16+
#include "HTP/QnnHtpPerfInfrastructure.h"
17+
#include "HTP/QnnHtpSystemContext.h"
1818
#include "Saver/QnnSaver.h"
1919
#include <gsl/gsl>
2020
#include "core/framework/endian_utils.h"
2121
#include "core/common/logging/capture.h"
22+
#include "core/providers/qnn/qnn_allocator.h"
2223
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
2324
#include "core/providers/qnn/builder/qnn_configs_helper.h"
25+
#include "core/providers/qnn/builder/qnn_utils.h"
2426

2527
#ifdef _WIN32
2628
#include <winmeta.h>
@@ -46,6 +48,14 @@ static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnSystemInterface_t* qnn_i
4648
return qnn_interface->systemApiVersion;
4749
}
4850

51+
static const char* DlError() {
52+
#ifdef _WIN32
53+
return "";
54+
#else
55+
return ::dlerror();
56+
#endif
57+
}
58+
4959
template <typename F, class T>
5060
Status QnnBackendManager::GetQnnInterfaceProvider(const char* lib_path,
5161
const char* interface_provider_name,
@@ -545,10 +555,11 @@ Status QnnBackendManager::CreateContext() {
545555
device_handle_,
546556
context_configs,
547557
&context);
548-
contexts_.push_back(context);
549558

550559
ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result));
551560

561+
ORT_RETURN_IF_ERROR(AddQnnContextHandle(context));
562+
552563
context_created_ = true;
553564
return Status::OK();
554565
}
@@ -558,14 +569,9 @@ Status QnnBackendManager::ReleaseContext() {
558569
return Status::OK();
559570
}
560571

561-
bool failed = false;
562-
for (auto context : contexts_) {
563-
Qnn_ErrorHandle_t result = qnn_interface_.contextFree(context, nullptr);
564-
if (QNN_CONTEXT_NO_ERROR != result) {
565-
failed = true;
566-
}
567-
}
568-
ORT_RETURN_IF(failed, "Failed to release context.");
572+
// release QNN context handles
573+
contexts_.clear();
574+
context_map_.clear();
569575

570576
context_created_ = false;
571577
return Status::OK();
@@ -766,7 +772,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
766772
&context,
767773
profile_backend_handle_);
768774
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
769-
contexts_.push_back(context);
775+
ORT_RETURN_IF_ERROR(AddQnnContextHandle(context));
770776
if (1 == graph_count) {
771777
// in case the EPContext node is generated from script
772778
// the graph name from the context binary may not match the EPContext node name
@@ -1452,13 +1458,8 @@ const char* QnnBackendManager::QnnProfileErrorToString(QnnProfile_Error_t error)
14521458
}
14531459
}
14541460

1455-
const char* QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) {
1456-
// From QNN SDK: The memory is statically owned and should not be freed by the caller.
1457-
const char* error_msg = nullptr;
1458-
if (QNN_SUCCESS == qnn_interface_.errorGetMessage(error, &error_msg)) {
1459-
return error_msg;
1460-
}
1461-
return "Unknown";
1461+
std::string_view QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) {
1462+
return utils::GetQnnErrorMessage(qnn_interface_, error);
14621463
}
14631464

14641465
const std::string QnnBackendManager::ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) {
@@ -1691,5 +1692,90 @@ void* QnnBackendManager::LibFunction(void* handle, const char* symbol, std::stri
16911692
#endif
16921693
}
16931694

1695+
Status QnnBackendManager::AddQnnContextHandle(Qnn_ContextHandle_t raw_context_handle) {
1696+
ORT_RETURN_IF(logger_ == nullptr, "logger_ should be set.");
1697+
1698+
auto free_context_handle = [this, &logger = *logger_](Qnn_ContextHandle_t raw_context_handle) {
1699+
const auto free_result = qnn_interface_.contextFree(raw_context_handle, nullptr);
1700+
if (free_result != QNN_CONTEXT_NO_ERROR) {
1701+
LOGS(logger, ERROR) << "qnn_interface.contextFree() failed: "
1702+
<< utils::GetVerboseQnnErrorMessage(qnn_interface_, free_result);
1703+
}
1704+
};
1705+
1706+
// take ownership of `raw_context_handle`
1707+
auto context_handle = UniqueQnnContextHandle(raw_context_handle, free_context_handle);
1708+
auto mem_handle_manager = std::make_unique<QnnContextMemHandleManager>(GetQnnInterface(), raw_context_handle,
1709+
*logger_);
1710+
1711+
auto context_handle_record = std::make_shared<QnnContextHandleRecord>();
1712+
context_handle_record->context_handle = std::move(context_handle);
1713+
context_handle_record->mem_handles = std::move(mem_handle_manager);
1714+
1715+
const bool inserted = context_map_.try_emplace(raw_context_handle, std::move(context_handle_record)).second;
1716+
ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", raw_context_handle);
1717+
1718+
contexts_.push_back(raw_context_handle);
1719+
1720+
return Status::OK();
1721+
}
1722+
1723+
Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context_handle,
1724+
void* shared_memory_address,
1725+
const Qnn_Tensor_t& qnn_tensor,
1726+
Qnn_MemHandle_t& mem_handle) {
1727+
// Multi-threading situations to consider:
1728+
// 1) Shared memory allocation is being freed in another thread while we are processing `shared_memory_address`.
1729+
// This implies incorrect usage as the memory is being freed while it is still in use. Let's assume this won't
1730+
// happen.
1731+
// 2) The shared memory allocation clean up function is being run from another thread while the
1732+
// QnnContextHandleRecord or QnnBackendManager objects are being destroyed.
1733+
// Usage of weak_ptrs from the clean up function should ensure that those objects are only accessed while they are
1734+
// in scope.
1735+
1736+
const auto context_handle_record_it = context_map_.find(context_handle);
1737+
ORT_RETURN_IF_NOT(context_handle_record_it != context_map_.end(), "QNN context not found: ", context_handle);
1738+
1739+
auto& context_handle_record = context_handle_record_it->second;
1740+
auto& context_mem_handle_manager = context_handle_record->mem_handles;
1741+
1742+
bool did_register{};
1743+
ORT_RETURN_IF_ERROR(context_mem_handle_manager->GetOrRegister(shared_memory_address, qnn_tensor,
1744+
mem_handle, did_register));
1745+
1746+
if (did_register) {
1747+
HtpSharedMemoryAllocator::AllocationCleanUpFn unregister_mem_handle =
1748+
[&logger = *logger_,
1749+
weak_backend_manager = weak_from_this(),
1750+
weak_context_handle_record = std::weak_ptr{context_handle_record}](
1751+
void* shared_memory_address) {
1752+
// Lock QnnBackendManager shared_ptr to ensure that QNN interface is still valid.
1753+
auto backend_manager = weak_backend_manager.lock();
1754+
if (!backend_manager) {
1755+
return;
1756+
}
1757+
1758+
// Lock QnnContextHandleRecord shared_ptr to ensure that QNN context handle is still valid.
1759+
auto context_handle_record = weak_context_handle_record.lock();
1760+
if (!context_handle_record) {
1761+
return;
1762+
}
1763+
1764+
auto& context_mem_handle_manager = context_handle_record->mem_handles;
1765+
1766+
auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address);
1767+
if (!unregister_status.IsOK()) {
1768+
LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: "
1769+
<< shared_memory_address << ", error: " << unregister_status.ErrorMessage();
1770+
}
1771+
};
1772+
1773+
ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::AddAllocationCleanUp(shared_memory_address,
1774+
std::move(unregister_mem_handle)));
1775+
}
1776+
1777+
return Status::OK();
1778+
}
1779+
16941780
} // namespace qnn
16951781
} // namespace onnxruntime

0 commit comments

Comments
 (0)