Skip to content

Commit 0c38d7c

Browse files
committed
[webgpu] make PIX frame generator per session
1 parent 2f0eb77 commit 0c38d7c

File tree

4 files changed

+30
-24
lines changed

4 files changed

+30
-24
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,6 @@ void WebGpuContext::Initialize(const WebGpuContextConfig& config) {
161161
} else {
162162
query_type_ = TimestampQueryType::None;
163163
}
164-
if (config.enable_pix_capture) {
165-
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
166-
// set pix frame generator
167-
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>(instance_,
168-
Device());
169-
#else
170-
ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)");
171-
#endif // ENABLE_PIX_FOR_WEBGPU_EP
172-
}
173164
});
174165
}
175166

@@ -757,14 +748,6 @@ void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) {
757748
num_pending_dispatches_ = 0;
758749
}
759750

760-
void WebGpuContext::OnRunEnd() {
761-
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
762-
if (pix_frame_generator_) {
763-
pix_frame_generator_->GeneratePIXFrame();
764-
}
765-
#endif // ENABLE_PIX_FOR_WEBGPU_EP
766-
}
767-
768751
void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder,
769752
const std::vector<WGPUBuffer>& bind_buffers,
770753
const std::vector<uint32_t>& bind_buffers_segments,

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ struct WebGpuContextConfig {
7878
0
7979
#endif
8080
};
81-
bool enable_pix_capture{false};
8281
};
8382

8483
class WebGpuContextFactory {
@@ -215,7 +214,13 @@ class WebGpuContext final {
215214
Status PopErrorScope();
216215

217216
Status Run(ComputeContextBase& context, const ProgramBase& program);
218-
void OnRunEnd();
217+
218+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
219+
std::unique_ptr<WebGpuPIXFrameGenerator> CreatePIXFrameGenerator() {
220+
return std::make_unique<WebGpuPIXFrameGenerator>(instance_,
221+
Device());
222+
}
223+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
219224

220225
private:
221226
enum class TimestampQueryType {
@@ -334,10 +339,6 @@ class WebGpuContext final {
334339

335340
// External vector to store captured commands, owned by EP
336341
std::vector<webgpu::CapturedCommandInfo>* external_captured_commands_ = nullptr;
337-
338-
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
339-
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
340-
#endif // ENABLE_PIX_FOR_WEBGPU_EP
341342
};
342343

343344
} // namespace webgpu

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,15 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
810810
webgpu::BufferCacheMode::GraphSimple,
811811
webgpu::BufferCacheMode::Disabled);
812812
}
813+
814+
if (config.enable_pix_capture) {
815+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
816+
// set pix frame generator
817+
pix_frame_generator_ = context_.CreatePIXFrameGenerator();
818+
#else
819+
ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)");
820+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
821+
}
813822
}
814823

815824
std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
@@ -1008,7 +1017,11 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
10081017
context_.CollectProfilingData(profiler_->Events());
10091018
}
10101019

1011-
context_.OnRunEnd();
1020+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
1021+
if (pix_frame_generator_) {
1022+
pix_frame_generator_->GeneratePIXFrame();
1023+
}
1024+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
10121025

10131026
if (context_.ValidationMode() >= ValidationMode::Basic) {
10141027
return context_.PopErrorScope();

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include "core/providers/providers.h"
1111
#include "core/providers/webgpu/buffer_manager.h"
1212

13+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
14+
#include "core/providers/webgpu/webgpu_pix_frame_generator.h"
15+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
16+
1317
struct pthreadpool;
1418
namespace onnxruntime {
1519
namespace webgpu {
@@ -29,6 +33,7 @@ struct CapturedCommandInfo;
2933
struct WebGpuExecutionProviderConfig {
3034
DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default
3135
bool enable_graph_capture{false}; // graph capture feature is disabled by default
36+
bool enable_pix_capture{false}; // PIX capture is disabled by default
3237
std::vector<std::string> force_cpu_node_names{};
3338
};
3439

@@ -92,6 +97,10 @@ class WebGpuExecutionProvider : public IExecutionProvider {
9297
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
9398
int m_current_graph_annotation_id = 0;
9499

100+
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
101+
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
102+
#endif // ENABLE_PIX_FOR_WEBGPU_EP
103+
95104
// Buffer manager specifically for graph capture mode
96105
std::unique_ptr<webgpu::BufferManager> graph_buffer_mgr_ = nullptr;
97106

0 commit comments

Comments
 (0)