@@ -245,42 +245,48 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) {
245245 example_kernel_ep));
246246 Ort::ConstEpDevice plugin_ep_device (example_kernel_ep.get ());
247247
248- // Create session with example kernel-based plugin EP
249- Ort::SessionOptions session_options;
250- session_options.AddConfigEntry (kOrtSessionOptionsDisableCPUEPFallback , " 1" ); // Fail if any node assigned to CPU EP.
248+ auto run_model_with_ep_options = [&](const std::unordered_map<std::string, std::string>& ep_options) {
249+ // Create session with example kernel-based plugin EP
250+ Ort::SessionOptions session_options;
251+ session_options.AddConfigEntry (kOrtSessionOptionsDisableCPUEPFallback , " 1" ); // Fail if any node assigned to CPU EP.
252+ session_options.AppendExecutionProvider_V2 (*ort_env, {plugin_ep_device}, ep_options);
251253
252- std::unordered_map<std::string, std::string> ep_options;
253- session_options. AppendExecutionProvider_V2 (*ort_env, {plugin_ep_device}, ep_options );
254+ // This model has Squeeze, Mul, and Relu nodes. The example plugin EP supports all nodes using registered kernels.
255+ Ort::Session session (*ort_env, ORT_TSTR ( " testdata/squeeze_mul_relu.onnx " ), session_options );
254256
255- // This model has Squeeze, Mul, and Relu nodes. The example plugin EP supports all nodes using registered kernels.
256- Ort::Session session (*ort_env, ORT_TSTR (" testdata/squeeze_mul_relu.onnx" ), session_options);
257+ // Create inputs
258+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
259+ std::array<int64_t , 3 > a_shape = {3 , 1 , 2 };
260+ std::array<int64_t , 2 > b_shape = {3 , 2 };
257261
258- // Create inputs
259- Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
260- std::array<int64_t , 3 > a_shape = {3 , 1 , 2 };
261- std::array<int64_t , 2 > b_shape = {3 , 2 };
262+ std::array<float , 6 > a_data = {1 .f , -2 .f , 3 .f , 4 .f , -5 .f , 6 .f };
263+ std::array<float , 6 > b_data = {2 .f , 3 .f , 4 .f , -5 .f , 6 .f , 7 .f };
262264
263- std::array<float , 6 > a_data = {1 .f , -2 .f , 3 .f , 4 .f , -5 .f , 6 .f };
264- std::array<float , 6 > b_data = {2 .f , 3 .f , 4 .f , -5 .f , 6 .f , 7 .f };
265+ std::vector<Ort::Value> ort_inputs{};
266+ ort_inputs.emplace_back (
267+ Ort::Value::CreateTensor<float >(memory_info, a_data.data (), a_data.size (), a_shape.data (), a_shape.size ()));
268+ ort_inputs.emplace_back (
269+ Ort::Value::CreateTensor<float >(memory_info, b_data.data (), b_data.size (), b_shape.data (), b_shape.size ()));
265270
266- std::vector<Ort::Value> ort_inputs{};
267- ort_inputs.emplace_back (
268- Ort::Value::CreateTensor<float >(memory_info, a_data.data (), a_data.size (), a_shape.data (), a_shape.size ()));
269- ort_inputs.emplace_back (
270- Ort::Value::CreateTensor<float >(memory_info, b_data.data (), b_data.size (), b_shape.data (), b_shape.size ()));
271+ std::array ort_input_names{" A" , " B" };
271272
272- std::array ort_input_names{" A" , " B" };
273+ // Run session and get outputs
274+ std::array output_names{" C" };
275+ std::vector<Ort::Value> ort_outputs = session.Run (Ort::RunOptions{nullptr }, ort_input_names.data (), ort_inputs.data (),
276+ ort_inputs.size (), output_names.data (), output_names.size ());
273277
274- // Run session and get outputs
275- std::array output_names{" C" };
276- std::vector<Ort::Value> ort_outputs = session.Run (Ort::RunOptions{nullptr }, ort_input_names.data (), ort_inputs.data (),
277- ort_inputs.size (), output_names.data (), output_names.size ());
278+ // Check expected output values
279+ Ort::Value& ort_output = ort_outputs[0 ];
280+ const float * output_data = ort_output.GetTensorData <float >();
281+ gsl::span<const float > output_span (output_data, 6 );
282+ EXPECT_THAT (output_span, ::testing::ElementsAre (4 , 0 , 24 , 0 , 0 , 84 ));
283+ };
278284
279- // Check expected output values
280- Ort::Value& ort_output = ort_outputs[ 0 ];
281- const float * output_data = ort_output. GetTensorData < float >();
282- gsl::span< const float > output_span (output_data, 6 );
283- EXPECT_THAT (output_span, :: testing::ElementsAre ( 4 , 0 , 24 , 0 , 0 , 84 ) );
285+ run_model_with_ep_options ({});
286+
287+ // Enable sharing of pre-packed weights.
288+ // This also tests the ability for the kernel implementation to retrieve the OrtEp and get its configuration.
289+ run_model_with_ep_options ({{ " enable_prepack_weight_sharing " , " 1 " }} );
284290}
285291} // namespace test
286292} // namespace onnxruntime
0 commit comments