@@ -37,8 +37,8 @@ namespace webgpu {
37
37
38
38
void WebGpuContext::Initialize (const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
39
39
std::call_once (init_flag_, [this , &buffer_cache_config, backend_type, enable_pix_capture]() {
40
- // Create wgpu::Adapter
41
- if (adapter_ == nullptr ) {
40
+ if (device_ == nullptr ) {
41
+ // Create wgpu::Adapter
42
42
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
43
43
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
44
44
//
@@ -77,20 +77,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
77
77
req_adapter_options.nextInChain = &adapter_toggles_desc;
78
78
#endif
79
79
80
+ wgpu::Adapter adapter;
80
81
ORT_ENFORCE (wgpu::WaitStatus::Success == instance_.WaitAny (instance_.RequestAdapter (
81
82
&req_adapter_options,
82
83
wgpu::CallbackMode::WaitAnyOnly,
83
84
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) {
84
85
ORT_ENFORCE (status == wgpu::RequestAdapterStatus::Success, " Failed to get a WebGPU adapter: " , std::string_view{message});
85
- *ptr = adapter;
86
+ *ptr = std::move ( adapter) ;
86
87
},
87
- &adapter_ ),
88
+ &adapter ),
88
89
UINT64_MAX));
89
- ORT_ENFORCE (adapter_ != nullptr , " Failed to get a WebGPU adapter." );
90
- }
90
+ ORT_ENFORCE (adapter != nullptr , " Failed to get a WebGPU adapter." );
91
91
92
- // Create wgpu::Device
93
- if (device_ == nullptr ) {
92
+ // Create wgpu::Device
94
93
wgpu::DeviceDescriptor device_desc = {};
95
94
96
95
#if !defined(__wasm__)
@@ -106,12 +105,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
106
105
device_toggles_desc.disabledToggles = disabled_device_toggles.data ();
107
106
#endif
108
107
109
- std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures (adapter_ );
108
+ std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures (adapter );
110
109
if (required_features.size () > 0 ) {
111
110
device_desc.requiredFeatures = required_features.data ();
112
111
device_desc.requiredFeatureCount = required_features.size ();
113
112
}
114
- wgpu::RequiredLimits required_limits = GetRequiredLimits (adapter_ );
113
+ wgpu::RequiredLimits required_limits = GetRequiredLimits (adapter );
115
114
device_desc.requiredLimits = &required_limits;
116
115
117
116
// TODO: revise temporary error handling
@@ -123,20 +122,20 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
123
122
LOGS_DEFAULT (INFO) << " WebGPU device lost (" << int (reason) << " ): " << std::string_view{message};
124
123
});
125
124
126
- ORT_ENFORCE (wgpu::WaitStatus::Success == instance_.WaitAny (adapter_ .RequestDevice (
125
+ ORT_ENFORCE (wgpu::WaitStatus::Success == instance_.WaitAny (adapter .RequestDevice (
127
126
&device_desc,
128
127
wgpu::CallbackMode::WaitAnyOnly,
129
128
[](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) {
130
129
ORT_ENFORCE (status == wgpu::RequestDeviceStatus::Success, " Failed to get a WebGPU device: " , std::string_view{message});
131
- *ptr = device;
130
+ *ptr = std::move ( device) ;
132
131
},
133
132
&device_),
134
133
UINT64_MAX));
135
134
ORT_ENFORCE (device_ != nullptr , " Failed to get a WebGPU device." );
136
135
}
137
136
138
137
// cache adapter info
139
- ORT_ENFORCE (Adapter ().GetInfo (&adapter_info_));
138
+ ORT_ENFORCE (Device ().GetAdapterInfo (&adapter_info_));
140
139
// cache device limits
141
140
wgpu::SupportedLimits device_supported_limits;
142
141
ORT_ENFORCE (Device ().GetLimits (&device_supported_limits));
@@ -706,13 +705,12 @@ wgpu::Instance WebGpuContextFactory::default_instance_;
706
705
WebGpuContext& WebGpuContextFactory::CreateContext (const WebGpuContextConfig& config) {
707
706
const int context_id = config.context_id ;
708
707
WGPUInstance instance = config.instance ;
709
- WGPUAdapter adapter = config.adapter ;
710
708
WGPUDevice device = config.device ;
711
709
712
710
if (context_id == 0 ) {
713
711
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
714
- ORT_ENFORCE (instance == nullptr && adapter == nullptr && device == nullptr ,
715
- " WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device." );
712
+ ORT_ENFORCE (instance == nullptr && device == nullptr ,
713
+ " WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device." );
716
714
717
715
std::call_once (init_default_flag_, [
718
716
#if !defined(__wasm__)
@@ -750,23 +748,22 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
750
748
});
751
749
instance = default_instance_.Get ();
752
750
} else {
753
- // for context ID > 0, user must provide custom WebGPU instance, adapter and device.
754
- ORT_ENFORCE (instance != nullptr && adapter != nullptr && device != nullptr ,
755
- " WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device." );
751
+ // for context ID > 0, user must provide custom WebGPU instance and device.
752
+ ORT_ENFORCE (instance != nullptr && device != nullptr ,
753
+ " WebGPU EP custom context (contextId>0) must have custom WebGPU instance and device." );
756
754
}
757
755
758
756
std::lock_guard<std::mutex> lock (mutex_);
759
757
760
758
auto it = contexts_.find (context_id);
761
759
if (it == contexts_.end ()) {
762
760
GSL_SUPPRESS (r.11 )
763
- auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext (instance, adapter, device, config.validation_mode ));
761
+ auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext (instance, device, config.validation_mode ));
764
762
it = contexts_.emplace (context_id, WebGpuContextFactory::WebGpuContextInfo{std::move (context), 0 }).first ;
765
763
} else if (context_id != 0 ) {
766
764
ORT_ENFORCE (it->second .context ->instance_ .Get () == instance &&
767
- it->second .context ->adapter_ .Get () == adapter &&
768
765
it->second .context ->device_ .Get () == device,
769
- " WebGPU EP context ID " , context_id, " is already created with different WebGPU instance, adapter or device." );
766
+ " WebGPU EP context ID " , context_id, " is already created with different WebGPU instance or device." );
770
767
}
771
768
it->second .ref_count ++;
772
769
return *it->second .context ;
0 commit comments