Skip to content

Commit b7c183a

Browse files
authored
[webgpu] Enable profiling for graph capture (#27058)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 2b08a0c commit b7c183a

File tree

3 files changed

+87
-47
lines changed

3 files changed

+87
-47
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <memory>
55
#include <cmath>
6+
#include <string>
67

78
#if defined(__GNUC__)
89
#pragma GCC diagnostic push
@@ -282,16 +283,6 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra
282283

283284
auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch);
284285

285-
if (is_profiling_) {
286-
PendingKernelInfo pending_kernel_info(context.NodeName(),
287-
context.OpType(),
288-
program.Name(),
289-
key,
290-
inputs,
291-
outputs);
292-
pending_kernels_.emplace_back(std::move(pending_kernel_info));
293-
}
294-
295286
LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")";
296287

297288
const auto* program_artifact = program_mgr_->Get(key);
@@ -480,6 +471,26 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra
480471
WriteTimestamp(num_pending_dispatches_ * 2 + 1);
481472
++num_pending_dispatches_;
482473

474+
// Update profiling data after LaunchComputePipeline
475+
if (is_profiling_) {
476+
PendingKernelInfo pending_kernel_info(context.NodeName(),
477+
context.OpType(),
478+
program.Name(),
479+
key,
480+
inputs,
481+
outputs);
482+
483+
if (graph_capture_state_ == GraphCaptureState::Capturing) {
484+
// Update the last captured command's profiling info
485+
if (external_captured_commands_ && !external_captured_commands_->empty()) {
486+
external_captured_commands_->back().pending_kernel_info = std::move(pending_kernel_info);
487+
}
488+
} else {
489+
// Add to pending kernels for current run profiling
490+
pending_kernels_.emplace_back(std::move(pending_kernel_info));
491+
}
492+
}
493+
483494
if (num_pending_dispatches_ >= max_num_pending_dispatches_ ||
484495
(is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) {
485496
EndComputePass();
@@ -577,7 +588,7 @@ wgpu::Limits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) cons
577588
}
578589

579590
void WebGpuContext::WriteTimestamp(uint32_t query_index) {
580-
if (!is_profiling_ || query_type_ != TimestampQueryType::InsidePasses) {
591+
if (!is_profiling_ || graph_capture_state_ == GraphCaptureState::Capturing || query_type_ != TimestampQueryType::InsidePasses) {
581592
return;
582593
}
583594

@@ -714,7 +725,11 @@ void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) {
714725

715726
EndComputePass();
716727

717-
if (is_profiling_ && num_pending_dispatches_ > 0) {
728+
if (is_profiling_ && num_pending_dispatches_ > 0 && graph_capture_state_ != GraphCaptureState::Capturing) {
729+
ORT_ENFORCE(num_pending_dispatches_ == pending_kernels_.size(),
730+
"Number of pending dispatches (", num_pending_dispatches_,
731+
") does not match pending kernels size (", pending_kernels_.size(), ")");
732+
718733
uint32_t query_count = num_pending_dispatches_ * 2;
719734
current_command_encoder_.ResolveQuerySet(
720735
query_set_,
@@ -793,11 +808,14 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
793808
if (indirect_dispatch_tensor != nullptr) {
794809
indirect_buffer = reinterpret_cast<WGPUBuffer>(const_cast<void*>(indirect_dispatch_tensor->DataRaw()));
795810
}
811+
812+
// Profiling data will be populated in Run() after this call returns.
796813
external_captured_commands_->push_back({program_artifact.compute_pipeline,
797814
bind_group,
798815
bind_group_layout,
799816
{x, y, z},
800-
indirect_buffer});
817+
indirect_buffer,
818+
std::nullopt});
801819
} else {
802820
compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline);
803821
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr);
@@ -828,9 +846,6 @@ void WebGpuContext::CaptureBegin(std::vector<webgpu::CapturedCommandInfo>* captu
828846
external_captured_commands_->clear();
829847
}
830848

831-
// TODO: support profiling with graph capture.
832-
ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode");
833-
834849
graph_capture_state_ = GraphCaptureState::Capturing;
835850
}
836851

@@ -843,6 +858,18 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
843858
auto& command = captured_commands[i];
844859
const auto& compute_pass_encoder = GetComputePassEncoder();
845860
WriteTimestamp(num_pending_dispatches_ * 2);
861+
862+
// Restore profiling info when profiling is enabled. All commands are expected
863+
// to have profiling data in this mode to keep pending_kernels_ consistent
864+
// with num_pending_dispatches_.
865+
if (is_profiling_) {
866+
ORT_ENFORCE(command.pending_kernel_info.has_value(),
867+
"WebGpuContext::Replay: profiling is enabled but captured command at index ",
868+
i,
869+
" is missing pending_kernel_info.");
870+
pending_kernels_.emplace_back(*command.pending_kernel_info);
871+
}
872+
846873
compute_pass_encoder.SetPipeline(command.compute_pipeline);
847874
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr);
848875

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <memory>
77
#include <mutex>
8+
#include <optional>
89

