@@ -92,6 +92,63 @@ class TestCompatibilityExecutionProvider : public IExecutionProvider {
9292 bool should_fail_validation_ = false ;
9393};
9494
95+ // Test execution provider that tracks whether GetCapability is called.
96+ // This is used to verify that early validation fails BEFORE Initialize() does expensive work.
97+ class TestEarlyValidationExecutionProvider : public IExecutionProvider {
98+ public:
99+ static constexpr const char * kTestEarlyValidationExecutionProviderType = " TestEarlyValidationExecutionProvider" ;
100+
101+ TestEarlyValidationExecutionProvider () : IExecutionProvider(kTestEarlyValidationExecutionProviderType ) {
102+ }
103+
104+ std::shared_ptr<KernelRegistry> GetKernelRegistry () const override {
105+ return std::make_shared<KernelRegistry>();
106+ }
107+
108+ std::vector<AllocatorPtr> CreatePreferredAllocators () override {
109+ return {};
110+ }
111+
112+ // Override GetCapability to track if it's called (happens during Initialize())
113+ std::vector<std::unique_ptr<ComputeCapability>> GetCapability (
114+ const onnxruntime::GraphViewer& graph_viewer,
115+ const IKernelLookup& kernel_lookup,
116+ const GraphOptimizerRegistry& graph_optimizer_registry,
117+ IResourceAccountant* resource_accountant = nullptr ) const override {
118+ ORT_UNUSED_PARAMETER (graph_viewer);
119+ ORT_UNUSED_PARAMETER (kernel_lookup);
120+ ORT_UNUSED_PARAMETER (graph_optimizer_registry);
121+ ORT_UNUSED_PARAMETER (resource_accountant);
122+ get_capability_called_ = true ;
123+ return {}; // Return empty - we don't actually want to handle any nodes
124+ }
125+
126+ // Configurable mock behavior for validation
127+ void SetMockCompatibilityStatus (OrtCompiledModelCompatibility status) {
128+ mock_compatibility_status_ = status;
129+ }
130+
131+ common::Status ValidateCompiledModelCompatibilityInfo (const std::string& compatibility_info,
132+ OrtCompiledModelCompatibility& model_compatibility) const override {
133+ ORT_UNUSED_PARAMETER (compatibility_info);
134+ model_compatibility = mock_compatibility_status_;
135+ return Status::OK ();
136+ }
137+
138+ // Query whether GetCapability was called
139+ bool WasGetCapabilityCalled () const {
140+ return get_capability_called_;
141+ }
142+
143+ void ResetGetCapabilityCalled () {
144+ get_capability_called_ = false ;
145+ }
146+
147+ private:
148+ OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
149+ mutable bool get_capability_called_ = false ;
150+ };
151+
95152// Helper class to create test models
96153class ModelBuilderWithCompatibility {
97154 public:
@@ -390,6 +447,72 @@ TEST_F(EpCompatibilityTest, TestEpValidationFailure) {
390447 EXPECT_THAT (status.ErrorMessage (), testing::HasSubstr (" Mock validation failure" ));
391448}
392449
450+ // Test that early validation optimization works: when a model is incompatible,
451+ // validation should fail BEFORE Initialize() performs expensive graph partitioning.
452+ // We verify this by checking that GetCapability() is NOT called when validation fails.
453+ TEST_F (EpCompatibilityTest, TestEarlyValidation_FailsBeforeGetCapability) {
454+ const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType ;
455+ const std::string compatibility_string = " test_compatibility_v1.0" ;
456+
457+ auto test_ep = std::make_unique<TestEarlyValidationExecutionProvider>();
458+ test_ep->SetMockCompatibilityStatus (OrtCompiledModelCompatibility_EP_UNSUPPORTED);
459+
460+ // Verify GetCapability hasn't been called yet
461+ EXPECT_FALSE (test_ep->WasGetCapabilityCalled ());
462+
463+ // Create model with compatibility metadata for this EP
464+ std::map<std::string, std::string> compatibility_info = {{ep_type, compatibility_string}};
465+ auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata (compatibility_info);
466+
467+ auto session = SessionBuilderWithCompatibility::CreateTestSession (std::move (model_with_metadata));
468+
469+ // Keep a raw pointer to check state after move
470+ auto * test_ep_ptr = test_ep.get ();
471+
472+ ASSERT_STATUS_OK (session->RegisterExecutionProvider (std::move (test_ep)));
473+
474+ // Initialization should fail due to incompatible model
475+ auto status = InitializeSessionWithValidation (*session);
476+ EXPECT_FALSE (status.IsOK ());
477+ EXPECT_THAT (status.ErrorMessage (), testing::HasSubstr (" not supported" ));
478+
479+ // CRITICAL: GetCapability should NOT have been called because validation failed early,
480+ // before Initialize() could perform graph partitioning
481+ EXPECT_FALSE (test_ep_ptr->WasGetCapabilityCalled ())
482+ << " GetCapability was called, indicating validation did not fail early before Initialize()" ;
483+ }
484+
485+ // Test that when validation succeeds, GetCapability IS called (normal flow)
486+ TEST_F (EpCompatibilityTest, TestEarlyValidation_SucceedsAndProceedsToGetCapability) {
487+ const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType ;
488+ const std::string compatibility_string = " test_compatibility_v1.0" ;
489+
490+ auto test_ep = std::make_unique<TestEarlyValidationExecutionProvider>();
491+ test_ep->SetMockCompatibilityStatus (OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL);
492+
493+ // Verify GetCapability hasn't been called yet
494+ EXPECT_FALSE (test_ep->WasGetCapabilityCalled ());
495+
496+ // Create model with compatibility metadata for this EP
497+ std::map<std::string, std::string> compatibility_info = {{ep_type, compatibility_string}};
498+ auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata (compatibility_info);
499+
500+ auto session = SessionBuilderWithCompatibility::CreateTestSession (std::move (model_with_metadata));
501+
502+ // Keep a raw pointer to check state after move
503+ auto * test_ep_ptr = test_ep.get ();
504+
505+ ASSERT_STATUS_OK (session->RegisterExecutionProvider (std::move (test_ep)));
506+
507+ // Initialization should succeed
508+ ASSERT_STATUS_OK (InitializeSessionWithValidation (*session));
509+
510+ // GetCapability SHOULD have been called because validation succeeded and
511+ // Initialize() proceeded normally with graph partitioning
512+ EXPECT_TRUE (test_ep_ptr->WasGetCapabilityCalled ())
513+ << " GetCapability was not called, but it should have been after successful validation" ;
514+ }
515+
393516// Test session option configuration for fail on suboptimal
394517TEST_F (EpCompatibilityTest, TestSessionOptionConfiguration) {
395518 SessionOptions so;
0 commit comments