Skip to content

Commit e666503

Browse files
authored
[webgpu] no longer need pass-in gpu adapter for custom context (microsoft#23593)
### Description Remove the need to pass in the GPU adapter for the custom context. With the introduction of the `wgpuDeviceGetAdapterInfo` API, we no longer need user to specify the GPU adapter when creating a custom context.
1 parent af679a0 commit e666503

File tree

4 files changed

+22
-38
lines changed

4 files changed

+22
-38
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

+19-22
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ namespace webgpu {
3737

3838
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
3939
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
40-
// Create wgpu::Adapter
41-
if (adapter_ == nullptr) {
40+
if (device_ == nullptr) {
41+
// Create wgpu::Adapter
4242
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
4343
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
4444
//
@@ -77,20 +77,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
7777
req_adapter_options.nextInChain = &adapter_toggles_desc;
7878
#endif
7979

80+
wgpu::Adapter adapter;
8081
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(
8182
&req_adapter_options,
8283
wgpu::CallbackMode::WaitAnyOnly,
8384
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) {
8485
ORT_ENFORCE(status == wgpu::RequestAdapterStatus::Success, "Failed to get a WebGPU adapter: ", std::string_view{message});
85-
*ptr = adapter;
86+
*ptr = std::move(adapter);
8687
},
87-
&adapter_),
88+
&adapter),
8889
UINT64_MAX));
89-
ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter.");
90-
}
90+
ORT_ENFORCE(adapter != nullptr, "Failed to get a WebGPU adapter.");
9191

92-
// Create wgpu::Device
93-
if (device_ == nullptr) {
92+
// Create wgpu::Device
9493
wgpu::DeviceDescriptor device_desc = {};
9594

9695
#if !defined(__wasm__)
@@ -106,12 +105,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
106105
device_toggles_desc.disabledToggles = disabled_device_toggles.data();
107106
#endif
108107

109-
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter_);
108+
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter);
110109
if (required_features.size() > 0) {
111110
device_desc.requiredFeatures = required_features.data();
112111
device_desc.requiredFeatureCount = required_features.size();
113112
}
114-
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_);
113+
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter);
115114
device_desc.requiredLimits = &required_limits;
116115

117116
// TODO: revise temporary error handling
@@ -123,20 +122,20 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
123122
LOGS_DEFAULT(INFO) << "WebGPU device lost (" << int(reason) << "): " << std::string_view{message};
124123
});
125124

126-
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(
125+
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter.RequestDevice(
127126
&device_desc,
128127
wgpu::CallbackMode::WaitAnyOnly,
129128
[](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) {
130129
ORT_ENFORCE(status == wgpu::RequestDeviceStatus::Success, "Failed to get a WebGPU device: ", std::string_view{message});
131-
*ptr = device;
130+
*ptr = std::move(device);
132131
},
133132
&device_),
134133
UINT64_MAX));
135134
ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device.");
136135
}
137136

138137
// cache adapter info
139-
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_));
138+
ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_));
140139
// cache device limits
141140
wgpu::SupportedLimits device_supported_limits;
142141
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
@@ -706,13 +705,12 @@ wgpu::Instance WebGpuContextFactory::default_instance_;
706705
WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
707706
const int context_id = config.context_id;
708707
WGPUInstance instance = config.instance;
709-
WGPUAdapter adapter = config.adapter;
710708
WGPUDevice device = config.device;
711709

712710
if (context_id == 0) {
713711
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
714-
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
715-
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");
712+
ORT_ENFORCE(instance == nullptr && device == nullptr,
713+
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device.");
716714

717715
std::call_once(init_default_flag_, [
718716
#if !defined(__wasm__)
@@ -750,23 +748,22 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
750748
});
751749
instance = default_instance_.Get();
752750
} else {
753-
// for context ID > 0, user must provide custom WebGPU instance, adapter and device.
754-
ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr,
755-
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device.");
751+
// for context ID > 0, user must provide custom WebGPU instance and device.
752+
ORT_ENFORCE(instance != nullptr && device != nullptr,
753+
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance and device.");
756754
}
757755

758756
std::lock_guard<std::mutex> lock(mutex_);
759757

760758
auto it = contexts_.find(context_id);
761759
if (it == contexts_.end()) {
762760
GSL_SUPPRESS(r.11)
763-
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
761+
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode));
764762
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
765763
} else if (context_id != 0) {
766764
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
767-
it->second.context->adapter_.Get() == adapter &&
768765
it->second.context->device_.Get() == device,
769-
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device.");
766+
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device.");
770767
}
771768
it->second.ref_count++;
772769
return *it->second.context;

onnxruntime/core/providers/webgpu/webgpu_context.h

+2-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class ProgramBase;
2929
struct WebGpuContextConfig {
3030
int context_id;
3131
WGPUInstance instance;
32-
WGPUAdapter adapter;
3332
WGPUDevice device;
3433
const void* dawn_proc_table;
3534
ValidationMode validation_mode;
@@ -76,7 +75,6 @@ class WebGpuContext final {
7675

7776
Status Wait(wgpu::Future f);
7877

79-
const wgpu::Adapter& Adapter() const { return adapter_; }
8078
const wgpu::Device& Device() const { return device_; }
8179

8280
const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
@@ -149,8 +147,8 @@ class WebGpuContext final {
149147
AtPasses
150148
};
151149

152-
WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode)
153-
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
150+
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode)
151+
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
154152
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);
155153

156154
std::vector<const char*> GetEnabledAdapterToggles() const;
@@ -198,7 +196,6 @@ class WebGpuContext final {
198196
LibraryHandles modules_;
199197

200198
wgpu::Instance instance_;
201-
wgpu::Adapter adapter_;
202199
wgpu::Device device_;
203200

204201
webgpu::ValidationMode validation_mode_;

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

+1-10
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
106106
std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec);
107107
}
108108

109-
size_t webgpu_adapter = 0;
110-
std::string webgpu_adapter_str;
111-
if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) {
112-
static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch");
113-
ORT_ENFORCE(std::errc{} ==
114-
std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec);
115-
}
116-
117109
size_t webgpu_device = 0;
118110
std::string webgpu_device_str;
119111
if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) {
@@ -154,7 +146,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
154146
webgpu::WebGpuContextConfig context_config{
155147
context_id,
156148
reinterpret_cast<WGPUInstance>(webgpu_instance),
157-
reinterpret_cast<WGPUAdapter>(webgpu_adapter),
158149
reinterpret_cast<WGPUDevice>(webgpu_device),
159150
reinterpret_cast<const void*>(dawn_proc_table),
160151
validation_mode,
@@ -238,7 +229,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
238229
// STEP.4 - start initialization.
239230
//
240231

241-
// Load the Dawn library and create the WebGPU instance and adapter.
232+
// Load the Dawn library and create the WebGPU instance.
242233
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);
243234

244235
// Create WebGPU device and initialize the context.

onnxruntime/core/providers/webgpu/webgpu_provider_options.h

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ constexpr const char* kDawnBackendType = "WebGPU:dawnBackendType";
1818

1919
constexpr const char* kDeviceId = "WebGPU:deviceId";
2020
constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance";
21-
constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter";
2221
constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice";
2322

2423
constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode";

0 commit comments

Comments
 (0)