Skip to content

Commit 55841b4

Browse files
author
Aditya Rastogi
committed
Fixes / improvements based on testing and code analysis
1 parent 915dd0e commit 55841b4

File tree

4 files changed

+136
-21
lines changed

4 files changed

+136
-21
lines changed

onnxruntime/core/session/model_compilation_options.cc

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "core/common/path_string.h"
1313
#include "core/framework/allocator.h"
1414
#include "core/framework/ep_context_options.h"
15+
#include "core/platform/env.h"
16+
#include "core/session/inference_session_utils.h"
1517
#include "core/session/onnxruntime_session_options_config_keys.h"
1618
#include "core/session/environment.h"
1719

@@ -289,11 +291,30 @@ Status ModelCompilationOptions::Check() const {
289291
// Model::LoadFromModelEditorApiModel -> Graph::Resolve()
290292
}
291293

294+
// ORT_LOAD_CONFIG_FROM_MODEL is not supported for OrtModel input.
295+
// Check early so we fail before session construction.
296+
if (input_from_model) {
297+
const Env& os_env = Env::Default();
298+
if (os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1") {
299+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
300+
"The environment variable ORT_LOAD_CONFIG_FROM_MODEL=1 is set, but loading "
301+
"config from model is not supported for in-memory OrtModel input. "
302+
"OrtModel is programmatically constructed and has no embedded ORT config. "
303+
"Unset ORT_LOAD_CONFIG_FROM_MODEL or use file/buffer input instead.");
304+
}
305+
}
306+
292307
// Check output model settings.
293308
const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
294309
bool has_no_output_model_location = std::holds_alternative<std::monostate>(
295310
ep_context_gen_options.output_model_location);
296311

312+
// Also treat an empty output file path as "no location" since it's not usable.
313+
const auto* output_path = ep_context_gen_options.TryGetOutputModelPath();
314+
if (!has_no_output_model_location && output_path != nullptr && output_path->empty()) {
315+
has_no_output_model_location = true;
316+
}
317+
297318
// Determine if we can derive an output path from the input
298319
bool can_derive_output_path = input_from_file;
299320
bool model_has_path = false;
@@ -315,9 +336,8 @@ Status ModelCompilationOptions::Check() const {
315336
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
316337
"OrtModel has no model_path set and no output location was specified. "
317338
"Please either: (1) set the model_path on the OrtGraph before adding to OrtModel, "
318-
"(2) call SetOutputModelPath/SetOutputModelBuffer to specify an output location, "
319-
"(3) call SetEpContextEmbedMode(true) to embed EP context in the model, or "
320-
"(4) call SetEpContextBinaryInformation to specify the binary output directory.");
339+
"(2) call SetOutputModelPath/SetOutputModelBuffer to specify an output location, or "
340+
"(3) call SetEpContextEmbedMode(true) to embed EP context in the model.");
321341
}
322342

323343
if (has_no_output_model_location && can_derive_output_path) {
@@ -360,7 +380,13 @@ Status ModelCompilationOptions::Check() const {
360380
}
361381

362382
std::string ModelCompilationOptions::GetInputSourceForTelemetry() const {
363-
return InputModelComesFromFile() ? "file" : "buffer";
383+
if (InputModelComesFromFile()) {
384+
return "file";
385+
}
386+
if (InputModelComesFromOrtModel()) {
387+
return "ort_model";
388+
}
389+
return "buffer";
364390
}
365391

366392
std::string ModelCompilationOptions::GetOutputTargetForTelemetry() const {

onnxruntime/core/session/model_compilation_options.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ class ModelCompilationOptions {
183183
// Telemetry helper methods
184184

185185
/// <summary>
186-
/// Returns a string describing the input source type: "file" or "buffer".
186+
/// Returns a string describing the input source type: "file", "buffer", or "ort_model".
187187
/// </summary>
188-
/// <returns>"file" or "buffer"</returns>
188+
/// <returns>"file", "buffer", or "ort_model"</returns>
189189
std::string GetInputSourceForTelemetry() const;
190190

191191
/// <summary>

onnxruntime/core/session/utils.cc

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,6 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op
301301
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtModel pointer is null");
302302
}
303303

304-
const Env& os_env = Env::Default();
305-
bool load_config_from_model =
306-
os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1";
307-
308304
// Check EPContext model generation options - OrtModel has no file path by default,
309305
// so we need explicit output location or embedded model path.
310306
if (options) {
@@ -325,22 +321,14 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op
325321
(!ep_ctx_gen_options.HasOutputModelLocation() ||
326322
(output_model_path != nullptr && output_model_path->empty()))) {
327323
return OrtApis::CreateStatus(ORT_FAIL,
328-
"OrtModel loaded without a model_path was configured with EPContext "
329-
"model generation enabled but without a valid location (e.g., file or buffer) "
330-
"for the output model. Please specify a valid output location via "
324+
"OrtModel has no model_path set and no valid output location was specified "
325+
"for EPContext model generation. "
331326
"SetOutputModelPath/SetOutputModelBuffer, or set the model_path on the "
332-
"OrtGraph before adding to OrtModel.");
327+
"OrtGraph before adding it to OrtModel.");
333328
}
334329
}
335330
}
336331

337-
// Note: load_config_from_model is not applicable for OrtModel since there's no serialized
338-
// model to load config from. We treat this as a regular load.
339-
if (load_config_from_model) {
340-
// Log a warning but proceed - OrtModel doesn't support loading config from model
341-
// as it's already in-memory and constructed programmatically
342-
}
343-
344332
sess = std::make_unique<onnxruntime::InferenceSession>(
345333
options == nullptr ? onnxruntime::SessionOptions() : options->value,
346334
env);
@@ -352,6 +340,30 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op
352340
}
353341
#endif
354342

