-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Add OrtModel input support for Compile API #27332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,8 @@ | |
| #include "core/common/path_string.h" | ||
| #include "core/framework/allocator.h" | ||
| #include "core/framework/ep_context_options.h" | ||
| #include "core/platform/env.h" | ||
| #include "core/session/inference_session_utils.h" | ||
| #include "core/session/onnxruntime_session_options_config_keys.h" | ||
| #include "core/session/environment.h" | ||
|
|
||
|
|
@@ -45,6 +47,11 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da | |
| input_model_data_size_ = input_model_data_size; | ||
| } | ||
|
|
||
| void ModelCompilationOptions::SetInputModel(const OrtModel* model) { | ||
| ResetInputModelSettings(); | ||
| input_model_ = model; | ||
| } | ||
|
|
||
| Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) { | ||
| ConfigOptions& config_options = session_options_.value.config_options; | ||
| epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; | ||
|
|
@@ -186,10 +193,19 @@ size_t ModelCompilationOptions::GetInputModelDataSize() const { | |
| return input_model_data_size_; | ||
| } | ||
|
|
||
| bool ModelCompilationOptions::InputModelComesFromOrtModel() const { | ||
| return input_model_ != nullptr; | ||
| } | ||
|
|
||
| const OrtModel* ModelCompilationOptions::GetInputModel() const { | ||
| return input_model_; | ||
| } | ||
|
|
||
| void ModelCompilationOptions::ResetInputModelSettings() { | ||
| input_model_path_.clear(); | ||
| input_model_data_ = nullptr; | ||
| input_model_data_size_ = 0; | ||
| input_model_ = nullptr; | ||
| } | ||
|
|
||
| Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { | ||
|
|
@@ -229,16 +245,21 @@ Status ModelCompilationOptions::Check() const { | |
| // Check input model settings. | ||
| const bool input_from_file = !input_model_path_.empty(); | ||
| const bool input_from_memory = input_model_data_ != nullptr; | ||
| const bool input_from_model = input_model_ != nullptr; | ||
|
|
||
| int input_source_count = (input_from_file ? 1 : 0) + | ||
| (input_from_memory ? 1 : 0) + | ||
| (input_from_model ? 1 : 0); | ||
|
|
||
| if (!input_from_file && !input_from_memory) { | ||
| if (input_source_count == 0) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "Input model to compile must be loaded from either a file or a memory buffer"); | ||
| "Input model to compile must be specified via file path, memory buffer, or OrtModel"); | ||
| } | ||
|
|
||
| if (input_from_file && input_from_memory) { | ||
| if (input_source_count > 1) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "Input model to compile must be loaded from either a file or a memory buffer, ", | ||
| "but not both."); | ||
| "Input model to compile must be specified via exactly one of: ", | ||
| "file path, memory buffer, or OrtModel"); | ||
| } | ||
|
|
||
| if (input_from_file && !std::filesystem::exists(input_model_path_)) { | ||
|
|
@@ -249,12 +270,77 @@ Status ModelCompilationOptions::Check() const { | |
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); | ||
| } | ||
|
|
||
| // Validate OrtModel input | ||
| if (input_from_model) { | ||
| if (input_model_->graph == nullptr) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "OrtModel has no graph. Call AddGraphToModel before compilation."); | ||
| } | ||
|
|
||
| if (input_model_->graph->GetNumInputs() == 0 || input_model_->graph->GetNumOutputs() == 0) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "OrtModel graph must have at least one input and one output defined."); | ||
| } | ||
|
|
||
| if (input_model_->domain_to_version.empty()) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "OrtModel must specify at least one opset domain/version."); | ||
| } | ||
|
|
||
| // Note: Additional validation (node connections, schema) happens during | ||
| // Model::LoadFromModelEditorApiModel -> Graph::Resolve() | ||
| } | ||
|
|
||
| // ORT_LOAD_CONFIG_FROM_MODEL is not supported for OrtModel input. | ||
| // Check early so we fail before session construction. | ||
| if (input_from_model) { | ||
| const Env& os_env = Env::Default(); | ||
| if (os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1") { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "The environment variable ORT_LOAD_CONFIG_FROM_MODEL=1 is set, but loading " | ||
| "config from model is not supported for in-memory OrtModel input. " | ||
| "OrtModel is programmatically constructed and has no embedded ORT config. " | ||
| "Unset ORT_LOAD_CONFIG_FROM_MODEL or use file/buffer input instead."); | ||
| } | ||
| } | ||
|
Comment on lines
+294
to
+305
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: not sure this is worth the extra code here given (afaik) reading the config from the model is a very niche usage. |
||
|
|
||
| // Check output model settings. | ||
| const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; | ||
| bool has_no_output_model_location = std::holds_alternative<std::monostate>( | ||
| ep_context_gen_options.output_model_location); | ||
|
|
||
| if (has_no_output_model_location && input_from_file) { | ||
| // Also treat an empty output file path as "no location" since it's not usable. | ||
| const auto* output_path = ep_context_gen_options.TryGetOutputModelPath(); | ||
| if (!has_no_output_model_location && output_path != nullptr && output_path->empty()) { | ||
| has_no_output_model_location = true; | ||
| } | ||
|
Comment on lines
+313
to
+316
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we instead prevent an empty path being set? |
||
|
|
||
| // Determine if we can derive an output path from the input | ||
| bool can_derive_output_path = input_from_file; | ||
| bool model_has_path = false; | ||
|
|
||
| // For OrtModel input, check if model_path is set in the graph using the virtual GetModelPath() method | ||
| // (avoids dynamic_cast which requires RTTI) | ||
| if (input_from_model && input_model_->graph) { | ||
| const ORTCHAR_T* model_path_cstr = input_model_->graph->GetModelPath(); | ||
| if (model_path_cstr && model_path_cstr[0] != ORT_TSTR('\0')) { | ||
| can_derive_output_path = true; | ||
| model_has_path = true; | ||
| } | ||
| } | ||
|
|
||
| // Fast-fail: If OrtModel has no model_path and user hasn't specified output location or embed mode, | ||
| // EPs that need to write context binaries will fail later. Fail early with a clear error. | ||
| if (input_from_model && !model_has_path && has_no_output_model_location && | ||
| !ep_context_gen_options.embed_ep_context_in_model) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "OrtModel has no model_path set and no output location was specified. " | ||
| "Please either: (1) set the model_path on the OrtGraph before adding to OrtModel, " | ||
| "(2) call SetOutputModelPath/SetOutputModelBuffer to specify an output location, or " | ||
| "(3) call SetEpContextEmbedMode(true) to embed EP context in the model."); | ||
| } | ||
|
Comment on lines
+332
to
+341
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it worth duplicating code for this? Not sure we need a 'fast fail' path if the issue is a usage error as that should be caught and fixed during development. |
||
|
|
||
| if (has_no_output_model_location && can_derive_output_path) { | ||
| // User did not specify an output file, an output buffer, or an output write function. We default to generating an | ||
| // output file with a name based on the input file name, so do not return an error. | ||
| return Status::OK(); | ||
|
|
@@ -294,7 +380,13 @@ Status ModelCompilationOptions::Check() const { | |
| } | ||
|
|
||
| std::string ModelCompilationOptions::GetInputSourceForTelemetry() const { | ||
| return InputModelComesFromFile() ? "file" : "buffer"; | ||
| if (InputModelComesFromFile()) { | ||
| return "file"; | ||
| } | ||
| if (InputModelComesFromOrtModel()) { | ||
| return "ort_model"; | ||
| } | ||
| return "buffer"; | ||
| } | ||
|
|
||
| std::string ModelCompilationOptions::GetOutputTargetForTelemetry() const { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debating whether or not we need this, ultimately... I realized that there aren't managed versions of the model editor APIs. But since I am touching the API table, it seemed mandatory(?).