|
3 | 3 |
|
4 | 4 | #include <algorithm> |
5 | 5 | #include <cassert> |
| 6 | +#include <climits> |
6 | 7 | #include <cstring> |
7 | 8 | #include <functional> |
8 | 9 | #include <mutex> |
9 | | -#include <vector> |
10 | 10 | #include <sstream> |
| 11 | +#include <vector> |
11 | 12 |
|
12 | 13 | #include "core/common/common.h" |
13 | 14 | #include "core/common/logging/logging.h" |
|
30 | 31 | #include "core/framework/utils.h" |
31 | 32 | #include "core/graph/constants.h" |
32 | 33 | #include "core/graph/graph.h" |
| 34 | +#include "core/graph/model.h" |
33 | 35 | #include "core/graph/model_editor_api_types.h" |
34 | 36 | #include "core/graph/ep_api_types.h" |
| 37 | +#include "core/graph/onnx_protobuf.h" |
35 | 38 | #include "core/providers/get_execution_providers.h" |
36 | 39 | #include "core/session/abi_devices.h" |
37 | 40 | #include "core/session/abi_session_options_impl.h" |
38 | 41 | #include "core/session/allocator_adapters.h" |
39 | 42 | #include "core/session/compile_api.h" |
40 | 43 | #include "core/session/environment.h" |
41 | 44 | #include "core/session/interop_api.h" |
| 45 | +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" |
42 | 46 | #include "core/session/plugin_ep/ep_api.h" |
43 | 47 | #include "core/session/plugin_ep/ep_library_internal.h" |
44 | 48 | #include "core/session/inference_session.h" |
@@ -3565,6 +3569,93 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, |
3565 | 3569 | API_IMPL_END |
3566 | 3570 | } |
3567 | 3571 |
|
| 3572 | +// Helper function to extract compatibility info from model metadata |
| 3573 | +static OrtStatus* ExtractCompatibilityInfoFromModelProto( |
| 3574 | + const ONNX_NAMESPACE::ModelProto& model_proto, |
| 3575 | + const char* ep_type, |
| 3576 | + OrtAllocator* allocator, |
| 3577 | + char** compatibility_info) { |
| 3578 | + // Build the key we're looking for |
| 3579 | + std::string target_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; |
| 3580 | + |
| 3581 | + // Search through metadata_props for the matching key |
| 3582 | + for (const auto& prop : model_proto.metadata_props()) { |
| 3583 | + if (prop.key() == target_key) { |
| 3584 | + // Found it - allocate and copy the value using the provided allocator |
| 3585 | + *compatibility_info = onnxruntime::StrDup(prop.value(), allocator); |
| 3586 | + if (*compatibility_info == nullptr) { |
| 3587 | + return OrtApis::CreateStatus(ORT_FAIL, "Failed to allocate memory for compatibility info."); |
| 3588 | + } |
| 3589 | + return nullptr; |
| 3590 | + } |
| 3591 | + } |
| 3592 | + |
| 3593 | + // Key not found - return nullptr (not an error, just means no compat info for this EP) |
| 3594 | + *compatibility_info = nullptr; |
| 3595 | + return nullptr; |
| 3596 | +} |
| 3597 | + |
| 3598 | +// Extract EP compatibility info from a model file |
| 3599 | +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModel, |
| 3600 | + _In_ const ORTCHAR_T* model_path, |
| 3601 | + _In_ const char* ep_type, |
| 3602 | + _Inout_ OrtAllocator* allocator, |
| 3603 | + _Outptr_result_maybenull_ char** compatibility_info) { |
| 3604 | + API_IMPL_BEGIN |
| 3605 | + if (model_path == nullptr || ep_type == nullptr || ep_type[0] == '\0' || |
| 3606 | + allocator == nullptr || compatibility_info == nullptr) { |
| 3607 | + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, |
| 3608 | + "Invalid argument provided to GetCompatibilityInfoFromModel."); |
| 3609 | + } |
| 3610 | + |
| 3611 | + *compatibility_info = nullptr; |
| 3612 | + |
| 3613 | + // Use Model::Load for proper cross-platform path handling via file descriptor |
| 3614 | + ONNX_NAMESPACE::ModelProto model_proto; |
| 3615 | + auto status = Model::Load(PathString(model_path), model_proto); |
| 3616 | + if (!status.IsOK()) { |
| 3617 | + if (status.Code() == common::NO_SUCHFILE) { |
| 3618 | + return OrtApis::CreateStatus(ORT_NO_SUCHFILE, status.ErrorMessage().c_str()); |
| 3619 | + } |
| 3620 | + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, status.ErrorMessage().c_str()); |
| 3621 | + } |
| 3622 | + |
| 3623 | + return ExtractCompatibilityInfoFromModelProto(model_proto, ep_type, allocator, compatibility_info); |
| 3624 | + API_IMPL_END |
| 3625 | +} |
| 3626 | + |
| 3627 | +// Extract EP compatibility info from model bytes in memory |
| 3628 | +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModelBytes, |
| 3629 | + _In_reads_(model_data_length) const void* model_data, |
| 3630 | + _In_ size_t model_data_length, |
| 3631 | + _In_ const char* ep_type, |
| 3632 | + _Inout_ OrtAllocator* allocator, |
| 3633 | + _Outptr_result_maybenull_ char** compatibility_info) { |
| 3634 | + API_IMPL_BEGIN |
| 3635 | + if (model_data == nullptr || model_data_length == 0 || ep_type == nullptr || ep_type[0] == '\0' || |
| 3636 | + allocator == nullptr || compatibility_info == nullptr) { |
| 3637 | + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, |
| 3638 | + "Invalid argument provided to GetCompatibilityInfoFromModelBytes."); |
| 3639 | + } |
| 3640 | + |
| 3641 | + *compatibility_info = nullptr; |
| 3642 | + |
| 3643 | + // Explicit check for size limit - Model::LoadFromBytes uses int for size due to protobuf API |
| 3644 | + if (model_data_length > static_cast<size_t>(INT_MAX)) { |
| 3645 | + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, |
| 3646 | + "Model data size exceeds maximum supported size (2GB). Use GetCompatibilityInfoFromModel with a file path instead."); |
| 3647 | + } |
| 3648 | + |
| 3649 | + ONNX_NAMESPACE::ModelProto model_proto; |
| 3650 | + auto status = Model::LoadFromBytes(static_cast<int>(model_data_length), model_data, model_proto); |
| 3651 | + if (!status.IsOK()) { |
| 3652 | + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, status.ErrorMessage().c_str()); |
| 3653 | + } |
| 3654 | + |
| 3655 | + return ExtractCompatibilityInfoFromModelProto(model_proto, ep_type, allocator, compatibility_info); |
| 3656 | + API_IMPL_END |
| 3657 | +} |
| 3658 | + |
3568 | 3659 | // GetInteropApi - returns the Interop API struct |
3569 | 3660 | ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { |
3570 | 3661 | return OrtInteropAPI::GetInteropApi(); |
@@ -3603,6 +3694,29 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, |
3603 | 3694 | API_IMPL_END |
3604 | 3695 | } |
3605 | 3696 |
|
| 3697 | +// Minimal build stub for GetCompatibilityInfoFromModel |
| 3698 | +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModel, |
| 3699 | + _In_ const ORTCHAR_T* /*model_path*/, |
| 3700 | + _In_ const char* /*ep_type*/, |
| 3701 | + _Inout_ OrtAllocator* /*allocator*/, |
| 3702 | + _Outptr_result_maybenull_ char** /*compatibility_info*/) { |
| 3703 | + API_IMPL_BEGIN |
| 3704 | + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetCompatibilityInfoFromModel is not supported in a minimal build."); |
| 3705 | + API_IMPL_END |
| 3706 | +} |
| 3707 | + |
| 3708 | +// Minimal build stub for GetCompatibilityInfoFromModelBytes |
| 3709 | +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModelBytes, |
| 3710 | + _In_reads_(model_data_length) const void* /*model_data*/, |
| 3711 | + _In_ size_t /*model_data_length*/, |
| 3712 | + _In_ const char* /*ep_type*/, |
| 3713 | + _Inout_ OrtAllocator* /*allocator*/, |
| 3714 | + _Outptr_result_maybenull_ char** /*compatibility_info*/) { |
| 3715 | + API_IMPL_BEGIN |
| 3716 | + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetCompatibilityInfoFromModelBytes is not supported in a minimal build."); |
| 3717 | + API_IMPL_END |
| 3718 | +} |
| 3719 | + |
3606 | 3720 | ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOptions* /*session_options*/, |
3607 | 3721 | _In_ OrtEnv* /*env*/, |
3608 | 3722 | _In_reads_(num_ep_devices) const OrtEpDevice* const* /*ep_devices*/, |
@@ -4298,6 +4412,8 @@ static constexpr OrtApi ort_api_1_to_24 = { |
4298 | 4412 | &OrtApis::DeviceEpIncompatibilityDetails_GetNotes, |
4299 | 4413 | &OrtApis::DeviceEpIncompatibilityDetails_GetErrorCode, |
4300 | 4414 | &OrtApis::ReleaseDeviceEpIncompatibilityDetails, |
| 4415 | + &OrtApis::GetCompatibilityInfoFromModel, |
| 4416 | + &OrtApis::GetCompatibilityInfoFromModelBytes, |
4301 | 4417 | }; |
4302 | 4418 |
|
4303 | 4419 | // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. |
|
0 commit comments