Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <memory>
#include <cmath>
#include <string>

#if defined(__GNUC__)
#pragma GCC diagnostic push
Expand Down Expand Up @@ -792,11 +793,19 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
if (indirect_dispatch_tensor != nullptr) {
indirect_buffer = reinterpret_cast<WGPUBuffer>(const_cast<void*>(indirect_dispatch_tensor->DataRaw()));
}

// Always store profiling metadata to support profiling during replay regardless of current profiling state
std::optional<PendingKernelInfo> profiling_data;
if (!pending_kernels_.empty()) {
profiling_data = pending_kernels_.back();
}

external_captured_commands_->push_back({program_artifact.compute_pipeline,
bind_group,
bind_group_layout,
{x, y, z},
indirect_buffer});
indirect_buffer,
profiling_data});
} else {
compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline);
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr);
Expand Down Expand Up @@ -827,9 +836,6 @@ void WebGpuContext::CaptureBegin(std::vector<webgpu::CapturedCommandInfo>* captu
external_captured_commands_->clear();
}

// TODO: support profiling with graph capture.
ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode");

graph_capture_state_ = GraphCaptureState::Capturing;
}

Expand All @@ -842,6 +848,12 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
auto& command = captured_commands[i];
const auto& compute_pass_encoder = GetComputePassEncoder();
WriteTimestamp(num_pending_dispatches_ * 2);

// Restore profiling info if available and profiling is enabled
if (is_profiling_ && command.pending_kernel_info.has_value()) {
pending_kernels_.emplace_back(command.pending_kernel_info.value());
}

compute_pass_encoder.SetPipeline(command.compute_pipeline);
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr);

Expand Down
66 changes: 36 additions & 30 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <memory>
#include <mutex>
#include <optional>

#include "core/providers/webgpu/webgpu_external_header.h"

Expand All @@ -25,13 +26,47 @@ class WebGpuContext;
class ComputeContextBase;
class ProgramBase;

// PendingKernelInfo stores profiling information for a kernel execution
struct PendingKernelInfo {
PendingKernelInfo(std::string_view kernel_name,
std::string_view kernel_type,
std::string_view program_name,
std::string_view cache_key,
const std::vector<ProgramInput>& inputs,
const std::vector<ProgramOutput>& outputs)
: name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} {
// Store shape information instead of tensor pointers to avoid accessing released tensors
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape());
}
output_shapes.reserve(outputs.size());
for (const auto& output : outputs) {
output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape());
}
}

PendingKernelInfo(const PendingKernelInfo&) = default;
PendingKernelInfo& operator=(const PendingKernelInfo&) = default;
PendingKernelInfo(PendingKernelInfo&&) = default;
PendingKernelInfo& operator=(PendingKernelInfo&&) = default;

std::string name;
std::string cache_key;
std::vector<TensorShape> input_shapes;
std::vector<TensorShape> output_shapes;
};

// Definition for CapturedCommandInfo in the webgpu namespace
struct CapturedCommandInfo {
wgpu::ComputePipeline compute_pipeline;
WGPUBindGroup bind_group;
WGPUBindGroupLayout bind_group_layout;
std::array<uint32_t, 3> dispatch_group;
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
// WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
WGPUBuffer indirect_buffer;
// Optional profiling data
std::optional<PendingKernelInfo> pending_kernel_info;
};

struct WebGpuBufferCacheConfig {
Expand Down Expand Up @@ -261,35 +296,6 @@ class WebGpuContext final {
wgpu::Limits GetRequiredLimits(const wgpu::Adapter& adapter) const;
void WriteTimestamp(uint32_t query_index);

struct PendingKernelInfo {
PendingKernelInfo(std::string_view kernel_name,
std::string_view kernel_type,
std::string_view program_name,
std::string_view cache_key,
const std::vector<ProgramInput>& inputs,
const std::vector<ProgramOutput>& outputs)
: name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} {
// Store shape information instead of tensor pointers to avoid accessing released tensors
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape());
}
output_shapes.reserve(outputs.size());
for (const auto& output : outputs) {
output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape());
}
}

PendingKernelInfo(PendingKernelInfo&&) = default;
PendingKernelInfo& operator=(PendingKernelInfo&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo);

std::string name;
std::string cache_key;
std::vector<TensorShape> input_shapes;
std::vector<TensorShape> output_shapes;
};

struct PendingQueryInfo {
PendingQueryInfo(std::vector<PendingKernelInfo>&& kernels, wgpu::Buffer query_buffer)
: kernels{std::move(kernels)}, query_buffer{query_buffer} {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,13 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {

Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
if (profiler_->Enabled()) {
context_.StartProfiling();
}
context_.Replay(captured_commands_, *graph_buffer_mgr_);
if (profiler_->Enabled()) {
context_.CollectProfilingData(profiler_->Events());
}
return Status::OK();
}

Expand Down
Loading