Skip to content

Commit 080d968

Browse files
adrastogiAditya Rastogi
andauthored
Move model compatibility checks ahead of session initialization (#27037)
### Description <!-- Describe your changes. --> The current infrastructure for validating compatibility of a precompiled model does the check after session initialization occurs, which turns out to be quite costly. The check should ideally happen beforehand, to short-circuit those expensive operations. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This change will make it more tractable for applications to rely on the existing session machinery to check compatibility of any of their models. Co-authored-by: Aditya Rastogi <adityar@ntdev.microsoft.com>
1 parent f481b17 commit 080d968

File tree

3 files changed

+150
-6
lines changed

3 files changed

+150
-6
lines changed

onnxruntime/core/session/inference_session.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,19 @@ class InferenceSession {
505505
*/
506506
const std::vector<std::string>& GetRegisteredProviderTypes() const;
507507

508+
/**
509+
* Get the registered Execution Providers.
510+
*
511+
* This method can be called after EP registration but before Initialize() completes.
512+
* Used only for early validation of compiled model compatibility where accessing
513+
* EPs through session state is not yet possible.
514+
*
515+
* @return const reference to the ExecutionProviders collection.
516+
*/
517+
const ExecutionProviders& GetExecutionProviders() const noexcept {
518+
return execution_providers_;
519+
}
520+
508521
/*
509522
* Get the options this session was initialized with.
510523
*/

onnxruntime/core/session/utils.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,10 @@ static Status ValidateCompiledModelCompatibility(InferenceSession& sess) {
335335

336336
const auto& registered_provider_types = sess.GetRegisteredProviderTypes();
337337

338-
// Access the execution providers through the session state (available after Initialize)
339-
const auto& execution_providers = sess.GetSessionState().GetExecutionProviders();
338+
// Access the execution providers directly from the session.
339+
// This allows validation to run before Initialize() completes, avoiding expensive
340+
// graph transformations for incompatible models. EPs are fully registered at this point.
341+
const auto& execution_providers = sess.GetExecutionProviders();
340342

341343
for (const auto& ep_type : registered_provider_types) {
342344
// Construct the full metadata key using the prefix + EP type
@@ -455,14 +457,20 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options,
455457
reinterpret_cast<PrepackedWeightsContainer*>(prepacked_weights_container)));
456458
}
457459

458-
ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize());
459-
460460
#if !defined(ORT_MINIMAL_BUILD)
461-
// Validate compiled model compatibility for all registered execution providers
462-
// This must be done after Initialize() so the session state is available
461+
// Validate compiled model compatibility for all registered execution providers BEFORE Initialize().
462+
// This is an optimization to fail fast for incompatible models, avoiding expensive graph transformations,
463+
// partitioning, and kernel binding that occur during Initialize().
464+
// This is safe because:
465+
// 1. Model metadata (containing compatibility strings) is available after Load() completes.
466+
// 2. Compiling EPs are fully registered at this point.
467+
// 3. Non-compiling EPs (like CPU EP, which may be implicitly added during Initialize()) don't participate
468+
// in compatibility validation - they return NOT_APPLICABLE by default.
463469
ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess));
464470
#endif // !defined(ORT_MINIMAL_BUILD)
465471

472+
ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize());
473+
466474
return nullptr;
467475
}
468476

onnxruntime/test/framework/ep_compatibility_test.cc

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,63 @@ class TestCompatibilityExecutionProvider : public IExecutionProvider {
9292
bool should_fail_validation_ = false;
9393
};
9494

