Skip to content

Commit 915dd0e

Browse files
author
Aditya Rastogi
committed
Initial draft
1 parent b3a34bb commit 915dd0e

File tree

10 files changed

+531
-6
lines changed

10 files changed

+531
-6
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public struct OrtCompileApi
2525
public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel;
2626
public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc;
2727
public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc;
28+
public IntPtr ModelCompilationOptions_SetInputModel;
2829
}
2930

3031
internal class NativeMethods
@@ -136,6 +137,12 @@ public DOrtModelCompilationOptions_SetOutputModelWriteFunc
136137
public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc
137138
OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc;
138139

140+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
141+
public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModel(
142+
IntPtr /* OrtModelCompilationOptions* */ options,
143+
IntPtr /* const OrtModel* */ inputModel);
144+
public DOrtModelCompilationOptions_SetInputModel OrtModelCompilationOptions_SetInputModel;
145+
139146
internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi)
140147
{
141148

@@ -217,6 +224,11 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi
217224
_compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
218225
typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc));
219226

227+
OrtModelCompilationOptions_SetInputModel =
228+
(DOrtModelCompilationOptions_SetInputModel)Marshal.GetDelegateForFunctionPointer(
229+
_compileApi.ModelCompilationOptions_SetInputModel,
230+
typeof(DOrtModelCompilationOptions_SetInputModel));
231+
220232
}
221233
}
222234
}

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8037,6 +8037,29 @@ struct OrtCompileApi {
80378037
ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
80388038
_In_ OrtModelCompilationOptions* model_compile_options,
80398039
_In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state);
8040+
8041+
/** \brief Sets the OrtModel to compile.
8042+
*
8043+
* Sets an OrtModel created via the Model Editor API as the input for compilation.
8044+
*
8045+
* The input model's source (file path, memory buffer, or OrtModel) must be set with
8046+
* one of: ModelCompilationOptions_SetInputModelPath, ModelCompilationOptions_SetInputModelFromBuffer,
8047+
* or ModelCompilationOptions_SetInputModel.
8048+
*
8049+
* The OrtModel must have a complete graph with inputs, outputs, and nodes defined.
8050+
* The caller retains ownership of the OrtModel and must not release it until after
8051+
* CompileModel returns.
8052+
*
8053+
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
8054+
* \param[in] model The OrtModel to compile. The model is borrowed (not copied or owned).
8055+
*
8056+
* \snippet{doc} snippets.dox OrtStatus Return Value
8057+
*
8058+
* \since Version 1.24.
8059+
*/
8060+
ORT_API2_STATUS(ModelCompilationOptions_SetInputModel,
8061+
_In_ OrtModelCompilationOptions* model_compile_options,
8062+
_In_ const OrtModel* model);
80408063
};
80418064

