Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MIGraphX EP] Adding Ortvalue features support for MGX EP #23404

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,21 @@ typedef struct OrtMIGraphXProviderOptions {
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false

/** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t)
* Defaults to SIZE_MAX.
* \note If a ::OrtArenaCfg has been applied, it will override this field
*/
size_t migraphx_mem_limit;

/** \brief Strategy used to grow the memory arena
* 0 = kNextPowerOfTwo<br>
* 1 = kSameAsRequested<br>
* Defaults to 0.
* \note If a ::OrtArenaCfg has been applied, it will override this field
*/
int migraphx_arena_extend_strategy;

} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
45 changes: 42 additions & 3 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
#include "core/common/safeint.h"
#include "core/common/logging/severity.h"
#include "migraphx_execution_provider.h"
#include "migraphx_execution_provider_info.h"

Check warning on line 16 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:16: Include the directory when naming header files [build/include_subdir] [4]
#include "migraphx_execution_provider_utils.h"
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
#include "migraphx_inc.h"
#include <hip/hip_version.h>
#include "migraphx_call.h"

Check warning on line 21 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:21: Include the directory when naming header files [build/include_subdir] [4]

#include "migraphx_stream_handle.h"

Expand Down Expand Up @@ -211,12 +212,50 @@
MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {
}

AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id,
size_t migx_mem_limit,
ArenaExtendStrategy arena_extend_strategy,
MIGraphXExecutionProviderExternalAllocatorInfo
external_allocator_info,
const OrtArenaCfg* default_memory_arena_cfg) {
if (external_allocator_info.UseExternalAllocator()) {
AllocatorCreationInfo default_memory_info(
[external_allocator_info](OrtDevice::DeviceId id) {
return std::make_unique<MIGraphXExternalAllocator>(id, HIP,
external_allocator_info.alloc,
external_allocator_info.free,
external_allocator_info.empty_cache);
},
device_id,
false);

return CreateAllocator(default_memory_info);
} else {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId id) {
return std::make_unique<MIGraphXAllocator>(id, HIP);
},
device_id,
true,
{default_memory_arena_cfg ? *default_memory_arena_cfg
: OrtArenaCfg(migx_mem_limit, static_cast<int>(arena_extend_strategy),
-1, -1, -1, -1L)},
// make it stream aware
true,
// enable cross stream sharing?
false);

// ROCM malloc/free is expensive so always use an arena
return CreateAllocator(default_memory_info);
}
}

std::vector<AllocatorPtr> MIGraphXExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId device_id) { return CreateMIGraphXAllocator(device_id, onnxruntime::CUDA); }, info_.device_id);
[](OrtDevice::DeviceId device_id) { return std::make_unique<MIGraphXAllocator>(device_id, onnxruntime::CUDA); }, info_.device_id);
AllocatorCreationInfo pinned_allocator_info(
[](OrtDevice::DeviceId device_id) {
return CreateMIGraphXPinnedAllocator(device_id, onnxruntime::CUDA_PINNED);
return std::make_unique<HIPPinnedAllocator>(device_id, onnxruntime::CUDA_PINNED);
},
0);
return std::vector<AllocatorPtr>{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "core/framework/execution_provider.h"
#include <mutex>
#include "core/providers/migraphx/migraphx_execution_provider_info.h"
#include "core/providers/migraphx/migraphx_inc.h"
#include "core/providers/migraphx/migraphx_call.h"

#include <map>
#include <unordered_map>
Expand Down Expand Up @@ -76,6 +76,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;

static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy,
MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg);

std::unique_ptr<IndexedSubGraph> GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"

#include "core/common/make_string.h"
Expand All @@ -10,6 +11,12 @@
#include "migraphx_call.h"

namespace onnxruntime {

const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
{ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"},
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
};

namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
Expand All @@ -22,12 +29,20 @@
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";
constexpr const char* kMemLimit = "migx_mem_limit";
constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy";
constexpr const char* kGpuExternalAlloc = "migx_external_alloc";
constexpr const char* kGpuExternalFree = "migx_external_free";
constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache";

} // namespace provider_option_names
} // namespace migraphx

MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
MIGraphXExecutionProviderInfo info{};
void* alloc = nullptr;
void* free = nullptr;
void* empty_cache = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
Expand All @@ -42,13 +57,42 @@
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
alloc = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalFree,
[&free](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
free = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalEmptyCache,
[&empty_cache](const std::string& value_str) -> Status {

Check warning on line 78 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc:78: Add #include <string> for string [build/include_what_you_use] [4]
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
empty_cache = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
.AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit)
.AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy)
.Parse(options));

MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
info.external_allocator_info = alloc_info;

return info;
}

Expand All @@ -59,6 +103,12 @@
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)},
{migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
{migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache))},
{migraphx::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
};
return options;
Expand All @@ -71,6 +121,8 @@
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)},
{migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.migraphx_arena_extend_strategy))},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
};
return options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,36 @@
#include <string>

#include "core/framework/ortdevice.h"
#include "core/common/hash_combine.h"
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/provider_options.h"
#include "core/session/onnxruntime_c_api.h"

namespace onnxruntime {

// Information needed to construct MIGraphX execution providers.
struct MIGraphXExecutionProviderExternalAllocatorInfo {
void* alloc{nullptr};
void* free{nullptr};
void* empty_cache{nullptr};

MIGraphXExecutionProviderExternalAllocatorInfo() {
alloc = nullptr;
free = nullptr;
empty_cache = nullptr;
}

MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) {
alloc = a;
free = f;
empty_cache = e;
}

bool UseExternalAllocator() const {
return (alloc != nullptr) && (free != nullptr);
}
};

// Information needed to construct trt execution providers.
struct MIGraphXExecutionProviderInfo {
std::string target_device;
Expand All @@ -25,8 +51,42 @@ struct MIGraphXExecutionProviderInfo {
std::string load_model_file{"./compiled_model.mxr"};
bool exhaustive_tune{false};

size_t mem_limit{std::numeric_limits<size_t>::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)

OrtArenaCfg* default_memory_arena_cfg{nullptr};
MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info);
};
} // namespace onnxruntime

template <>
struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> {
size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const {
size_t value{0xbc9f1d34}; // seed

// Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each)
size_t data = static_cast<size_t>(info.device_id) ^
(static_cast<size_t>(info.arena_extend_strategy) << 16) ^
(static_cast<size_t>(info.fp16_enable) << 18) ^
(static_cast<size_t>(info.int8_enable) << 19) ^
(static_cast<size_t>(info.int8_use_native_calibration_table) << 20) ^
(static_cast<size_t>(info.save_compiled_model) << 21) ^
(static_cast<size_t>(info.load_compiled_model) << 22) ^
(static_cast<size_t>(info.exhaustive_tune) << 23);
onnxruntime::HashCombine(data, value);

onnxruntime::HashCombine(info.mem_limit, value);

// Memory pointers
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.alloc), value);
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.free), value);
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache), value);

// The default memory arena cfg is not used in hashing right now.
return value;
}
};
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_provider_factory.h"
#include "migraphx_execution_provider.h"
#include "migraphx_execution_provider_info.h"

Check warning on line 8 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:8: Include the directory when naming header files [build/include_subdir] [4]
#include "migraphx_provider_factory_creator.h"
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
Expand Down Expand Up @@ -42,6 +43,27 @@
return std::make_unique<HIPPinnedAllocator>(device_id, name);
}

void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override {
// hipMemcpy() operates on the default stream
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice));

// To ensure that the copy has completed, invoke a stream sync for the default stream.
// For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated.
// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer
// to device memory, but the DMA to final destination may not have completed.

HIP_CALL_THROW(hipStreamSynchronize(0));
}

// Used by onnxruntime_pybind_state.cc
void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override {
// For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed.
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost));
}

std::shared_ptr<IAllocator> CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override {
return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg);
}
} g_info;

struct MIGraphX_Provider : Provider {
Expand Down Expand Up @@ -77,6 +99,8 @@
if (options.migraphx_load_model_path != nullptr) {
info.load_model_file = options.migraphx_load_model_path;
}
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(options.migraphx_arena_extend_strategy);
info.mem_limit = options.migraphx_mem_limit;
return std::make_shared<MIGraphXProviderFactory>(info);
}

Expand Down Expand Up @@ -109,6 +133,8 @@
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
migx_options.migraphx_arena_extend_strategy = static_cast<int>(internal_options.arena_extend_strategy);
migx_options.migraphx_mem_limit = internal_options.mem_limit;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
struct ProviderInfo_MIGraphX {
virtual std::unique_ptr<onnxruntime::IAllocator> CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0;
virtual std::unique_ptr<onnxruntime::IAllocator> CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0;
virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0;
virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0;
virtual std::shared_ptr<onnxruntime::IAllocator> CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0;

Check warning on line 19 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.h:19: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]

protected:
~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance
Expand Down
Loading
Loading