Skip to content

Commit 0aebe82

Browse files
authored
[WebGPU] Register DataTransfer to Env (#26450)
This pull request adds a C API for WebGPU data transfer, enabling tensor copying between CPU and GPU devices via the WebGPU execution provider. The main changes introduce a wrapper implementation for data transfer, integrate it with the plugin execution provider factory, and expose a creation function for use by the ONNX Runtime core.
1 parent 92c1ed2 commit 0aebe82

File tree

5 files changed

+201
-15
lines changed

5 files changed

+201
-15
lines changed

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include <charconv>
5+
#include <mutex>
56

67
#include "core/framework/error_code_helper.h"
78
#include "core/providers/webgpu/buffer_manager.h"
@@ -12,9 +13,49 @@
1213
#include "core/session/ort_apis.h"
1314

1415
#include "core/providers/webgpu/webgpu_provider_options.h"
16+
#include "core/providers/webgpu/data_transfer.h"
1517
using namespace onnxruntime::webgpu::options;
1618

1719
namespace onnxruntime {
20+
// Helper struct that holds configuration parameters for creating a WebGPU context with default settings.
21+
// This is used during lazy initialization of the data transfer to create a context if one doesn't exist.
22+
struct WebGpuContextParams {
23+
webgpu::WebGpuContextConfig context_config; // WebGPU context configuration
24+
webgpu::WebGpuBufferCacheConfig buffer_cache_config; // Buffer cache settings
25+
int backend_type; // Dawn backend type (D3D12, Vulkan, etc.)
26+
bool enable_pix_capture; // Enable PIX GPU capture for debugging
27+
};
28+
29+
static WebGpuContextParams GetDefaultWebGpuContextParams() {
30+
WebGpuContextParams params;
31+
params.context_config.context_id = 0;
32+
params.context_config.instance = nullptr;
33+
params.context_config.device = nullptr;
34+
params.context_config.dawn_proc_table = nullptr;
35+
params.context_config.validation_mode = webgpu::ValidationMode::Disabled;
36+
params.context_config.preserve_device = false;
37+
params.context_config.max_storage_buffer_binding_size = 0;
38+
params.context_config.power_preference = static_cast<int>(WGPUPowerPreference_HighPerformance);
39+
40+
params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket;
41+
params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple;
42+
params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled;
43+
params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled;
44+
45+
#ifdef _WIN32
46+
#if defined(DAWN_ENABLE_D3D12)
47+
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
48+
#elif defined(DAWN_ENABLE_VULKAN)
49+
params.backend_type = static_cast<int>(WGPUBackendType_Vulkan);
50+
#else
51+
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
52+
#endif
53+
#else
54+
params.backend_type = 0;
55+
#endif
56+
params.enable_pix_capture = false;
57+
return params;
58+
}
1859

1960
struct WebGpuProviderFactory : IExecutionProviderFactory {
2061
WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config)
@@ -291,4 +332,134 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
291332
return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config));
292333
}
293334

