Skip to content

Commit 4a5e066

Browse files
committed
Implement EP API for WebGPU EP
1 parent d6e0859 commit 4a5e066

27 files changed

+1081
-83
lines changed

cmake/onnxruntime_providers_webgpu.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@
122122
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
123123
message(FATAL_ERROR "WebGPU EP shared library build is not supported on Emscripten. Please use static library build.")
124124
endif()
125+
126+
# Configure precompiled headers for shared library build
127+
# PCH ensures ep/_pch.h is included first and improves compilation speed
128+
target_precompile_headers(onnxruntime_providers_webgpu PRIVATE
129+
"${REPO_ROOT}/include/onnxruntime/ep/_pch.h"
130+
)
125131
endif()
126132

127133
set_target_properties(onnxruntime_providers_webgpu PROPERTIES CXX_STANDARD_REQUIRED ON)

cmake/onnxruntime_unittests.cmake

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,18 @@ function(onnxruntime_apply_test_target_workarounds target)
10381038
endif()
10391039
endfunction()
10401040

1041+
# Set environment variables for plugin EP tests when run via CTest.
1042+
function(onnxruntime_set_plugin_ep_test_environment target)
1043+
if(onnxruntime_USE_WEBGPU AND NOT onnxruntime_BUILD_WEBGPU_EP_STATIC_LIB)
1044+
set(ORT_PLUGIN_EP_JSON_CONFIG "{\"ep_library_registration_name\": \"WebGPU_PluginEP\", \"ep_library_path\": \"onnxruntime_providers_webgpu.dll\", \"selected_ep_name\": \"WebGpuExecutionProvider\"}")
1045+
set_tests_properties(${target} PROPERTIES
1046+
ENVIRONMENT "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON=${ORT_PLUGIN_EP_JSON_CONFIG}"
1047+
)
1048+
# TODO: add for other plugin EPs if needed
1049+
# elseif()
1050+
endif()
1051+
endfunction()
1052+
10411053
function(onnxruntime_apply_emscripten_test_link_settings target)
10421054
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
10431055
set_target_properties(${target} PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js)
@@ -1239,6 +1251,7 @@ block()
12391251
)
12401252

12411253
onnxruntime_apply_test_target_workarounds(onnxruntime_provider_test)
1254+
onnxruntime_set_plugin_ep_test_environment(onnxruntime_provider_test)
12421255

12431256
# Expose QNN SDK headers to unit tests via an interface target
12441257
if(onnxruntime_USE_QNN)

