diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 84020d84c9e73..00ca25d0a6367 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -25,6 +25,7 @@ public struct OrtCompileApi public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel; public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc; public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; + public IntPtr ModelCompilationOptions_SetInputModel; } internal class NativeMethods @@ -136,6 +137,12 @@ public DOrtModelCompilationOptions_SetOutputModelWriteFunc public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModel( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* const OrtModel* */ inputModel); + public DOrtModelCompilationOptions_SetInputModel OrtModelCompilationOptions_SetInputModel; + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) { @@ -217,6 +224,11 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi _compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)); + OrtModelCompilationOptions_SetInputModel = + (DOrtModelCompilationOptions_SetInputModel)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModel, + typeof(DOrtModelCompilationOptions_SetInputModel)); + } } } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 221f3673f2027..5c80ed131f4c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -8037,6 +8037,29 @@ struct OrtCompileApi { ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, _In_ OrtModelCompilationOptions* model_compile_options, _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); + + /** \brief Sets the OrtModel to compile. + * + * Sets an OrtModel created via the Model Editor API as the input for compilation. + * + * The input model's source (file path, memory buffer, or OrtModel) must be set with + * one of: ModelCompilationOptions_SetInputModelPath, ModelCompilationOptions_SetInputModelFromBuffer, + * or ModelCompilationOptions_SetInputModel. + * + * The OrtModel must have a complete graph with inputs, outputs, and nodes defined. + * The caller retains ownership of the OrtModel and must not release it until after + * CompileModel returns. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] model The OrtModel to compile. The model is borrowed (not copied or owned). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* model); }; /** diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 2c1d52894e7f3..8dae24a3bffe7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1612,6 +1612,8 @@ struct ModelCompilationOptions : detail::Base { ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel + + ModelCompilationOptions& SetInputModel(const OrtModel* model); ///< Wraps OrtCompileApi::ModelCompilationOptions_SetInputModel }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 745128fe6c7b4..bce2aa97d47cd 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1180,6 +1180,11 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLev return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const OrtModel* model) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModel(this->p_, model)); + return *this; +} + namespace detail { template diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 12127e9708255..54d26021d8c99 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -306,6 +306,27 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationL API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ const OrtModel* model) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: OrtModel pointer is null"); + } + + model_compile_options->SetInputModel(model); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(model); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* ort_model_compile_options) { API_IMPL_BEGIN @@ -343,6 +364,9 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, &OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, // End of Version 23 - DO NOT MODIFY ABOVE + + &OrtCompileAPI::ModelCompilationOptions_SetInputModel, + // End of Version 24 - DO NOT MODIFY ABOVE }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned @@ -350,6 +374,8 @@ static_assert(offsetof(OrtCompileApi, CompileModel) / sizeof(void*) == 8, "Size of version 22 Api cannot change"); // initial version in ORT 1.22 static_assert(offsetof(OrtCompileApi, ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc) / sizeof(void*) == 13, "Size of version 23 of Api cannot change"); +static_assert(offsetof(OrtCompileApi, ModelCompilationOptions_SetInputModel) / sizeof(void*) == 14, + "Size of version 24 of Api cannot change"); ORT_API(const OrtCompileApi*, OrtCompileAPI::GetCompileApi) { return &ort_compile_api; diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 34fa06340a7f9..e8f171ee24295 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -41,5 +41,8 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelWriteFunc, ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, _In_ OrtModelCompilationOptions* model_compile_options, _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* model); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 468dacc30c054..f161a1b4b6987 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -12,6 +12,8 @@ #include "core/common/path_string.h" #include "core/framework/allocator.h" #include "core/framework/ep_context_options.h" +#include "core/platform/env.h" +#include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -45,6 +47,11 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da input_model_data_size_ = input_model_data_size; } +void ModelCompilationOptions::SetInputModel(const OrtModel* model) { + ResetInputModelSettings(); + input_model_ = model; +} + Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) { ConfigOptions& config_options = session_options_.value.config_options; epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; @@ -186,10 +193,19 @@ size_t ModelCompilationOptions::GetInputModelDataSize() const { return input_model_data_size_; } +bool ModelCompilationOptions::InputModelComesFromOrtModel() const { + return input_model_ != nullptr; +} + +const OrtModel* ModelCompilationOptions::GetInputModel() const { + return input_model_; +} + void ModelCompilationOptions::ResetInputModelSettings() { input_model_path_.clear(); input_model_data_ = nullptr; input_model_data_size_ = 0; + input_model_ = nullptr; } Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { @@ -229,16 +245,21 @@ Status ModelCompilationOptions::Check() const { // Check input model settings. const bool input_from_file = !input_model_path_.empty(); const bool input_from_memory = input_model_data_ != nullptr; + const bool input_from_model = input_model_ != nullptr; + + int input_source_count = (input_from_file ? 1 : 0) + + (input_from_memory ? 1 : 0) + + (input_from_model ? 1 : 0); - if (!input_from_file && !input_from_memory) { + if (input_source_count == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer"); + "Input model to compile must be specified via file path, memory buffer, or OrtModel"); } - if (input_from_file && input_from_memory) { + if (input_source_count > 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer, ", - "but not both."); + "Input model to compile must be specified via exactly one of: ", + "file path, memory buffer, or OrtModel"); } if (input_from_file && !std::filesystem::exists(input_model_path_)) { @@ -249,12 +270,77 @@ Status ModelCompilationOptions::Check() const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); } + // Validate OrtModel input + if (input_from_model) { + if (input_model_->graph == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "OrtModel has no graph. Call AddGraphToModel before compilation."); + } + + if (input_model_->graph->GetNumInputs() == 0 || input_model_->graph->GetNumOutputs() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "OrtModel graph must have at least one input and one output defined."); + } + + if (input_model_->domain_to_version.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "OrtModel must specify at least one opset domain/version."); + } + + // Note: Additional validation (node connections, schema) happens during + // Model::LoadFromModelEditorApiModel -> Graph::Resolve() + } + + // ORT_LOAD_CONFIG_FROM_MODEL is not supported for OrtModel input. + // Check early so we fail before session construction. + if (input_from_model) { + const Env& os_env = Env::Default(); + if (os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1") { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The environment variable ORT_LOAD_CONFIG_FROM_MODEL=1 is set, but loading " + "config from model is not supported for in-memory OrtModel input. " + "OrtModel is programmatically constructed and has no embedded ORT config. " + "Unset ORT_LOAD_CONFIG_FROM_MODEL or use file/buffer input instead."); + } + } + // Check output model settings. const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; bool has_no_output_model_location = std::holds_alternative( ep_context_gen_options.output_model_location); - if (has_no_output_model_location && input_from_file) { + // Also treat an empty output file path as "no location" since it's not usable. + const auto* output_path = ep_context_gen_options.TryGetOutputModelPath(); + if (!has_no_output_model_location && output_path != nullptr && output_path->empty()) { + has_no_output_model_location = true; + } + + // Determine if we can derive an output path from the input + bool can_derive_output_path = input_from_file; + bool model_has_path = false; + + // For OrtModel input, check if model_path is set in the graph using the virtual GetModelPath() method + // (avoids dynamic_cast which requires RTTI) + if (input_from_model && input_model_->graph) { + const ORTCHAR_T* model_path_cstr = input_model_->graph->GetModelPath(); + if (model_path_cstr && model_path_cstr[0] != ORT_TSTR('\0')) { + can_derive_output_path = true; + model_has_path = true; + } + } + + // Fast-fail: If OrtModel has no model_path and user hasn't specified output location or embed mode, + // EPs that need to write context binaries will fail later. Fail early with a clear error. + if (input_from_model && !model_has_path && has_no_output_model_location && + !ep_context_gen_options.embed_ep_context_in_model) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "OrtModel has no model_path set and no output location was specified. " + "Please either: (1) set the model_path on the OrtGraph before adding to OrtModel, " + "(2) call SetOutputModelPath/SetOutputModelBuffer to specify an output location, or " + "(3) call SetEpContextEmbedMode(true) to embed EP context in the model."); + } + + if (has_no_output_model_location && can_derive_output_path) { // User did not specify an output file, an output buffer, or an output write function. We default to generating an // output file with a name based on the input file name, so do not return an error. return Status::OK(); @@ -294,7 +380,13 @@ Status ModelCompilationOptions::Check() const { } std::string ModelCompilationOptions::GetInputSourceForTelemetry() const { - return InputModelComesFromFile() ? "file" : "buffer"; + if (InputModelComesFromFile()) { + return "file"; + } + if (InputModelComesFromOrtModel()) { + return "ort_model"; + } + return "buffer"; } std::string ModelCompilationOptions::GetOutputTargetForTelemetry() const { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 4ba8712a6c9c7..47529e794677e 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -10,6 +10,7 @@ #include "core/common/status.h" #include "core/common/path_string.h" #include "core/framework/allocator.h" +#include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -45,6 +46,14 @@ class ModelCompilationOptions { /// The size in bytes of the input model's buffer void SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size); + /// + /// Sets the OrtModel to compile. + /// The OrtModel is borrowed (not copied) - caller must keep it alive until CompileModel returns. + /// Overrides any previous call to SetInputModelPath(), SetInputModelFromBuffer(), or SetInputModel(). + /// + /// The OrtModel to compile + void SetInputModel(const OrtModel* model); + /// /// Sets the file path to store the output/compiled ONNX model. /// Overrides any previous call to SetOutputModelPath() or SetOutputModelBuffer(). @@ -132,6 +141,18 @@ class ModelCompilationOptions { /// true if input model comes from a file bool InputModelComesFromFile() const; + /// + /// Returns true if the input model comes from an OrtModel pointer. + /// + /// true if input model comes from an OrtModel + bool InputModelComesFromOrtModel() const; + + /// + /// Returns the OrtModel to compile, or nullptr if not set. + /// + /// pointer to the OrtModel or nullptr + const OrtModel* GetInputModel() const; + /// /// Returns the buffer that contains the bytes for the input ONNX model. /// Returns nullptr if the input model is not stored in a buffer. @@ -162,9 +183,9 @@ class ModelCompilationOptions { // Telemetry helper methods /// - /// Returns a string describing the input source type: "file" or "buffer". + /// Returns a string describing the input source type: "file", "buffer", or "ort_model". /// - /// "file" or "buffer" + /// "file", "buffer", or "ort_model" std::string GetInputSourceForTelemetry() const; /// @@ -205,6 +226,7 @@ class ModelCompilationOptions { std::filesystem::path input_model_path_; const void* input_model_data_ = nullptr; size_t input_model_data_size_ = 0; + const OrtModel* input_model_ = nullptr; // Borrowed pointer }; } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index a354cf26368d4..677da43f970c2 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -22,6 +22,7 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #if !defined(ORT_MINIMAL_BUILD) +#include "core/graph/model_editor_api_types.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/plugin_ep/ep_library_plugin.h" @@ -288,6 +289,88 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op return nullptr; } +#if !defined(ORT_MINIMAL_BUILD) +// Overload of CreateSessionAndLoadModelImpl that takes an OrtModel* directly. +// This ensures load-path parity with file/buffer inputs by running the same checks +// (ORT_LOAD_CONFIG_FROM_MODEL, EP-context output validation, custom domain wiring). +static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* options, + const onnxruntime::Environment& env, + _In_ const OrtModel* model, + std::unique_ptr& sess) { + if (model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtModel pointer is null"); + } + + // Check EPContext model generation options - OrtModel has no file path by default, + // so we need explicit output location or embedded model path. + if (options) { + epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); + + if (ep_ctx_gen_options.enable) { + auto* output_model_path = ep_ctx_gen_options.TryGetOutputModelPath(); + + // Check if OrtModel has a model_path set + bool has_model_path = false; + if (model->graph) { + const ORTCHAR_T* model_path_cstr = model->graph->GetModelPath(); + has_model_path = model_path_cstr && model_path_cstr[0] != ORT_TSTR('\0'); + } + + // If there's no model path and no output location, fail early + if (!has_model_path && + (!ep_ctx_gen_options.HasOutputModelLocation() || + (output_model_path != nullptr && output_model_path->empty()))) { + return OrtApis::CreateStatus(ORT_FAIL, + "OrtModel has no model_path set and no valid output location was specified " + "for EPContext model generation. " + "SetOutputModelPath/SetOutputModelBuffer, or set the model_path on the " + "OrtGraph before adding it to OrtModel."); + } + } + } + + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env); + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Add custom domains + if (options && !options->custom_op_domains_.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + // Add custom domains for all OrtEpDevice instances to inference session. + // The custom domains should be registered before model load for ORT to validate the custom ops. + // This mirrors the same block in the file/buffer overload to maintain load-path parity. + if (options != nullptr && + options->provider_factories.empty() && + options->value.ep_selection_policy.enable) { + InlinedVector all_ep_custom_op_domains; + + for (const OrtEpDevice* ep_device : env.GetOrtEpDevices()) { + InlinedVector domains; + ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); + + for (auto domain : domains) { + if (ShouldAddDomain(domain, options->custom_op_domains_)) { + all_ep_custom_op_domains.push_back(domain); + } + } + } + + if (!all_ep_custom_op_domains.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains)); + } + } + + // Load from OrtModel + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); + + return nullptr; +} +#endif // !defined(ORT_MINIMAL_BUILD) + // Creates an InferenceSession and loads the model. // Caller should provide either model_path, or modal_data + model_data_length. OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, @@ -491,6 +574,12 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model status = ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, input_model_path.c_str(), nullptr, 0, session)); + } else if (model_compile_options.InputModelComesFromOrtModel()) { + // Use the OrtModel overload of CreateSessionAndLoadModelImpl to maintain load-path parity + // with file/buffer inputs (same checks for ORT_LOAD_CONFIG_FROM_MODEL, EP-context output, etc.) + const OrtModel* input_model = model_compile_options.GetInputModel(); + status = ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, + input_model, session)); } else { status = ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, nullptr, model_compile_options.GetInputModelData(), diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 018204bd1dfb0..b81ee9ac9b401 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -18,6 +18,7 @@ #include "test/shared_lib/test_fixture.h" #include "test/shared_lib/utils.h" +#include "test/util/include/scoped_env_vars.h" #include "test/util/include/test_allocator.h" #include "onnxruntime_config.h" // generated file in build output dir @@ -725,3 +726,369 @@ TEST(ModelEditorAPITest, CreateTypeInfo) { api.ReleaseTypeInfo(base_tensor_type_info); } + +// +// Tests for Model Editor API + Compile API integration +// + +namespace { +// Helper to create a simple model for testing with Model Editor API +// Creates a model with a Gemm operation: Z = X * Y where X is input and Y is initializer +Ort::Model CreateSimpleGemmModel(std::vector>>& weights) { + Ort::Graph graph; + + std::vector graph_inputs; + std::vector graph_outputs; + + // Input: X is 3x4 + std::vector input_dims({3, 4}); + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs.emplace_back("X", input_type_info.GetConst()); + + // Output: Z is 3x8 + std::vector output_dims = {3, 8}; + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + graph_outputs.emplace_back("Z", output_type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + // Gemm node with alpha=2.0 + std::vector attributes; + float alpha_value = 2.0; + attributes.push_back(OpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT)); + + Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attributes); + graph.AddNode(node); + + // Y initializer: 4x8 + std::vector y_dims = {4, 8}; + weights.emplace_back(std::make_unique>(32)); + auto& y_values = *weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + auto y_tensor = Value::CreateTensor(info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); + graph.AddInitializer("Y", y_tensor, /*data is external*/ true); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + return model; +} +} // namespace + +// Test basic compilation from OrtModel +TEST(ModelEditorCompileAPITest, BasicCompileFromOrtModel) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + // Set the OrtModel as input + compile_options.SetInputModel(static_cast(model)); + + // Set output to buffer - use embed mode for simplicity + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Compile should succeed (note: may not produce EPContext nodes without specific EP, but validation passes) + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_TRUE(status.IsOK()) << "CompileModel failed: " << status.GetErrorMessage(); + + // Verify output was produced + EXPECT_NE(output_buffer, nullptr); + EXPECT_GT(output_size, 0u); + + // Cleanup + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +} + +// Test validation: null model pointer +TEST(ModelEditorCompileAPITest, CompileFromNullModel_Fails) { + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + try { + compile_options.SetInputModel(nullptr); + FAIL() << "Expected exception for null model pointer"; + } catch (const Ort::Exception& e) { + EXPECT_THAT(e.what(), ::testing::HasSubstr("null")); + } +} + +// Test validation: model with no graph +TEST(ModelEditorCompileAPITest, CompileFromModelWithNoGraph_Fails) { + // Create a model but don't add a graph + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with no graph"; + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("graph")); +} + +// Test validation: model with empty inputs/outputs +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { + // Create a model with a graph that has no inputs or outputs + Ort::Graph graph; + // Don't set inputs or outputs + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with empty inputs/outputs"; + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("input")); +} + +// Test: model can be reused after compilation +TEST(ModelEditorCompileAPITest, ModelCanBeReusedAfterCompilation) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + // First compilation + { + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_TRUE(status.IsOK()) << "First CompileModel failed: " << status.GetErrorMessage(); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } + } + + // Second compilation with same model + { + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_TRUE(status.IsOK()) << "Second CompileModel failed: " << status.GetErrorMessage(); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } + } + + // Model should still be usable for creating a session + Ort::SessionOptions session_options; + Ort::Session session(*ort_env, model, session_options); + EXPECT_EQ(session.GetInputCount(), 1u); + EXPECT_EQ(session.GetOutputCount(), 1u); +} + +// Test: SetInputModel overrides previous input source (file path) +TEST(ModelEditorCompileAPITest, SetInputModelOverridesPreviousInputPath) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + // First set a file path (doesn't need to exist since we'll override it) + compile_options.SetInputModelPath(ORT_TSTR("nonexistent_file.onnx")); + + // Then override with OrtModel + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Should use the OrtModel, not the nonexistent file + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_TRUE(status.IsOK()) << "CompileModel failed: " << status.GetErrorMessage(); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +} + +// Test: SetInputModelPath overrides previous OrtModel setting +TEST(ModelEditorCompileAPITest, SetInputModelPathOverridesPreviousModel) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + // First set an OrtModel + compile_options.SetInputModel(static_cast(model)); + + // Then override with a real file path + compile_options.SetInputModelPath(ORT_TSTR("testdata/matmul_1.onnx")); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Should use the file path, not the OrtModel + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_TRUE(status.IsOK()) << "CompileModel failed: " << status.GetErrorMessage(); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +} + +// Test: Compile with output to file +TEST(ModelEditorCompileAPITest, CompileFromOrtModelToFile) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + auto output_path = ORT_TSTR("test_compile_from_ortmodel_output.onnx"); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetOutputModelPath(output_path); + compile_options.SetEpContextEmbedMode(true); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_TRUE(status.IsOK()) << "CompileModel failed: " << status.GetErrorMessage(); + + // Verify output file exists + EXPECT_TRUE(std::filesystem::exists(output_path)); + + // Verify the output model can be loaded + Ort::Session session(*ort_env, output_path, Ort::SessionOptions()); + EXPECT_GE(session.GetInputCount(), 1u); + EXPECT_GE(session.GetOutputCount(), 1u); + + // Cleanup + std::filesystem::remove(output_path); +} + +// Test: ORT_LOAD_CONFIG_FROM_MODEL=1 with OrtModel input fails fast with a clear, +// actionable error message. +TEST(ModelEditorCompileAPITest, LoadConfigFromModelEnvVarFailsForOrtModel) { + // RAII helper saves the current env var value and restores it when the scope exits. + onnxruntime::test::EnvVarMap env_vars; + env_vars["ORT_LOAD_CONFIG_FROM_MODEL"] = "1"; + onnxruntime::test::ScopedEnvironmentVariables scoped_env(env_vars); + + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Should fail with a clear error about the unsupported env-var/input combination. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("ORT_LOAD_CONFIG_FROM_MODEL=1")); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("in-memory OrtModel input")); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("Unset ORT_LOAD_CONFIG_FROM_MODEL")); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +} + +// Test: Validation error for OrtModel with no model_path, no output location, and no embed mode. +// Verifies the error message contains the expected remediation guidance. +TEST(ModelEditorCompileAPITest, NoOutputLocationNoModelPathFails) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + // Intentionally do NOT call SetEpContextEmbedMode, SetOutputModelPath, or SetOutputModelBuffer + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_FALSE(status.IsOK()); + // Should suggest setting output location or model_path, including embed mode as an option + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("SetOutputModelPath")); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("SetEpContextEmbedMode(true)")); + // Should NOT suggest SetEpContextBinaryInformation (that alone is not sufficient) + EXPECT_THAT(status.GetErrorMessage(), ::testing::Not(::testing::HasSubstr("SetEpContextBinaryInformation"))); +} + +// Test: Setting embed mode with buffer output satisfies the output location requirement +// for OrtModel with no model_path. +TEST(ModelEditorCompileAPITest, EmbedModeWithBufferOutputSatisfiesValidation) { + std::vector>> weights; + auto model = CreateSimpleGemmModel(weights); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + EXPECT_TRUE(status.IsOK()) << "CompileModel failed: " << status.GetErrorMessage(); + + if (output_buffer != nullptr) { + allocator->Free(output_buffer); + } +}