335+
// WebGPU DataTransfer implementation wrapper for the C API with lazy initialization
336+
struct WebGpuDataTransferImpl : OrtDataTransferImpl {
337+
WebGpuDataTransferImpl(const OrtApi& ort_api_in)
338+
: ort_api{ort_api_in},
339+
ep_api{*ort_api_in.GetEpApi()},
340+
data_transfer_{nullptr},
341+
context_id_{0}, // Always use context 0 for Environment's data transfer
342+
init_mutex_{} {
343+
ort_version_supported = ORT_API_VERSION;
344+
CanCopy = CanCopyImpl; // OrtDataTransferImpl::CanCopy callback
345+
CopyTensors = CopyTensorsImpl; // OrtDataTransferImpl::CopyTensors callback
346+
Release = ReleaseImpl; // OrtDataTransferImpl::Release callback
347+
}
348+
349+
static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr,
350+
const OrtMemoryDevice* src_memory_device,
351+
const OrtMemoryDevice* dst_memory_device) noexcept {
352+
const auto& impl = *static_cast<const WebGpuDataTransferImpl*>(this_ptr);
353+
OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device);
354+
OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device);
355+
356+
// Check if at least one device is GPU
357+
bool has_gpu = (src_type == OrtMemoryInfoDeviceType_GPU) || (dst_type == OrtMemoryInfoDeviceType_GPU);
358+
if (!has_gpu) {
359+
return false;
360+
}
361+
362+
// WebGPU uses vendor ID 0 (VendorIds::NONE). Only handle GPU devices with vendor ID 0.
363+
// This prevents attempting to copy data for other EPs' fake GPU devices (e.g., example EP with vendor 0xBE57)
364+
if (src_type == OrtMemoryInfoDeviceType_GPU) {
365+
uint32_t src_vendor = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device);
366+
if (src_vendor != 0) {
367+
return false; // Not a WebGPU device
368+
}
369+
}
370+
371+
if (dst_type == OrtMemoryInfoDeviceType_GPU) {
372+
uint32_t dst_vendor = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device);
373+
if (dst_vendor != 0) {
374+
return false; // Not a WebGPU device
375+
}
376+
}
377+
378+
// If both are GPU, they must have the same device ID
379+
if (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) {
380+
uint64_t src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device);
381+
uint64_t dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device);
382+
if (src_device_id != dst_device_id) {
383+
return false; // Cannot copy between different devices
384+
}
385+
}
386+
387+
// WebGPU supports GPU<->GPU, GPU<->CPU copies (where GPU has vendor ID 0)
388+
return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) ||
389+
(src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) ||
390+
(src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU);
391+
}
392+
393+
static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr,
394+
const OrtValue** src_tensors,
395+
OrtValue** dst_tensors,
396+
OrtSyncStream** /*streams*/,
397+
size_t num_tensors) noexcept {
398+
auto& impl = *static_cast<WebGpuDataTransferImpl*>(this_ptr);
399+
400+
if (num_tensors == 0) {
401+
return nullptr;
402+
}
403+
404+
// Lazy initialization: Use double-checked locking to avoid unnecessary lock operations
405+
if (impl.data_transfer_ == nullptr) {
406+
std::lock_guard<std::mutex> lock(impl.init_mutex_);
407+
if (impl.data_transfer_ == nullptr) {
408+
// Always create a new context with context_id 0
409+
WebGpuContextParams params = GetDefaultWebGpuContextParams();
410+
params.context_config.context_id = impl.context_id_;
411+
auto* context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config);
412+
context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture);
413+
414+
// Create the DataTransfer instance
415+
// Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle
416+
// is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists
417+
// for the lifetime of the application, ensuring the reference remains valid.
418+
impl.data_transfer_ = std::make_unique<webgpu::DataTransfer>(context_ptr->BufferManager());
419+
}
420+
}
421+
422+
// Now perform the actual tensor copy
423+
for (size_t idx = 0; idx < num_tensors; ++idx) {
424+
const OrtValue* src_tensor = src_tensors[idx];
425+
OrtValue* dst_tensor = dst_tensors[idx];
426+
auto status = impl.data_transfer_->CopyTensor(src_tensor->Get<Tensor>(), *dst_tensor->GetMutable<Tensor>());
427+
if (!status.IsOK()) {
428+
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str());
429+
}
430+
}
431+
return nullptr;
432+
}
433+
434+
static void ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept {
435+
auto* p_impl = static_cast<WebGpuDataTransferImpl*>(this_ptr);
436+
int context_id = p_impl->context_id_;
437+
bool data_transfer_initialized = false;
438+
{
439+
std::lock_guard<std::mutex> lock(p_impl->init_mutex_);
440+
data_transfer_initialized = (p_impl->data_transfer_ != nullptr);
441+
}
442+
delete p_impl;
443+
if (data_transfer_initialized) {
444+
webgpu::WebGpuContextFactory::ReleaseContext(context_id);
445+
}
446+
}
447+
448+
const OrtApi& ort_api;
449+
const OrtEpApi& ep_api;
450+
std::unique_ptr<webgpu::DataTransfer> data_transfer_; // Lazy-initialized
451+
int context_id_; // Track which context we're using
452+
std::mutex init_mutex_; // Protects lazy initialization
453+
};
454+
455+
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() {
456+
// Validate API version is supported
457+
const OrtApi* api = OrtApis::GetApi(ORT_API_VERSION);
458+
if (!api) {
459+
// API version not supported - return nullptr to indicate failure
460+
return nullptr;
461+
}
462+
return new WebGpuDataTransferImpl(*api);
463+
}
464+
294465
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@
1010

