Skip to content

Commit 5ef1832

Browse files
authored
[WebGPU] Support PIX Capture for WebGPU EP (#23192)
PIX Capture tool requires 'present' to end a frame capture. ORT doesn't have rendering work so no 'present' happens. To avoid endless waiting for PIX capture tool, this PR added a blank surface and 'present' on it in each session run. The surface is created in WebGPU ep constructor and closed in WebGPU ep destructor. ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 0114551 commit 5ef1832

File tree

11 files changed

+225
-10
lines changed

11 files changed

+225
-10
lines changed

cmake/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C+
139139
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)
140140
option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.")
141141
option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF)
142+
option(onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP "Adding frame present for PIX to capture a frame" OFF)
142143
# The following 2 options are only for Windows
143144
option(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN "Enable Vulkan backend for Dawn (on Windows)" OFF)
144145
option(onnxruntime_ENABLE_DAWN_BACKEND_D3D12 "Enable D3D12 backend for Dawn (on Windows)" ON)
@@ -1038,6 +1039,14 @@ if (onnxruntime_USE_WEBGPU)
10381039
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
10391040
list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1)
10401041
endif()
1042+
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
1043+
if (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12 OR NOT WIN32)
1044+
message(
1045+
FATAL_ERROR
1046+
"Option onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP can only be set on windows with onnxruntime_ENABLE_DAWN_BACKEND_D3D12 is enabled.")
1047+
endif()
1048+
add_compile_definitions(ENABLE_PIX_FOR_WEBGPU_EP)
1049+
endif()
10411050
endif()
10421051
if (onnxruntime_USE_CANN)
10431052
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)

cmake/external/onnxruntime_external_deps.cmake

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,24 @@ if (onnxruntime_USE_WEBGPU)
685685
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
686686
endif()
687687

688+
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
689+
set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE)
690+
set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE)
691+
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE)
692+
set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE)
693+
set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE)
694+
set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE)
695+
set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE)
696+
else()
697+
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
698+
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
699+
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
700+
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
701+
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
702+
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
703+
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
704+
endif()
705+
688706
# disable things we don't use
689707
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
690708
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
@@ -741,6 +759,10 @@ if (onnxruntime_USE_WEBGPU)
741759
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
742760
endif()
743761
endif()
762+
763+
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
764+
list(APPEND onnxruntime_EXTERNAL_LIBRARIES glfw webgpu_glfw)
765+
endif()
744766
endif()
745767

746768
if(onnxruntime_USE_COREML)

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
namespace onnxruntime {
3636
namespace webgpu {
3737

38-
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type) {
39-
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() {
38+
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
39+
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
4040
// Create wgpu::Adapter
4141
if (adapter_ == nullptr) {
4242
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
@@ -162,6 +162,16 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
162162
} else {
163163
query_type_ = TimestampQueryType::None;
164164
}
165+
if (enable_pix_capture) {
166+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
167+
// set pix frame generator
168+
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>(instance_,
169+
Adapter(),
170+
Device());
171+
#else
172+
ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)");
173+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
174+
}
165175
});
166176
}
167177

@@ -680,6 +690,14 @@ void WebGpuContext::Flush() {
680690
num_pending_dispatches_ = 0;
681691
}
682692

693+
void WebGpuContext::OnRunEnd() {
694+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
695+
if (pix_frame_generator_) {
696+
pix_frame_generator_->GeneratePIXFrame();
697+
}
698+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
699+
}
700+
683701
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
684702
std::mutex WebGpuContextFactory::mutex_;
685703
std::once_flag WebGpuContextFactory::init_default_flag_;

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
#include "core/providers/webgpu/buffer_manager.h"
1515
#include "core/providers/webgpu/program_manager.h"
1616