95+
// Test execution provider that tracks whether GetCapability is called.
96+
// This is used to verify that early validation fails BEFORE Initialize() does expensive work.
97+
class TestEarlyValidationExecutionProvider : public IExecutionProvider {
98+
public:
99+
static constexpr const char* kTestEarlyValidationExecutionProviderType = "TestEarlyValidationExecutionProvider";
100+
101+
TestEarlyValidationExecutionProvider() : IExecutionProvider(kTestEarlyValidationExecutionProviderType) {
102+
}
103+
104+
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override {
105+
return std::make_shared<KernelRegistry>();
106+
}
107+
108+
std::vector<AllocatorPtr> CreatePreferredAllocators() override {
109+
return {};
110+
}
111+
112+
// Override GetCapability to track if it's called (happens during Initialize())
113+
std::vector<std::unique_ptr<ComputeCapability>> GetCapability(
114+
const onnxruntime::GraphViewer& graph_viewer,
115+
const IKernelLookup& kernel_lookup,
116+
const GraphOptimizerRegistry& graph_optimizer_registry,
117+
IResourceAccountant* resource_accountant = nullptr) const override {
118+
ORT_UNUSED_PARAMETER(graph_viewer);
119+
ORT_UNUSED_PARAMETER(kernel_lookup);
120+
ORT_UNUSED_PARAMETER(graph_optimizer_registry);
121+
ORT_UNUSED_PARAMETER(resource_accountant);
122+
get_capability_called_ = true;
123+
return {}; // Return empty - we don't actually want to handle any nodes
124+
}
125+
126+
// Configurable mock behavior for validation
127+
void SetMockCompatibilityStatus(OrtCompiledModelCompatibility status) {
128+
mock_compatibility_status_ = status;
129+
}
130+
131+
common::Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info,
132+
OrtCompiledModelCompatibility& model_compatibility) const override {
133+
ORT_UNUSED_PARAMETER(compatibility_info);
134+
model_compatibility = mock_compatibility_status_;
135+
return Status::OK();
136+
}
137+
138+
// Query whether GetCapability was called
139+
bool WasGetCapabilityCalled() const {
140+
return get_capability_called_;
141+
}
142+
143+
void ResetGetCapabilityCalled() {
144+
get_capability_called_ = false;
145+
}
146+
147+
private:
148+
OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
149+
mutable bool get_capability_called_ = false;
150+
};
151+
95152
// Helper class to create test models
96153
class ModelBuilderWithCompatibility {
97154
public:
@@ -390,6 +447,72 @@ TEST_F(EpCompatibilityTest, TestEpValidationFailure) {
390447
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Mock validation failure"));
391448
}
392449

450+
// Test that early validation optimization works: when a model is incompatible,
451+
// validation should fail BEFORE Initialize() performs expensive graph partitioning.
452+
// We verify this by checking that GetCapability() is NOT called when validation fails.
453+
TEST_F(EpCompatibilityTest, TestEarlyValidation_FailsBeforeGetCapability) {
454+
const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType;
455+
const std::string compatibility_string = "test_compatibility_v1.0";
456+
457+
auto test_ep = std::make_unique<TestEarlyValidationExecutionProvider>();
458+
test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_UNSUPPORTED);
459+
460+
// Verify GetCapability hasn't been called yet
461+
EXPECT_FALSE(test_ep->WasGetCapabilityCalled());
462+
463+
// Create model with compatibility metadata for this EP
464+
std::map<std::string, std::string> compatibility_info = {{ep_type, compatibility_string}};
465+
auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info);
466+
467+
auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata));
468+
469+
// Keep a raw pointer to check state after move
470+
auto* test_ep_ptr = test_ep.get();
471+
472+
ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep)));
473+
474+
// Initialization should fail due to incompatible model
475+
auto status = InitializeSessionWithValidation(*session);
476+
EXPECT_FALSE(status.IsOK());
477+
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("not supported"));
478+
479+
// CRITICAL: GetCapability should NOT have been called because validation failed early,
480+
// before Initialize() could perform graph partitioning
481+
EXPECT_FALSE(test_ep_ptr->WasGetCapabilityCalled())
482+
<< "GetCapability was called, indicating validation did not fail early before Initialize()";
483+
}
484+
485+
// Test that when validation succeeds, GetCapability IS called (normal flow)
486+
TEST_F(EpCompatibilityTest, TestEarlyValidation_SucceedsAndProceedsToGetCapability) {
487+
const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType;
488+
const std::string compatibility_string = "test_compatibility_v1.0";
489+
490+
auto test_ep = std::make_unique<TestEarlyValidationExecutionProvider>();
491+
test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL);
492+
493+
// Verify GetCapability hasn't been called yet
494+
EXPECT_FALSE(test_ep->WasGetCapabilityCalled());
495+
496+
// Create model with compatibility metadata for this EP
497+
std::map<std::string, std::string> compatibility_info = {{ep_type, compatibility_string}};
498+
auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info);
499+
500+
auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata));
501+
502+
// Keep a raw pointer to check state after move
503+
auto* test_ep_ptr = test_ep.get();
504+
505+
ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep)));
506+
507+
// Initialization should succeed
508+
ASSERT_STATUS_OK(InitializeSessionWithValidation(*session));
509+
510+
// GetCapability SHOULD have been called because validation succeeded and
511+
// Initialize() proceeded normally with graph partitioning
512+
EXPECT_TRUE(test_ep_ptr->WasGetCapabilityCalled())
513+
<< "GetCapability was not called, but it should have been after successful validation";
514+
}
515+
393516
// Test session option configuration for fail on suboptimal
394517
TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) {
395518
SessionOptions so;

0 commit comments

Comments
 (0)