onnxruntime/core/providers/cpu/tensor/upsamplebase.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,22 @@ class UpsampleBase {
219219
if (scales_input_idx_ > 0) {
220220
const Tensor* scale;
221221
bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale);
222-
auto x_shape = node.InputDefs()[0]->Shape();
223-
int64_t rank = x_shape ? x_shape->dim_size() : -1;
222+
int64_t rank = -1;
223+
if constexpr (std::is_same_v<KernelInfoType, onnxruntime::OpKernelInfo>) {
224+
auto x_shape = node.InputDefs()[0]->Shape();
225+
if (x_shape != nullptr) {
226+
rank = x_shape->dim_size();
227+
}
228+
} else {
229+
int is_const;
230+
auto tensor = info.GetKernelInfo().GetTensorConstantInput(0, &is_const);
231+
if (is_const) {
232+
auto type_and_shape_info = tensor.GetTensorTypeAndShapeInfo();
233+
if (type_and_shape_info.HasShape()) {
234+
rank = static_cast<int64_t>(type_and_shape_info.GetShape().size());
235+
}
236+
}
237+
}
224238
if (get_scale && scale->Shape().Size() > 0 && ((opset < 18) || (rank > 0 && opset >= 18))) {
225239
ORT_THROW_IF_ERROR(ParseScalesData(scale, scales_, rank));
226240
scales_cached_ = true;

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ class ComputeContextBase {
100100
// Get the logger.
101101
//
102102
inline const logging::Logger& Logger() const {
103+
#if defined(BUILD_WEBGPU_EP_STATIC_LIB)
103104
return *ep_.GetLogger();
105+
#else
106+
return ep_.GetEpLogger();
107+
#endif
104108
}
105109

106110
//

onnxruntime/core/providers/webgpu/controlflow/if.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "core/providers/webgpu/controlflow/if.h"
55

6+
#if !defined(BUILD_WEBGPU_EP_STATIC_LIB)
7+
#include "core/framework/error_code_helper.h"
8+
#endif
9+
610
using namespace ONNX_NAMESPACE;
711
using namespace onnxruntime::common;
812

@@ -68,10 +72,20 @@ ONNX_OPERATOR_KERNEL_EX(If,
6872
.TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()),
6973
If);
7074

75+
#if defined(BUILD_WEBGPU_EP_STATIC_LIB)
7176
Status If::Compute(OpKernelContext* ctx) const {
7277
// call the base CPU version.
7378
return onnxruntime::If::Compute(ctx);
7479
}
80+
#else
81+
Status If::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) {
82+
return ToStatusAndRelease(ep::Api().ep.CreateIfKernel(info, impl));
83+
}
84+
85+
Status If::Compute(OpKernelContext* ctx) const {
86+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "If operator should be handled by ORT core.");
87+
}
88+
#endif
7589

7690
} // namespace webgpu
77-
} // namespace onnxruntime
91+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/controlflow/if.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
namespace onnxruntime {
1111
namespace webgpu {
1212

13+
#if defined(BUILD_WEBGPU_EP_STATIC_LIB)
14+
1315
// Use the CPU implementation for the logic
1416
class If final : public onnxruntime::If {
1517
public:
@@ -18,5 +20,16 @@ class If final : public onnxruntime::If {
1820
Status Compute(OpKernelContext* ctx) const override;
1921
};
2022

23+
#else
24+
25+
class If final : public OpKernel {
26+
public:
27+
If(const OpKernelInfo& info) : OpKernel(info) {}
28+
29+
Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override;
30+
Status Compute(OpKernelContext* ctx) const override;
31+
};
32+
#endif
33+
2134
} // namespace webgpu
22-
} // namespace onnxruntime
35+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/data_transfer.cc

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,45 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev
1313
(dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU);
1414
}
1515

16-
common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
17-
size_t bytes = src.SizeInBytes();
16+
common::Status DataTransfer::CopyTensorImpl(void const* src_data,
17+
bool src_is_gpu,
18+
void* dst_data,
19+
bool dst_is_gpu,
20+
size_t bytes) const {
1821
if (bytes > 0) {
19-
void const* src_data = src.DataRaw();
20-
void* dst_data = dst.MutableDataRaw();
21-
22-
auto& src_device = src.Location().device;
23-
auto& dst_device = dst.Location().device;
24-
25-
if (dst_device.Type() == OrtDevice::GPU) {
26-
if (src_device.Type() == OrtDevice::GPU) {
22+
if (dst_is_gpu) {
23+
if (src_is_gpu) {
2724
// copy from GPU to GPU
2825
buffer_manager_.MemCpy(static_cast<WGPUBuffer>(const_cast<void*>(src_data)),
29-
static_cast<WGPUBuffer>(dst_data), bytes);
26+
static_cast<WGPUBuffer>(dst_data),
27+
bytes);
3028
} else {
3129
// copy from CPU to GPU
32-
buffer_manager_.Upload(const_cast<void*>(src_data), static_cast<WGPUBuffer>(dst_data), bytes);
30+
buffer_manager_.Upload(const_cast<void*>(src_data),
31+
static_cast<WGPUBuffer>(dst_data),
32+
bytes);
3333
}
34-
} else /* if (src_device.Type() == OrtDevice::GPU) */ {
34+
} else {
3535
// copy from GPU to CPU
36-
buffer_manager_.Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)), dst_data, bytes);
36+
buffer_manager_.Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)),
37+
dst_data,
38+
bytes);
3739
}
3840
}
3941

