Skip to content

Commit df7d5ac

Browse files
committed
[webgpu] Enable profiling for graph capture
1 parent d8f0318 commit df7d5ac

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -792,11 +792,20 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
792792
if (indirect_dispatch_tensor != nullptr) {
793793
indirect_buffer = reinterpret_cast<WGPUBuffer>(const_cast<void*>(indirect_dispatch_tensor->DataRaw()));
794794
}
795+
796+
// Store profiling info if profiling is enabled
797+
std::optional<std::tuple<std::string, std::string, std::vector<TensorShape>, std::vector<TensorShape>>> profiling_data;
798+
if (is_profiling_ && !pending_kernels_.empty()) {
799+
const auto& kernel_info = pending_kernels_.back();
800+
profiling_data = std::make_tuple(kernel_info.name, kernel_info.cache_key, kernel_info.input_shapes, kernel_info.output_shapes);
801+
}
802+
795803
external_captured_commands_->push_back({program_artifact.compute_pipeline,
796804
bind_group,
797805
bind_group_layout,
798806
{x, y, z},
799-
indirect_buffer});
807+
indirect_buffer,
808+
profiling_data});
800809
} else {
801810
compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline);
802811
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr);
@@ -827,9 +836,6 @@ void WebGpuContext::CaptureBegin(std::vector<webgpu::CapturedCommandInfo>* captu
827836
external_captured_commands_->clear();
828837
}
829838

830-
// TODO: support profiling with graph capture.
831-
ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode");
832-
833839
graph_capture_state_ = GraphCaptureState::Capturing;
834840
}
835841

@@ -842,6 +848,13 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
842848
auto& command = captured_commands[i];
843849
const auto& compute_pass_encoder = GetComputePassEncoder();
844850
WriteTimestamp(num_pending_dispatches_ * 2);
851+
852+
// Restore profiling info if available and profiling is enabled
853+
if (is_profiling_ && command.pending_kernel_info.has_value()) {
854+
const auto& [name, cache_key, input_shapes, output_shapes] = command.pending_kernel_info.value();
855+
pending_kernels_.emplace_back(name, cache_key, input_shapes, output_shapes);
856+
}
857+
845858
compute_pass_encoder.SetPipeline(command.compute_pipeline);
846859
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr);
847860

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include <memory>
77
#include <mutex>
8+
#include <optional>
9+
#include <tuple>
810

911
#include "core/providers/webgpu/webgpu_external_header.h"
1012

@@ -31,7 +33,8 @@ struct CapturedCommandInfo {
3133
WGPUBindGroup bind_group;
3234
WGPUBindGroupLayout bind_group_layout;
3335
std::array<uint32_t, 3> dispatch_group;
34-
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
36+
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
37+
std::optional<std::tuple<std::string, std::string, std::vector<TensorShape>, std::vector<TensorShape>>> pending_kernel_info; // Optional profiling data: (name, cache_key, input_shapes, output_shapes)
3538
};
3639

3740
struct WebGpuBufferCacheConfig {
@@ -280,6 +283,16 @@ class WebGpuContext final {
280283
}
281284
}
282285

286+
// Constructor for replay - takes shapes directly
287+
PendingKernelInfo(std::string name_in,
288+
std::string cache_key_in,
289+
std::vector<TensorShape> input_shapes_in,
290+
std::vector<TensorShape> output_shapes_in)
291+
: name{std::move(name_in)},
292+
cache_key{std::move(cache_key_in)},
293+
input_shapes{std::move(input_shapes_in)},
294+
output_shapes{std::move(output_shapes_in)} {}
295+
283296
PendingKernelInfo(PendingKernelInfo&&) = default;
284297
PendingKernelInfo& operator=(PendingKernelInfo&&) = default;
285298
ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo);

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,13 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
10501050

10511051
Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
10521052
ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
1053+
if (profiler_->Enabled()) {
1054+
context_.StartProfiling();
1055+
}
10531056
context_.Replay(captured_commands_, *graph_buffer_mgr_);
1057+
if (profiler_->Enabled()) {
1058+
context_.CollectProfilingData(profiler_->Events());
1059+
}
10541060
return Status::OK();
10551061
}
10561062

0 commit comments

Comments
 (0)