17+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
18+
#include "core/providers/webgpu/webgpu_pix_frame_generator.h"
19+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
20+
1721
namespace onnxruntime {
1822
class Tensor;
1923

@@ -68,7 +72,7 @@ class WebGpuContextFactory {
6872
// Class WebGpuContext includes all necessary resources for the context.
6973
class WebGpuContext final {
7074
public:
71-
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type);
75+
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture);
7276

7377
Status Wait(wgpu::Future f);
7478

@@ -136,6 +140,7 @@ class WebGpuContext final {
136140
Status PopErrorScope();
137141

138142
Status Run(ComputeContext& context, const ProgramBase& program);
143+
void OnRunEnd();
139144

140145
private:
141146
enum class TimestampQueryType {
@@ -222,6 +227,10 @@ class WebGpuContext final {
222227

223228
uint64_t gpu_timestamp_offset_ = 0;
224229
bool is_profiling_ = false;
230+
231+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
232+
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
233+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
225234
};
226235

227236
} // namespace webgpu

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
747747
context_{context},
748748
preferred_data_layout_{config.data_layout},
749749
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
750-
enable_graph_capture_{config.enable_graph_capture} {
751-
}
750+
enable_graph_capture_{config.enable_graph_capture} {}
752751

753752
std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
754753
AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) {
@@ -862,11 +861,13 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
862861
context_.CollectProfilingData(profiler_->Events());
863862
}
864863

864+
context_.OnRunEnd();
865+
865866
if (context_.ValidationMode() >= ValidationMode::Basic) {
866867
return context_.PopErrorScope();
868+
} else {
869+
return Status::OK();
867870
}
868-
869-
return Status::OK();
870871
}
871872

872873
bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ class WebGpuProfiler;
2323
} // namespace webgpu
2424

2525
struct WebGpuExecutionProviderConfig {
26-
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture)
26+
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture, bool enable_pix_capture)
2727
: data_layout{data_layout},
28-
enable_graph_capture{enable_graph_capture} {}
28+
enable_graph_capture{enable_graph_capture},
29+
enable_pix_capture{enable_pix_capture} {}
2930
WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default;
3031
WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default;
3132
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig);
3233

3334
DataLayout data_layout;
3435
bool enable_graph_capture;
36+
bool enable_pix_capture;
3537
std::vector<std::string> force_cpu_node_names;
3638
};
3739

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
5+
6+
#include <webgpu/webgpu_glfw.h>
7+
8+
#include "core/common/common.h"
9+
#include "core/providers/webgpu/webgpu_pix_frame_generator.h"
10+
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
14+
WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device) {
15+
// Trivial window size for surface texture creation and provide frame concept for PIX.
16+
static constexpr uint32_t kWidth = 512u;
17+
static constexpr uint32_t kHeight = 512u;
18+
19+
if (!glfwInit()) {
20+
ORT_ENFORCE("Failed to init glfw for PIX capture");
21+
}
22+
23+
glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API);
24+
25+
window_ =
26+
glfwCreateWindow(kWidth, kHeight, "WebGPU window", nullptr, nullptr);
27+
28+
ORT_ENFORCE(window_ != nullptr, "PIX Capture: Failed to create Window for capturing frames.");
29+
30+
surface_ = wgpu::glfw::CreateSurfaceForWindow(instance, window_);
31+
ORT_ENFORCE(surface_.Get() != nullptr, "PIX Capture: Failed to create surface for capturing frames.");
32+
33+
wgpu::TextureFormat format;
34+
wgpu::SurfaceCapabilities capabilities;
35+
surface_.GetCapabilities(adapter, &capabilities);
36+
format = capabilities.formats[0];
37+
38+
wgpu::SurfaceConfiguration config;
39+
config.device = device;
40+
config.format = format;
41+
config.width = kWidth;
42+
config.height = kHeight;
43+
44+
surface_.Configure(&config);
45+
}
46+
47+
void WebGpuPIXFrameGenerator::GeneratePIXFrame() {
48+
ORT_ENFORCE(surface_.Get() != nullptr, "PIX Capture: Cannot do present on null surface for capturing frames");
49+
wgpu::SurfaceTexture surfaceTexture;
50+
surface_.GetCurrentTexture(&surfaceTexture);
51+
52+
// Call present to trigger dxgi_swapchain present. PIX
53+
// take this as a frame boundary.
54+
surface_.Present();
55+
}
56+
57+
WebGpuPIXFrameGenerator::~WebGpuPIXFrameGenerator() {
58+
if (surface_.Get()) {
59+
surface_.Unconfigure();
60+
}
61+
62+
if (window_) {
63+
glfwDestroyWindow(window_);
64+
window_ = nullptr;
65+
}
66+
}
67+
68+
} // namespace webgpu
69+
} // namespace onnxruntime
70+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
5+
#pragma once
6+
7+
#ifdef __EMSCRIPTEN__
8+
#include <emscripten/emscripten.h>
9+
#endif
10+
11+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
12+
#include <GLFW/glfw3.h>
13+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
14+
15+
#include <memory>
16+
17+
#include <webgpu/webgpu_cpp.h>
18+
19+
namespace onnxruntime {
20+
21+
namespace webgpu {
22+
23+
// PIX(https://devblogs.microsoft.com/pix/introduction/) is a profiling tool
24+
// provides by Microsoft. It has ability to do GPU capture to profile gpu
25+
// behavior among different GPU vendors. It works on Windows only.
26+
//
27+
// GPU capture(present-to-present) provided by PIX uses present as a frame boundary to
28+
// capture and generate a valid frame infos. But ORT WebGPU EP doesn't have any present logic
29+
// and hangs PIX GPU Capture forever.
30+
//
31+
// To make PIX works with ORT WebGPU EP on Windows, WebGpuPIXFrameGenerator class includes codes
32+
// to create a trivial window through glfw, config surface with Dawn device and call present in
33+
// proper place to trigger frame boundary for PIX GPU Capture.
34+
//
35+
// WebGpuPIXFrameGenerator is an friend class because:
36+
// - It should only be used in WebGpuContext class implementation.
37+
// - It requires instance and device from WebGpuContext.
38+
//
39+
// The lifecycle of WebGpuPIXFrameGenerator instance should be nested into WebGpuContext lifecycle.
40+
// WebGpuPIXFrameGenerator instance should be created during WebGpuContext creation and be destroyed during
41+
// WebGpuContext destruction.
42+
class WebGpuPIXFrameGenerator {
43+
public:
44+
WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device);
45+
~WebGpuPIXFrameGenerator();
46+
void GeneratePIXFrame();
47+
48+
private:
49+
void CreateSurface();
50+
wgpu::Surface surface_;
51+
GLFWwindow* window_;
52+
53+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuPIXFrameGenerator);
54+
};
55+
56+
} // namespace webgpu
57+
} // namespace onnxruntime
58+
#endif // ENABLE_PIX_FOR_WEBGPU_EP

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
4040
DataLayout::NHWC,
4141
// graph capture feature is disabled by default
4242
false,
43+
// enable pix capture feature is diabled by default
44+
false,
4345
};
4446