4042
return Status::OK();
4143
}
4244

45+
common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
46+
void const* src_data = src.DataRaw();
47+
void* dst_data = dst.MutableDataRaw();
48+
49+
return CopyTensorImpl(src_data,
50+
src.Location().device.Type() == OrtDevice::GPU,
51+
dst_data,
52+
dst.Location().device.Type() == OrtDevice::GPU,
53+
src.SizeInBytes());
54+
}
55+
4356
} // namespace webgpu
4457
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/data_transfer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class DataTransfer : public IDataTransfer {
2020

2121
common::Status CopyTensor(const Tensor& src, Tensor& dst) const override;
2222

23+
common::Status CopyTensorImpl(void const* src_data,
24+
bool src_is_gpu,
25+
void* dst_data,
26+
bool dst_is_gpu,
27+
size_t bytes) const;
28+
2329
private:
2430
const BufferManager& buffer_manager_;
2531
};
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#define ORT_API_MANUAL_INIT
5+
#include "onnxruntime_cxx_api.h"
6+
#undef ORT_API_MANUAL_INIT
7+
8+
#include <memory>
9+
10+
#include "core/providers/webgpu/ep/factory.h"
11+
12+
// To make symbols visible on macOS/iOS
13+
#ifdef __APPLE__
14+
#define EXPORT_SYMBOL __attribute__((visibility("default")))
15+
#else
16+
#define EXPORT_SYMBOL
17+
#endif
18+
19+
namespace onnxruntime {
20+
namespace webgpu {
21+
void CleanupWebGpuContexts();
22+
void CleanupKernelRegistries();
23+
} // namespace webgpu
24+
} // namespace onnxruntime
25+
26+
namespace google {
27+
namespace protobuf {
28+
void ShutdownProtobufLibrary();
29+
} // namespace protobuf
30+
} // namespace google
31+
32+
extern "C" {
33+
//
34+
// Public symbols
35+
//
36+
EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base,
37+
const OrtLogger* default_logger,
38+
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
39+
// Manual init for the C++ API
40+
onnxruntime::ep::ApiInit(ort_api_base);
41+
42+
if (max_factories < 1) {
43+
return onnxruntime::ep::Api().ort.CreateStatus(ORT_INVALID_ARGUMENT,
44+
"Not enough space to return EP factory. Need at least one.");
45+
}
46+
47+
// Initialize the global default logger
48+
::onnxruntime::ep::adapter::Logger::CreateDefaultLogger(default_logger);
49+
50+
// Factory could use registration_name or define its own EP name.
51+
std::unique_ptr<OrtEpFactory> factory = std::make_unique<onnxruntime::webgpu::ep::Factory>();
52+
53+
factories[0] = factory.release();
54+
*num_factories = 1;
55+
56+
return nullptr;
57+
}
58+
59+
EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
60+
// STEP.1 - Release the factory
61+
delete static_cast<onnxruntime::webgpu::ep::Factory*>(factory);
62+
63+
// STEP.2 - Clean up cached kernel registries
64+
onnxruntime::webgpu::CleanupKernelRegistries();
65+
66+
// STEP.3 - Clean up WebGPU contexts
67+
onnxruntime::webgpu::CleanupWebGpuContexts();
68+
69+
// STEP.4 - Destroy the global default logger wrapper
70+
::onnxruntime::ep::adapter::Logger::DestroyDefaultLogger();
71+
72+
// STEP.5 - Shutdown protobuf library
73+
google::protobuf::ShutdownProtobufLibrary();
74+
75+
return nullptr;
76+
}
77+
78+
} // extern "C"

0 commit comments

Comments
 (0)