Skip to content

Commit 9c834b2

Browse files
authored
Run onnxruntime_provider_test with example_plugin_ep in CI (#26214)
### Description <!-- Describe your changes. --> - Add test dynamic plugin EP configuration setting to skip specific tests. - Run test in x64 Windows Release CI build. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Enable more testing of plugin EP in CI build.
1 parent 54086d8 commit 9c834b2

File tree

6 files changed

+105
-1
lines changed

6 files changed

+105
-1
lines changed

.github/workflows/windows_x64_release_build_x64_release.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,34 @@ jobs:
9898
ALLOW_RELEASED_ONNX_OPSET_ONLY: '0'
9999
DocUpdateNeeded: 'false'
100100

101+
- name: Run onnxruntime_provider_test with example_plugin_ep
102+
shell: pwsh
103+
run: |
104+
# Note on skipped tests:
105+
# The skipped tests are either:
106+
# - relying on CPU EP fallback for BFloat16 which is not supported
107+
# - testing the LayerNormalization contrib op with mixed input/output types (only supported by a few EPs)
108+
# Some other hardcoded EP types are skipped in these tests. For a plugin EP, we skip these tests by
109+
# specifying them in the dynamic plugin EP configuration.
110+
111+
$env:ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON = @"
112+
{
113+
"ep_library_registration_name": "example_ep",
114+
"ep_library_path": "./example_plugin_ep.dll",
115+
"selected_ep_name": "example_ep",
116+
"tests_to_skip": [
117+
"LayerNormTest.LayerNorm_BFloat16Input",
118+
"LayerNormTest.LayerNorm_Scale_Float16Input",
119+
"LayerNormTest.LayerNorm_Scale_Float16ScaleOutput",
120+
"LayerNormTest.LayerNorm_Scale_Bias_Float16Input",
121+
"LayerNormTest.LayerNorm_Scale_Bias_Float16ScaleBiasOutput"
122+
]
123+
}
124+
"@
125+
126+
.\onnxruntime_provider_test.exe
127+
working-directory: ${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo
128+
101129
- name: Validate C# native delegates
102130
shell: cmd
103131
run: python tools\ValidateNativeDelegateAttributes.py

onnxruntime/core/optimizer/transformer_memcpy.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,11 @@ static const IExecutionProvider* FindProviderByType(ProviderTypeToProviderMap pr
250250

251251
bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const {
252252
const auto& node_provider_type = node.GetExecutionProviderType();
253+
ORT_ENFORCE(!node_provider_type.empty(),
254+
"Provider type for ", node.OpType(), " node with name '", node.Name(), "' is not set.");
253255
const auto* node_provider = FindProviderByType(providers_by_type_, node_provider_type);
254-
ORT_ENFORCE(node_provider != nullptr, "Unable to get provider associated with provider type ", node_provider_type);
256+
ORT_ENFORCE(node_provider != nullptr,
257+
"Unable to get provider associated with provider type '", node_provider_type, "'.");
255258

256259
// Same provider?
257260
if (node_provider->Type() == provider_.Type()) {

onnxruntime/test/unittest_main/test_main.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
#include <algorithm>
55
#include <cstdlib>
6+
#include <memory>
67
#include <optional>
78
#include <string>
9+
#include <vector>
810
#ifdef _WIN32
911
#include <iostream>
1012
#include <locale>
@@ -35,6 +37,7 @@
3537

3638
#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE)
3739
#include "test/unittest_util/test_dynamic_plugin_ep.h"
40+
#include "test/util/include/skipping_test_listener.h"
3841
#endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE)
3942

4043
std::unique_ptr<Ort::Env> ort_env;
@@ -107,6 +110,19 @@ extern "C" void ortenv_teardown() {
107110
ort_env.reset();
108111
}
109112

113+
static std::vector<std::unique_ptr<::testing::TestEventListener>> MakeTestEventListeners() {
114+
std::vector<std::unique_ptr<::testing::TestEventListener>> result{};
115+
#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE)
116+
{
117+
namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra;
118+
const auto tests_to_skip = dynamic_plugin_ep_infra::GetTestsToSkip();
119+
auto skipping_test_listener = std::make_unique<onnxruntime::test::SkippingTestListener>(tests_to_skip);
120+
result.emplace_back(std::move(skipping_test_listener));
121+
}
122+
#endif
123+
return result;
124+
}
125+
110126
#ifdef USE_TENSORRT
111127

112128
#if defined(_MSC_VER)
@@ -152,6 +168,14 @@ int TEST_MAIN(int argc, char** argv) {
152168
ortenv_setup();
153169
::testing::InitGoogleTest(&argc, argv);
154170

171+
{
172+
auto& test_listeners = ::testing::UnitTest::GetInstance()->listeners();
173+
auto test_listeners_to_add = MakeTestEventListeners();
174+
for (auto& test_listener_to_add : test_listeners_to_add) {
175+
test_listeners.Append(test_listener_to_add.release());
176+
}
177+
}
178+
155179
status = RUN_ALL_TESTS();
156180
}
157181
ORT_CATCH(const std::exception& ex) {

onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Status ParseInitializationConfig(std::string_view json_str, InitializationConfig
9090
config.selected_ep_name = parsed_json.value<decltype(config.selected_ep_name)>("selected_ep_name", {});
9191
config.selected_ep_device_indices =
9292
parsed_json.value<decltype(config.selected_ep_device_indices)>("selected_ep_device_indices", {});
93+
config.tests_to_skip = parsed_json.value<decltype(config.tests_to_skip)>("tests_to_skip", {});
9394

9495
config_out = std::move(config);
9596
return Status::OK();
@@ -198,4 +199,12 @@ std::optional<std::string> GetEpName() {
198199
return g_plugin_ep_infrastructure_state->ep_name;
199200
}
200201

202+
std::vector<std::string> GetTestsToSkip() {
203+
if (!IsInitialized()) {
204+
return {};
205+
}
206+
207+
return g_plugin_ep_infrastructure_state->config.tests_to_skip;
208+
}
209+
201210
} // namespace onnxruntime::test::dynamic_plugin_ep_infra

onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ struct InitializationConfig {
4646
std::vector<size_t> selected_ep_device_indices{};
4747

4848
std::map<std::string, std::string> default_ep_options{};
49+
50+
// Specifies any tests to skip.
51+
// Tests should be specified by full name, i.e., "<test suite name>.<test name>".
52+
std::vector<std::string> tests_to_skip{};
4953
};
5054

5155
// Parses `InitializationConfig` from JSON.
@@ -75,6 +79,9 @@ std::unique_ptr<IExecutionProvider> MakeEp(const logging::Logger* logger = nullp
7579
// Gets the dynamic plugin EP name, or `std::nullopt` if uninitialized.
7680
std::optional<std::string> GetEpName();
7781

82+
// Gets the list of tests to skip, or an empty list if uninitialized.
83+
std::vector<std::string> GetTestsToSkip();
84+
7885
} // namespace dynamic_plugin_ep_infra
7986
} // namespace test
8087
} // namespace onnxruntime
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <string>
7+
#include <unordered_set>
8+
9+
#include "gsl/gsl"
10+
11+
#include "gtest/gtest.h"
12+
13+
namespace onnxruntime::test {
14+
15+
// A test event listener that skips the specified tests.
16+
class SkippingTestListener : public ::testing::EmptyTestEventListener {
17+
public:
18+
explicit SkippingTestListener(gsl::span<const std::string> tests_to_skip)
19+
: tests_to_skip_(tests_to_skip.begin(), tests_to_skip.end()) {
20+
}
21+
22+
private:
23+
void OnTestStart(const ::testing::TestInfo& test_info) override {
24+
const auto full_test_name = std::string(test_info.test_suite_name()) + "." + test_info.name();
25+
if (tests_to_skip_.find(full_test_name) != tests_to_skip_.end()) {
26+
GTEST_SKIP();
27+
}
28+
}
29+
30+
std::unordered_set<std::string> tests_to_skip_;
31+
};
32+
33+
} // namespace onnxruntime::test

0 commit comments

Comments
 (0)