910
#include "core/providers/webgpu/webgpu_external_header.h"
1011

@@ -25,13 +26,47 @@ class WebGpuContext;
2526
class ComputeContextBase;
2627
class ProgramBase;
2728

29+
// PendingKernelInfo stores profiling information for a kernel execution
30+
struct PendingKernelInfo {
31+
PendingKernelInfo(std::string_view kernel_name,
32+
std::string_view kernel_type,
33+
std::string_view program_name,
34+
std::string_view cache_key,
35+
const std::vector<ProgramInput>& inputs,
36+
const std::vector<ProgramOutput>& outputs)
37+
: name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} {
38+
// Store shape information instead of tensor pointers to avoid accessing released tensors
39+
input_shapes.reserve(inputs.size());
40+
for (const auto& input : inputs) {
41+
input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape());
42+
}
43+
output_shapes.reserve(outputs.size());
44+
for (const auto& output : outputs) {
45+
output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape());
46+
}
47+
}
48+
49+
PendingKernelInfo(const PendingKernelInfo&) = default;
50+
PendingKernelInfo& operator=(const PendingKernelInfo&) = default;
51+
PendingKernelInfo(PendingKernelInfo&&) = default;
52+
PendingKernelInfo& operator=(PendingKernelInfo&&) = default;
53+
54+
std::string name;
55+
std::string cache_key;
56+
std::vector<TensorShape> input_shapes;
57+
std::vector<TensorShape> output_shapes;
58+
};
59+
2860
// Definition for CapturedCommandInfo in the webgpu namespace
2961
struct CapturedCommandInfo {
3062
wgpu::ComputePipeline compute_pipeline;
3163
WGPUBindGroup bind_group;
3264
WGPUBindGroupLayout bind_group_layout;
3365
std::array<uint32_t, 3> dispatch_group;
34-
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
66+
// WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
67+
WGPUBuffer indirect_buffer;
68+
// Optional profiling data
69+
std::optional<PendingKernelInfo> pending_kernel_info;
3570
};
3671

3772
struct WebGpuBufferCacheConfig {
@@ -145,7 +180,7 @@ class WebGpuContext final {
145180

146181
wgpu::ComputePassDescriptor compute_pass_desc{};
147182

148-
if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) {
183+
if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses && graph_capture_state_ != GraphCaptureState::Capturing) {
149184
wgpu::PassTimestampWrites timestampWrites = {
150185
nullptr,
151186
query_set_,
@@ -261,35 +296,6 @@ class WebGpuContext final {
261296
wgpu::Limits GetRequiredLimits(const wgpu::Adapter& adapter) const;
262297
void WriteTimestamp(uint32_t query_index);
263298

264-
struct PendingKernelInfo {
265-
PendingKernelInfo(std::string_view kernel_name,
266-
std::string_view kernel_type,
267-
std::string_view program_name,
268-
std::string_view cache_key,
269-
const std::vector<ProgramInput>& inputs,
270-
const std::vector<ProgramOutput>& outputs)
271-
: name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} {
272-
// Store shape information instead of tensor pointers to avoid accessing released tensors
273-
input_shapes.reserve(inputs.size());
274-
for (const auto& input : inputs) {
275-
input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape());
276-
}
277-
output_shapes.reserve(outputs.size());
278-
for (const auto& output : outputs) {
279-
output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape());
280-
}
281-
}
282-
283-
PendingKernelInfo(PendingKernelInfo&&) = default;
284-
PendingKernelInfo& operator=(PendingKernelInfo&&) = default;
285-
ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo);
286-
287-
std::string name;
288-
std::string cache_key;
289-
std::vector<TensorShape> input_shapes;
290-
std::vector<TensorShape> output_shapes;
291-
};
292-
293299
struct PendingQueryInfo {
294300
PendingQueryInfo(std::vector<PendingKernelInfo>&& kernels, wgpu::Buffer query_buffer)
295301
: kernels{std::move(kernels)}, query_buffer{query_buffer} {}

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,14 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
10751075

10761076
Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
10771077
ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
1078+
// TODO: enable profiling in run level
1079+
if (session_profiler_ && session_profiler_->Enabled()) {
1080+
context_.StartProfiling();
1081+
}
10781082
context_.Replay(captured_commands_, *graph_buffer_mgr_);
1083+
if (session_profiler_ && session_profiler_->Enabled()) {
1084+
context_.CollectProfilingData();
1085+
}
10791086
return Status::OK();
10801087
}
10811088

0 commit comments

Comments
 (0)