80428065
/**

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,8 @@ struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
16121612
ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
16131613

16141614
ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel
1615+
1616+
ModelCompilationOptions& SetInputModel(const OrtModel* model); ///< Wraps OrtCompileApi::ModelCompilationOptions_SetInputModel
16151617
};
16161618

16171619
/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,11 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLev
11801180
return *this;
11811181
}
11821182

1183+
inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const OrtModel* model) {
1184+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModel(this->p_, model));
1185+
return *this;
1186+
}
1187+
11831188
namespace detail {
11841189

11851190
template <typename T>

onnxruntime/core/session/compile_api.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,27 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationL
306306
API_IMPL_END
307307
}
308308

309+
ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModel,
310+
_In_ OrtModelCompilationOptions* ort_model_compile_options,
311+
_In_ const OrtModel* model) {
312+
API_IMPL_BEGIN
313+
#if !defined(ORT_MINIMAL_BUILD)
314+
auto model_compile_options = reinterpret_cast<onnxruntime::ModelCompilationOptions*>(ort_model_compile_options);
315+
316+
if (model == nullptr) {
317+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: OrtModel pointer is null");
318+
}
319+
320+
model_compile_options->SetInputModel(model);
321+
return nullptr;
322+
#else
323+
ORT_UNUSED_PARAMETER(ort_model_compile_options);
324+
ORT_UNUSED_PARAMETER(model);
325+
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
326+
#endif // !defined(ORT_MINIMAL_BUILD)
327+
API_IMPL_END
328+
}
329+
309330
ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env,
310331
_In_ const OrtModelCompilationOptions* ort_model_compile_options) {
311332
API_IMPL_BEGIN
@@ -343,13 +364,18 @@ static constexpr OrtCompileApi ort_compile_api = {
343364
&OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc,
344365
&OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
345366
// End of Version 23 - DO NOT MODIFY ABOVE
367+
368+
&OrtCompileAPI::ModelCompilationOptions_SetInputModel,
369+
// End of Version 24 - DO NOT MODIFY ABOVE
346370
};
347371

348372
// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
349373
static_assert(offsetof(OrtCompileApi, CompileModel) / sizeof(void*) == 8,
350374
"Size of version 22 Api cannot change"); // initial version in ORT 1.22
351375
static_assert(offsetof(OrtCompileApi, ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc) / sizeof(void*) == 13,
352376
"Size of version 23 of Api cannot change");
377+
static_assert(offsetof(OrtCompileApi, ModelCompilationOptions_SetInputModel) / sizeof(void*) == 14,
378+
"Size of version 24 of Api cannot change");
353379

354380
ORT_API(const OrtCompileApi*, OrtCompileAPI::GetCompileApi) {
355381
return &ort_compile_api;

onnxruntime/core/session/compile_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,8 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelWriteFunc,
4141
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
4242
_In_ OrtModelCompilationOptions* model_compile_options,
4343
_In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state);
44+
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel,
45+
_In_ OrtModelCompilationOptions* model_compile_options,
46+
_In_ const OrtModel* model);
4447

4548
} // namespace OrtCompileAPI

onnxruntime/core/session/model_compilation_options.cc

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da
4545
input_model_data_size_ = input_model_data_size;
4646
}
4747

48+
void ModelCompilationOptions::SetInputModel(const OrtModel* model) {
49+
ResetInputModelSettings();
50+
input_model_ = model;
51+
}
52+
4853
Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) {
4954
ConfigOptions& config_options = session_options_.value.config_options;
5055
epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
@@ -186,10 +191,19 @@ size_t ModelCompilationOptions::GetInputModelDataSize() const {
186191
return input_model_data_size_;
187192
}
188193

194+
bool ModelCompilationOptions::InputModelComesFromOrtModel() const {
195+
return input_model_ != nullptr;
196+
}
197+
198+
const OrtModel* ModelCompilationOptions::GetInputModel() const {
199+
return input_model_;
200+
}
201+
189202
void ModelCompilationOptions::ResetInputModelSettings() {
190203
input_model_path_.clear();
191204
input_model_data_ = nullptr;
192205
input_model_data_size_ = 0;
206+
input_model_ = nullptr;
193207
}
194208

195209
Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
@@ -229,16 +243,21 @@ Status ModelCompilationOptions::Check() const {
229243
// Check input model settings.
230244
const bool input_from_file = !input_model_path_.empty();
231245
const bool input_from_memory = input_model_data_ != nullptr;
246+
const bool input_from_model = input_model_ != nullptr;
247+
248+
int input_source_count = (input_from_file ? 1 : 0) +
249+
(input_from_memory ? 1 : 0) +
250+
(input_from_model ? 1 : 0);
232251

233-
if (!input_from_file && !input_from_memory) {
252+
if (input_source_count == 0) {
234253
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
235-
"Input model to compile must be loaded from either a file or a memory buffer");
254+
"Input model to compile must be specified via file path, memory buffer, or OrtModel");
236255
}
237256

238-
if (input_from_file && input_from_memory) {
257+
if (input_source_count > 1) {
239258
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
240-
"Input model to compile must be loaded from either a file or a memory buffer, ",
241-
"but not both.");
259+
"Input model to compile must be specified via exactly one of: ",
260+
"file path, memory buffer, or OrtModel");
242261
}
243262

244263
if (input_from_file && !std::filesystem::exists(input_model_path_)) {
@@ -249,12 +268,59 @@ Status ModelCompilationOptions::Check() const {
249268
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0");
250269
}
251270

271+
// Validate OrtModel input
272+
if (input_from_model) {
273+
if (input_model_->graph == nullptr) {
274+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
275+
"OrtModel has no graph. Call AddGraphToModel before compilation.");
276+
}
277+
278+
if (input_model_->graph->GetNumInputs() == 0 || input_model_->graph->GetNumOutputs() == 0) {
279+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
280+
"OrtModel graph must have at least one input and one output defined.");
281+
}
282+
283+
if (input_model_->domain_to_version.empty()) {
284+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
285+
"OrtModel must specify at least one opset domain/version.");
286+
}
287+
288+
// Note: Additional validation (node connections, schema) happens during
289+
// Model::LoadFromModelEditorApiModel -> Graph::Resolve()
290+
}
291+
252292
// Check output model settings.
253293
const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
254294
bool has_no_output_model_location = std::holds_alternative<std::monostate>(
255295
ep_context_gen_options.output_model_location);
256296

257-
if (has_no_output_model_location && input_from_file) {
297+
// Determine if we can derive an output path from the input
298+
bool can_derive_output_path = input_from_file;
299+
bool model_has_path = false;
300+
301+
// For OrtModel input, check if model_path is set in the graph using the virtual GetModelPath() method
302+
// (avoids dynamic_cast which requires RTTI)
303+
if (input_from_model && input_model_->graph) {
304+
const ORTCHAR_T* model_path_cstr = input_model_->graph->GetModelPath();
305+
if (model_path_cstr && model_path_cstr[0] != ORT_TSTR('\0')) {
306+
can_derive_output_path = true;
307+
model_has_path = true;
308+
}
309+
}
310+
311+
// Fast-fail: If OrtModel has no model_path and user hasn't specified output location or embed mode,
312+
// EPs that need to write context binaries will fail later. Fail early with a clear error.
313+
if (input_from_model && !model_has_path && has_no_output_model_location &&
314+
!ep_context_gen_options.embed_ep_context_in_model) {
315+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
316+
"OrtModel has no model_path set and no output location was specified. "
317+
"Please either: (1) set the model_path on the OrtGraph before adding to OrtModel, "
318+
"(2) call SetOutputModelPath/SetOutputModelBuffer to specify an output location, "
319+
"(3) call SetEpContextEmbedMode(true) to embed EP context in the model, or "
320+
"(4) call SetEpContextBinaryInformation to specify the binary output directory.");
321+
}
322+
323+
if (has_no_output_model_location && can_derive_output_path) {
258324
// User did not specify an output file, an output buffer, or an output write function. We default to generating an
259325
// output file with a name based on the input file name, so do not return an error.
260326
return Status::OK();

onnxruntime/core/session/model_compilation_options.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/common/status.h"
1111
#include "core/common/path_string.h"
1212
#include "core/framework/allocator.h"
13+
#include "core/graph/model_editor_api_types.h"
1314
#include "core/session/abi_session_options_impl.h"
1415
#include "core/session/onnxruntime_c_api.h"
1516
#include "core/session/onnxruntime_session_options_config_keys.h"
@@ -45,6 +46,14 @@ class ModelCompilationOptions {
4546
/// <param name="input_model_data_size">The size in bytes of the input model's buffer</param>
4647
void SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size);
4748

49+
/// <summary>
50+
/// Sets the OrtModel to compile.
51+
/// The OrtModel is borrowed (not copied) - caller must keep it alive until CompileModel returns.
52+
/// Overrides any previous call to SetInputModelPath(), SetInputModelFromBuffer(), or SetInputModel().
53+
/// </summary>
54+
/// <param name="model">The OrtModel to compile</param>
55+
void SetInputModel(const OrtModel* model);
56+
4857
/// <summary>
4958
/// Sets the file path to store the output/compiled ONNX model.
5059
/// Overrides any previous call to SetOutputModelPath() or SetOutputModelBuffer().
@@ -132,6 +141,18 @@ class ModelCompilationOptions {
132141
/// <returns>true if input model comes from a file</returns>
133142
bool InputModelComesFromFile() const;
134143

144+
/// <summary>
145+
/// Returns true if the input model comes from an OrtModel pointer.
146+
/// </summary>
147+
/// <returns>true if input model comes from an OrtModel</returns>
148+
bool InputModelComesFromOrtModel() const;
149+
150+
/// <summary>
151+
/// Returns the OrtModel to compile, or nullptr if not set.
152+
/// </summary>
153+
/// <returns>pointer to the OrtModel or nullptr</returns>
154+
const OrtModel* GetInputModel() const;
155+
135156
/// <summary>
136157
/// Returns the buffer that contains the bytes for the input ONNX model.
137158
/// Returns nullptr if the input model is not stored in a buffer.
@@ -205,6 +226,7 @@ class ModelCompilationOptions {
205226
std::filesystem::path input_model_path_;
206227
const void* input_model_data_ = nullptr;
207228
size_t input_model_data_size_ = 0;
229+
const OrtModel* input_model_ = nullptr; // Borrowed pointer
208230
};
209231
} // namespace onnxruntime
210232
#endif // !defined(ORT_MINIMAL_BUILD)

0 commit comments

Comments
 (0)