Skip to content

Commit 1ca0d32

Browse files
committed
WIP
1 parent 72ecea9 commit 1ca0d32

24 files changed

+959
-43
lines changed

cmake/onnxruntime_providers_webgpu.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@
122122
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
123123
message(FATAL_ERROR "WebGPU EP shared library build is not supported on Emscripten. Please use static library build.")
124124
endif()
125+
126+
# Configure precompiled headers for shared library build
127+
# PCH ensures ep/_pch.h is included first and improves compilation speed
128+
target_precompile_headers(onnxruntime_providers_webgpu PRIVATE
129+
"${REPO_ROOT}/include/onnxruntime/ep/_pch.h"
130+
)
125131
endif()
126132

127133
set_target_properties(onnxruntime_providers_webgpu PROPERTIES CXX_STANDARD_REQUIRED ON)

onnxruntime/core/providers/cpu/tensor/padbase.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,20 @@ class PadBase {
158158
ORT_THROW("Invalid 'mode' attribute value");
159159
}
160160

161-
const auto& kernel_def = info.GetKernelDef();
161+
if constexpr (std::is_same_v<KernelInfoType, onnxruntime::OpKernelInfo>) {
162+
const auto& kernel_def = info.GetKernelDef();
162163

163-
int start_ver, end_ver;
164-
kernel_def.SinceVersion(&start_ver, &end_ver);
164+
int start_ver, end_ver;
165+
kernel_def.SinceVersion(&start_ver, &end_ver);
165166

166-
// kMSDomain contrib kernel AND OnnxDomain start version >= 11 => DynamicPad
167-
if (start_ver >= 11 || kernel_def.Domain() == kMSDomain) {
168-
is_dynamic_ = true;
167+
// kMSDomain contrib kernel AND OnnxDomain start version >= 11 => DynamicPad
168+
if (start_ver >= 11 || kernel_def.Domain() == kMSDomain) {
169+
is_dynamic_ = true;
170+
}
171+
} else {
172+
if (info.node().SinceVersion() >= 11) { // TODO(fs-eire): support contrib domain check
173+
is_dynamic_ = true;
174+
}
169175
}
170176

171177
if (!is_dynamic_) {

onnxruntime/core/providers/cpu/tensor/upsamplebase.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,22 @@ class UpsampleBase {
219219
if (scales_input_idx_ > 0) {
220220
const Tensor* scale;
221221
bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale);
222-
auto x_shape = node.InputDefs()[0]->Shape();
223-
int64_t rank = x_shape ? x_shape->dim_size() : -1;
222+
int64_t rank = -1;
223+
if constexpr (std::is_same_v<KernelInfoType, onnxruntime::OpKernelInfo>) {
224+
auto x_shape = node.InputDefs()[0]->Shape();
225+
if (x_shape != nullptr) {
226+
rank = x_shape->dim_size();
227+
}
228+
} else {
229+
int is_const;
230+
auto tensor = info.GetKernelInfo().GetTensorConstantInput(0, &is_const);
231+
if (is_const) {
232+
auto type_and_shape_info = tensor.GetTensorTypeAndShapeInfo();
233+
if (type_and_shape_info.HasShape()) {
234+
rank = static_cast<int64_t>(type_and_shape_info.GetShape().size());
235+
}
236+
}
237+
}
224238
if (get_scale && scale->Shape().Size() > 0 && ((opset < 18) || (rank > 0 && opset >= 18))) {
225239
ORT_THROW_IF_ERROR(ParseScalesData(scale, scales_, rank));
226240
scales_cached_ = true;

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ class ComputeContextBase {
100100
// Get the logger.
101101
//
102102
inline const logging::Logger& Logger() const {
103+
#if defined(BUILD_WEBGPU_EP_STATIC_LIB)
103104
return *ep_.GetLogger();
105+
#else
106+
return ep_.GetEpLogger();
107+
#endif
104108
}
105109

106110
//

onnxruntime/core/providers/webgpu/controlflow/if.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,14 @@ ONNX_OPERATOR_KERNEL_EX(If,
6969
If);
7070

7171
Status If::Compute(OpKernelContext* ctx) const {
72+
#if defined(BUILD_WEBGPU_EP_STATIC_LIB)
7273
// call the base CPU version.
7374
return onnxruntime::If::Compute(ctx);
75+
#else
76+
// TODO(fs-eire): implement WebGPU If kernel
77+
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "If operator is not implemented for WebGPU EP yet.");
78+
#endif
7479
}
7580

7681
} // namespace webgpu
77-
} // namespace onnxruntime
82+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/controlflow/if.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
namespace onnxruntime {
1111
namespace webgpu {
1212

13+
#if defined(BUILD_WEBGPU_EP_STATIC_LIB)
14+
1315
// Use the CPU implementation for the logic
1416
class If final : public onnxruntime::If {
1517
public:
@@ -18,5 +20,15 @@ class If final : public onnxruntime::If {
1820
Status Compute(OpKernelContext* ctx) const override;
1921
};
2022

23+
#else
24+
25+
class If final : public OpKernel {
26+
public:
27+
If(const OpKernelInfo& info) : OpKernel(info) {}
28+
29+
Status Compute(OpKernelContext* ctx) const override;
30+
};
31+
#endif
32+
2133
} // namespace webgpu
22-
} // namespace onnxruntime
34+
} // namespace onnxruntime
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#define ORT_API_MANUAL_INIT
5+
#include "onnxruntime_cxx_api.h"
6+
#undef ORT_API_MANUAL_INIT
7+
8+
#include <memory>
9+
10+
#include "core/providers/webgpu/ep/factory.h"
11+
12+
// To make symbols visible on macOS/iOS
13+
#ifdef __APPLE__
14+
#define EXPORT_SYMBOL __attribute__((visibility("default")))
15+
#else
16+
#define EXPORT_SYMBOL
17+
#endif
18+
19+
namespace onnxruntime {
20+
namespace webgpu {
21+
void CleanupWebGpuContexts();
22+
void CleanupKernelRegistries();
23+
} // namespace webgpu
24+
} // namespace onnxruntime
25+
26+
namespace google {
27+
namespace protobuf {
28+
void ShutdownProtobufLibrary();
29+
} // namespace protobuf
30+
} // namespace google
31+
32+
extern "C" {
33+
//
34+
// Public symbols
35+
//
36+
EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base,
37+
const OrtLogger* default_logger,
38+
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
39+
// Manual init for the C++ API
40+
onnxruntime::ep::ApiInit(ort_api_base);
41+
42+
if (max_factories < 1) {
43+
return onnxruntime::ep::Api().ort.CreateStatus(ORT_INVALID_ARGUMENT,
44+
"Not enough space to return EP factory. Need at least one.");
45+
}
46+
47+
// Initialize the global default logger
48+
::onnxruntime::ep::adapter::Logger::CreateDefaultLogger(default_logger);
49+
50+
// Factory could use registration_name or define its own EP name.
51+
std::unique_ptr<OrtEpFactory> factory = std::make_unique<onnxruntime::webgpu::ep::Factory>();
52+
53+
factories[0] = factory.release();
54+
*num_factories = 1;
55+
56+
return nullptr;
57+
}
58+
59+
EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
60+
// STEP.1 - Release the factory
61+
delete static_cast<onnxruntime::webgpu::ep::Factory*>(factory);
62+
63+
// STEP.2 - Clean up cached kernel registries
64+
onnxruntime::webgpu::CleanupKernelRegistries();
65+
66+
// STEP.3 - Clean up WebGPU contexts
67+
onnxruntime::webgpu::CleanupWebGpuContexts();
68+
69+
// STEP.4 - Destroy the global default logger wrapper
70+
::onnxruntime::ep::adapter::Logger::DestroyDefaultLogger();
71+
72+
// STEP.5 - Shutdown protobuf library
73+
google::protobuf::ShutdownProtobufLibrary();
74+
75+
return nullptr;
76+
}
77+
78+
} // extern "C"

0 commit comments

Comments
 (0)