343+
// Add custom domains for all OrtEpDevice instances to inference session.
344+
// The custom domains should be registered before model load for ORT to validate the custom ops.
345+
// This mirrors the same block in the file/buffer overload to maintain load-path parity.
346+
if (options != nullptr &&
347+
options->provider_factories.empty() &&
348+
options->value.ep_selection_policy.enable) {
349+
InlinedVector<OrtCustomOpDomain*> all_ep_custom_op_domains;
350+
351+
for (const OrtEpDevice* ep_device : env.GetOrtEpDevices()) {
352+
InlinedVector<OrtCustomOpDomain*> domains;
353+
ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains));
354+
355+
for (auto domain : domains) {
356+
if (ShouldAddDomain(domain, options->custom_op_domains_)) {
357+
all_ep_custom_op_domains.push_back(domain);
358+
}
359+
}
360+
}
361+
362+
if (!all_ep_custom_op_domains.empty()) {
363+
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains));
364+
}
365+
}
366+
355367
// Load from OrtModel
356368
ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model));
357369

onnxruntime/test/shared_lib/test_model_builder_api.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "test/shared_lib/test_fixture.h"
2020
#include "test/shared_lib/utils.h"
21+
#include "test/util/include/scoped_env_vars.h"
2122
#include "test/util/include/test_allocator.h"
2223

2324
#include "onnxruntime_config.h" // generated file in build output dir
@@ -1014,3 +1015,79 @@ TEST(ModelEditorCompileAPITest, CompileFromOrtModelToFile) {
10141015
// Cleanup
10151016
std::filesystem::remove(output_path);
10161017
}
1018+
1019+
// Test: ORT_LOAD_CONFIG_FROM_MODEL=1 with OrtModel input fails fast with a clear,
1020+
// actionable error message.
1021+
TEST(ModelEditorCompileAPITest, LoadConfigFromModelEnvVarFailsForOrtModel) {
1022+
// RAII helper saves the current env var value and restores it when the scope exits.
1023+
onnxruntime::test::ScopedEnvironmentVariables scoped_env(
1024+
{{"ORT_LOAD_CONFIG_FROM_MODEL", "1"}});
1025+
1026+
std::vector<std::unique_ptr<std::vector<float>>> weights;
1027+
auto model = CreateSimpleGemmModel(weights);
1028+
1029+
Ort::SessionOptions session_options;
1030+
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
1031+
compile_options.SetInputModel(static_cast<const OrtModel*>(model));
1032+
compile_options.SetEpContextEmbedMode(true);
1033+
1034+
std::unique_ptr<MockedOrtAllocator> allocator = std::make_unique<MockedOrtAllocator>();
1035+
void* output_buffer = nullptr;
1036+
size_t output_size = 0;
1037+
compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size);
1038+
1039+
// Should fail with a clear error about the unsupported env-var/input combination.
1040+
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
1041+
EXPECT_FALSE(status.IsOK());
1042+
EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("ORT_LOAD_CONFIG_FROM_MODEL=1"));
1043+
EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("in-memory OrtModel input"));
1044+
EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("Unset ORT_LOAD_CONFIG_FROM_MODEL"));
1045+
1046+
if (output_buffer != nullptr) {
1047+
allocator->Free(output_buffer);
1048+
}
1049+
}
1050+
1051+
// Test: Validation error for OrtModel with no model_path, no output location, and no embed mode.
1052+
// Verifies the error message contains the expected remediation guidance.
1053+
TEST(ModelEditorCompileAPITest, NoOutputLocationNoModelPathFails) {
1054+
std::vector<std::unique_ptr<std::vector<float>>> weights;
1055+
auto model = CreateSimpleGemmModel(weights);
1056+
1057+
Ort::SessionOptions session_options;
1058+
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
1059+
compile_options.SetInputModel(static_cast<const OrtModel*>(model));
1060+
// Intentionally do NOT call SetEpContextEmbedMode, SetOutputModelPath, or SetOutputModelBuffer
1061+
1062+
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
1063+
EXPECT_FALSE(status.IsOK());
1064+
// Should suggest setting output location or model_path, including embed mode as an option
1065+
EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("SetOutputModelPath"));
1066+
EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("SetEpContextEmbedMode(true)"));
1067+
// Should NOT suggest SetEpContextBinaryInformation (that alone is not sufficient)
1068+
EXPECT_THAT(status.GetErrorMessage(), ::testing::Not(::testing::HasSubstr("SetEpContextBinaryInformation")));
1069+
}
1070+
1071+
// Test: Setting embed mode with buffer output satisfies the output location requirement
1072+
// for OrtModel with no model_path.
1073+
TEST(ModelEditorCompileAPITest, EmbedModeWithBufferOutputSatisfiesValidation) {
1074+
std::vector<std::unique_ptr<std::vector<float>>> weights;
1075+
auto model = CreateSimpleGemmModel(weights);
1076+
1077+
Ort::SessionOptions session_options;
1078+
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
1079+
compile_options.SetInputModel(static_cast<const OrtModel*>(model));
1080+
compile_options.SetEpContextEmbedMode(true);
1081+
1082+
std::unique_ptr<MockedOrtAllocator> allocator = std::make_unique<MockedOrtAllocator>();
1083+
void* output_buffer = nullptr;
1084+
size_t output_size = 0;
1085+
compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size);
1086+
1087+
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
1088+
EXPECT_TRUE(status.IsOK()) << "CompileModel failed: " << status.GetErrorMessage();
1089+
1090+
if (output_buffer != nullptr) {
1091+
allocator->Free(output_buffer);
1092+
}
1093+
}

0 commit comments

Comments
 (0)