|
2 | 2 | // Licensed under the MIT License. |
3 | 3 |
|
4 | 4 | #include <charconv> |
| 5 | +#include <mutex> |
5 | 6 |
|
6 | 7 | #include "core/framework/error_code_helper.h" |
7 | 8 | #include "core/providers/webgpu/buffer_manager.h" |
|
12 | 13 | #include "core/session/ort_apis.h" |
13 | 14 |
|
14 | 15 | #include "core/providers/webgpu/webgpu_provider_options.h" |
| 16 | +#include "core/providers/webgpu/data_transfer.h" |
15 | 17 | using namespace onnxruntime::webgpu::options; |
16 | 18 |
|
17 | 19 | namespace onnxruntime { |
| 20 | +// Helper struct that holds configuration parameters for creating a WebGPU context with default settings. |
| 21 | +// This is used during lazy initialization of the data transfer to create a context if one doesn't exist. |
| 22 | +struct WebGpuContextParams { |
| 23 | + webgpu::WebGpuContextConfig context_config; // WebGPU context configuration |
| 24 | + webgpu::WebGpuBufferCacheConfig buffer_cache_config; // Buffer cache settings |
| 25 | + int backend_type; // Dawn backend type (D3D12, Vulkan, etc.) |
| 26 | + bool enable_pix_capture; // Enable PIX GPU capture for debugging |
| 27 | +}; |
| 28 | + |
| 29 | +static WebGpuContextParams GetDefaultWebGpuContextParams() { |
| 30 | + WebGpuContextParams params; |
| 31 | + params.context_config.context_id = 0; |
| 32 | + params.context_config.instance = nullptr; |
| 33 | + params.context_config.device = nullptr; |
| 34 | + params.context_config.dawn_proc_table = nullptr; |
| 35 | + params.context_config.validation_mode = webgpu::ValidationMode::Disabled; |
| 36 | + params.context_config.preserve_device = false; |
| 37 | + params.context_config.max_storage_buffer_binding_size = 0; |
| 38 | + params.context_config.power_preference = static_cast<int>(WGPUPowerPreference_HighPerformance); |
| 39 | + |
| 40 | + params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket; |
| 41 | + params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple; |
| 42 | + params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled; |
| 43 | + params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled; |
| 44 | + |
| 45 | +#ifdef _WIN32 |
| 46 | +#if defined(DAWN_ENABLE_D3D12) |
| 47 | + params.backend_type = static_cast<int>(WGPUBackendType_D3D12); |
| 48 | +#elif defined(DAWN_ENABLE_VULKAN) |
| 49 | + params.backend_type = static_cast<int>(WGPUBackendType_Vulkan); |
| 50 | +#else |
| 51 | + params.backend_type = static_cast<int>(WGPUBackendType_D3D12); |
| 52 | +#endif |
| 53 | +#else |
| 54 | + params.backend_type = 0; |
| 55 | +#endif |
| 56 | + params.enable_pix_capture = false; |
| 57 | + return params; |
| 58 | +} |
18 | 59 |
|
19 | 60 | struct WebGpuProviderFactory : IExecutionProviderFactory { |
20 | 61 | WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config) |
@@ -291,4 +332,134 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create( |
291 | 332 | return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config)); |
292 | 333 | } |
293 | 334 |
|
| 335 | +// WebGPU DataTransfer implementation wrapper for the C API with lazy initialization |
| 336 | +struct WebGpuDataTransferImpl : OrtDataTransferImpl { |
| 337 | + WebGpuDataTransferImpl(const OrtApi& ort_api_in) |
| 338 | + : ort_api{ort_api_in}, |
| 339 | + ep_api{*ort_api_in.GetEpApi()}, |
| 340 | + data_transfer_{nullptr}, |
| 341 | + context_id_{0}, // Always use context 0 for Environment's data transfer |
| 342 | + init_mutex_{} { |
| 343 | + ort_version_supported = ORT_API_VERSION; |
| 344 | + CanCopy = CanCopyImpl; // OrtDataTransferImpl::CanCopy callback |
| 345 | + CopyTensors = CopyTensorsImpl; // OrtDataTransferImpl::CopyTensors callback |
| 346 | + Release = ReleaseImpl; // OrtDataTransferImpl::Release callback |
| 347 | + } |
| 348 | + |
| 349 | + static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr, |
| 350 | + const OrtMemoryDevice* src_memory_device, |
| 351 | + const OrtMemoryDevice* dst_memory_device) noexcept { |
| 352 | + const auto& impl = *static_cast<const WebGpuDataTransferImpl*>(this_ptr); |
| 353 | + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); |
| 354 | + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); |
| 355 | + |
| 356 | + // Check if at least one device is GPU |
| 357 | + bool has_gpu = (src_type == OrtMemoryInfoDeviceType_GPU) || (dst_type == OrtMemoryInfoDeviceType_GPU); |
| 358 | + if (!has_gpu) { |
| 359 | + return false; |
| 360 | + } |
| 361 | + |
| 362 | + // WebGPU uses vendor ID 0 (VendorIds::NONE). Only handle GPU devices with vendor ID 0. |
| 363 | + // This prevents attempting to copy data for other EPs' fake GPU devices (e.g., example EP with vendor 0xBE57) |
| 364 | + if (src_type == OrtMemoryInfoDeviceType_GPU) { |
| 365 | + uint32_t src_vendor = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); |
| 366 | + if (src_vendor != 0) { |
| 367 | + return false; // Not a WebGPU device |
| 368 | + } |
| 369 | + } |
| 370 | + |
| 371 | + if (dst_type == OrtMemoryInfoDeviceType_GPU) { |
| 372 | + uint32_t dst_vendor = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); |
| 373 | + if (dst_vendor != 0) { |
| 374 | + return false; // Not a WebGPU device |
| 375 | + } |
| 376 | + } |
| 377 | + |
| 378 | + // If both are GPU, they must have the same device ID |
| 379 | + if (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) { |
| 380 | + uint64_t src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); |
| 381 | + uint64_t dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); |
| 382 | + if (src_device_id != dst_device_id) { |
| 383 | + return false; // Cannot copy between different devices |
| 384 | + } |
| 385 | + } |
| 386 | + |
| 387 | + // WebGPU supports GPU<->GPU, GPU<->CPU copies (where GPU has vendor ID 0) |
| 388 | + return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) || |
| 389 | + (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) || |
| 390 | + (src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU); |
| 391 | + } |
| 392 | + |
| 393 | + static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr, |
| 394 | + const OrtValue** src_tensors, |
| 395 | + OrtValue** dst_tensors, |
| 396 | + OrtSyncStream** /*streams*/, |
| 397 | + size_t num_tensors) noexcept { |
| 398 | + auto& impl = *static_cast<WebGpuDataTransferImpl*>(this_ptr); |
| 399 | + |
| 400 | + if (num_tensors == 0) { |
| 401 | + return nullptr; |
| 402 | + } |
| 403 | + |
| 404 | + // Lazy initialization: Use double-checked locking to avoid unnecessary lock operations |
| 405 | + if (impl.data_transfer_ == nullptr) { |
| 406 | + std::lock_guard<std::mutex> lock(impl.init_mutex_); |
| 407 | + if (impl.data_transfer_ == nullptr) { |
| 408 | + // Always create a new context with context_id 0 |
| 409 | + WebGpuContextParams params = GetDefaultWebGpuContextParams(); |
| 410 | + params.context_config.context_id = impl.context_id_; |
| 411 | + auto* context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config); |
| 412 | + context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture); |
| 413 | + |
| 414 | + // Create the DataTransfer instance |
| 415 | + // Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle |
| 416 | + // is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists |
| 417 | + // for the lifetime of the application, ensuring the reference remains valid. |
| 418 | + impl.data_transfer_ = std::make_unique<webgpu::DataTransfer>(context_ptr->BufferManager()); |
| 419 | + } |
| 420 | + } |
| 421 | + |
| 422 | + // Now perform the actual tensor copy |
| 423 | + for (size_t idx = 0; idx < num_tensors; ++idx) { |
| 424 | + const OrtValue* src_tensor = src_tensors[idx]; |
| 425 | + OrtValue* dst_tensor = dst_tensors[idx]; |
| 426 | + auto status = impl.data_transfer_->CopyTensor(src_tensor->Get<Tensor>(), *dst_tensor->GetMutable<Tensor>()); |
| 427 | + if (!status.IsOK()) { |
| 428 | + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str()); |
| 429 | + } |
| 430 | + } |
| 431 | + return nullptr; |
| 432 | + } |
| 433 | + |
| 434 | + static void ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept { |
| 435 | + auto* p_impl = static_cast<WebGpuDataTransferImpl*>(this_ptr); |
| 436 | + int context_id = p_impl->context_id_; |
| 437 | + bool data_transfer_initialized = false; |
| 438 | + { |
| 439 | + std::lock_guard<std::mutex> lock(p_impl->init_mutex_); |
| 440 | + data_transfer_initialized = (p_impl->data_transfer_ != nullptr); |
| 441 | + } |
| 442 | + delete p_impl; |
| 443 | + if (data_transfer_initialized) { |
| 444 | + webgpu::WebGpuContextFactory::ReleaseContext(context_id); |
| 445 | + } |
| 446 | + } |
| 447 | + |
| 448 | + const OrtApi& ort_api; |
| 449 | + const OrtEpApi& ep_api; |
| 450 | + std::unique_ptr<webgpu::DataTransfer> data_transfer_; // Lazy-initialized |
| 451 | + int context_id_; // Track which context we're using |
| 452 | + std::mutex init_mutex_; // Protects lazy initialization |
| 453 | +}; |
| 454 | + |
| 455 | +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() { |
| 456 | + // Validate API version is supported |
| 457 | + const OrtApi* api = OrtApis::GetApi(ORT_API_VERSION); |
| 458 | + if (!api) { |
| 459 | + // API version not supported - return nullptr to indicate failure |
| 460 | + return nullptr; |
| 461 | + } |
| 462 | + return new WebGpuDataTransferImpl(*api); |
| 463 | +} |
| 464 | + |
294 | 465 | } // namespace onnxruntime |
0 commit comments