Skip to content
Merged
43 changes: 30 additions & 13 deletions onnxruntime/test/unittest_util/base_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session,
bool SetEpsForAllNodes(Graph& graph,
const std::vector<std::unique_ptr<IExecutionProvider>>& execution_providers,
const std::vector<std::shared_ptr<CustomRegistry>>* custom_registries,
const std::function<bool(const IExecutionProvider&)>& ep_uses_kernel_registry_fn) {
const std::function<bool(const IExecutionProvider&)>& ep_only_uses_kernel_registry_fn) {
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
const KernelRegistry::TypeConstraintMap type_constraint_map{};

Expand All @@ -440,7 +440,7 @@ bool SetEpsForAllNodes(Graph& graph,

node.SetExecutionProviderType(provider_type);

if (!ep_uses_kernel_registry_fn(*ep)) {
if (!ep_only_uses_kernel_registry_fn(*ep)) {
found = true;
break;
}
Expand Down Expand Up @@ -659,7 +659,12 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter,
#endif
kDnnlExecutionProvider,
kTensorrtExecutionProvider,
#ifdef USE_NV
// Only include NV TRT RTX EP when is ORT is built with the provider-bridge
// version of the EP (i.e., USE_NV is defined). This allows use of the plugin EP version of the EP
// when ORT is not built any provider-bridge EPs.
kNvTensorRTRTXExecutionProvider,
#endif
kOpenVINOExecutionProvider,
kDmlExecutionProvider,
kAclExecutionProvider,
Expand Down Expand Up @@ -830,12 +835,15 @@ void BaseTester::ExecuteModelForEps(

ASSERT_TRUE(!execution_providers.empty()) << "Empty execution providers vector.";
if (try_assign_ep_for_nodes) {
auto ep_uses_kernel_registry = [](const IExecutionProvider& ep) {
auto ep_only_uses_kernel_registry = [](const IExecutionProvider& ep) {
const auto& provider_type = ep.Type();

constexpr std::array kEpsThatDoNotUseKernelRegistry{
constexpr std::array kEpsThatCompileNodes{
kOpenVINOExecutionProvider,
kTensorrtExecutionProvider,
kTensorrtExecutionProvider, // uses kernel registry for Memcpy* nodes only
#ifdef USE_NV
kNvTensorRTRTXExecutionProvider, // uses kernel registry for Memcpy* nodes only
#endif
kNnapiExecutionProvider,
kVSINPUExecutionProvider,
kCoreMLExecutionProvider,
Expand All @@ -844,24 +852,33 @@ void BaseTester::ExecuteModelForEps(
kSnpeExecutionProvider,
};

// check list of known EPs that do not use a kernel registry
if (const auto ep_it = std::find(kEpsThatDoNotUseKernelRegistry.begin(), kEpsThatDoNotUseKernelRegistry.end(),
// check list of known EPs that compile nodes
if (const auto ep_it = std::find(kEpsThatCompileNodes.begin(), kEpsThatCompileNodes.end(),
provider_type);
ep_it != kEpsThatDoNotUseKernelRegistry.end()) {
ep_it != kEpsThatCompileNodes.end()) {
return false;
}

// assume that a dynamic plugin EP which does not return a kernel registry does not use one
if (provider_type == dynamic_plugin_ep_infra::GetEpName() &&
ep.GetKernelRegistry() == nullptr) {
return false;
const OrtEp* ort_ep = ep.GetOrtEp();

if (ort_ep != nullptr) { // This is a plugin EP

if (ep.GetKernelRegistry() == nullptr) {
// assume that a dynamic plugin EP which does not return a kernel registry does not use one
return false;
}

if (ort_ep->Compile != nullptr) {
// assume that a plugin EP that compiles nodes does not use a kernel registry for all nodes
return false;
}
}

// otherwise, assume that the EP uses a kernel registry
return true;
};

if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_uses_kernel_registry)) {
if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_only_uses_kernel_registry)) {
std::string providers;
for (const auto& ep : execution_providers) {
providers.append(ep->Type() + " ");
Expand Down
Loading