1111
#include "core/providers/webgpu/webgpu_provider_options.h"
1212

13+
struct OrtDataTransferImpl;
14+
1315
namespace onnxruntime {
1416
struct ConfigOptions;
1517

1618
struct WebGpuProviderFactoryCreator {
1719
static std::shared_ptr<IExecutionProviderFactory> Create(const ConfigOptions& config_options);
1820
};
1921

22+
// C API to create data transfer for WebGPU EP with lazy initialization
23+
// Context will be determined from tensors during the first CopyTensors call
24+
// Caller takes ownership of the returned OrtDataTransferImpl*
25+
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer();
26+
2027
} // namespace onnxruntime

onnxruntime/core/session/ort_env.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,20 @@ OrtEnv::OrtEnv(std::unique_ptr<onnxruntime::Environment> value1)
3636
}
3737

3838
OrtEnv::~OrtEnv() {
39-
#ifdef USE_WEBGPU
40-
webgpu::CleanupWebGpuContexts();
41-
#endif
42-
4339
// We don't support any shared providers in the minimal build yet
4440
#if !defined(ORT_MINIMAL_BUILD)
4541
UnloadSharedProviders();
4642
#endif
43+
44+
// Explicitly destroy the Environment first, which will properly clean up DataTransferManager
45+
// and call ReleaseImpl on WebGpuDataTransferImpl
46+
value_.reset();
47+
48+
// Now that Environment is destroyed and all data transfers are cleaned up,
49+
// we can safely cleanup any remaining WebGPU contexts
50+
#ifdef USE_WEBGPU
51+
webgpu::CleanupWebGpuContexts();
52+
#endif
4753
}
4854

4955
OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info,

onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,20 @@ OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* co
5757
return nullptr;
5858
}
5959

60-
/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of
61-
an InferenceSession.
62-
OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info,
63-
const OrtKeyValuePairs* allocator_options,
64-
OrtAllocator** allocator) noexcept override {
65-
*allocator = device_allocators[memory_info->device.Id()].get();
66-
}
60+
OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept {
61+
// Call the WebGPU provider's C API to create the data transfer
62+
// This is implemented in the WebGPU provider backend which has access to WebGPU headers
63+
*data_transfer = OrtWebGpuCreateDataTransfer();
64+
65+
// API version mismatch is a fatal error - return error status if creation failed
66+
if (*data_transfer == nullptr) {
67+
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION,
68+
"Failed to create WebGPU data transfer - API version mismatch.");
69+
}
6770

68-
OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override {
69-
// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors.
70-
*data_transfer = nullptr;
7171
return nullptr;
7272
}
73-
*/
73+
7474
} // namespace onnxruntime
7575

7676
#endif // USE_WEBGPU

onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class WebGpuEpFactory : public EpFactoryInternalImpl {
2929
const OrtSessionOptions* session_options,
3030
const OrtLogger* session_logger,
3131
std::unique_ptr<IExecutionProvider>* ep) noexcept override;
32+
33+
OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override;
3234
};
3335
} // namespace onnxruntime
3436

0 commit comments

Comments
 (0)