4547
std::string preferred_layout_str;
@@ -219,6 +221,19 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
219221
buffer_cache_config.default_entry.mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled);
220222
LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << buffer_cache_config.default_entry.mode;
221223

224+
bool enable_pix_capture = false;
225+
std::string enable_pix_capture_str;
226+
if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) {
227+
if (enable_pix_capture_str == kEnablePIXCapture_ON) {
228+
enable_pix_capture = true;
229+
} else if (enable_pix_capture_str == kEnablePIXCapture_OFF) {
230+
enable_pix_capture = false;
231+
} else {
232+
ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str);
233+
}
234+
}
235+
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << enable_pix_capture;
236+
222237
//
223238
// STEP.4 - start initialization.
224239
//
@@ -227,7 +242,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
227242
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);
228243

229244
// Create WebGPU device and initialize the context.
230-
context.Initialize(buffer_cache_config, backend_type);
245+
context.Initialize(buffer_cache_config, backend_type, enable_pix_capture);
231246

232247
// Create WebGPU EP factory.
233248
return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config));

onnxruntime/core/providers/webgpu/webgpu_provider_options.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ constexpr const char* kDefaultBufferCacheMode = "WebGPU:defaultBufferCacheMode";
2929
constexpr const char* kValidationMode = "WebGPU:validationMode";
3030

3131
constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames";
32+
constexpr const char* kEnablePIXCapture = "WebGPU:enablePIXCapture";
3233

3334
// The following are the possible values for the provider options.
3435

@@ -41,6 +42,9 @@ constexpr const char* kPreferredLayout_NHWC = "NHWC";
4142
constexpr const char* kEnableGraphCapture_ON = "1";
4243
constexpr const char* kEnableGraphCapture_OFF = "0";
4344

45+
constexpr const char* kEnablePIXCapture_ON = "1";
46+
constexpr const char* kEnablePIXCapture_OFF = "0";
47+
4448
constexpr const char* kBufferCacheMode_Disabled = "disabled";
4549
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
4650
constexpr const char* kBufferCacheMode_Simple = "simple";

0 commit comments

Comments
 (0)