Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <memory>
#include <cmath>
#include <string>

#if defined(__GNUC__)
#pragma GCC diagnostic push
Expand Down Expand Up @@ -282,16 +283,6 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra

auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch);

if (is_profiling_) {
PendingKernelInfo pending_kernel_info(context.NodeName(),
context.OpType(),
program.Name(),
key,
inputs,
outputs);
pending_kernels_.emplace_back(std::move(pending_kernel_info));
}

LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")";

const auto* program_artifact = program_mgr_->Get(key);
Expand Down Expand Up @@ -480,6 +471,26 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra
WriteTimestamp(num_pending_dispatches_ * 2 + 1);
++num_pending_dispatches_;

// Update profiling data after LaunchComputePipeline
if (is_profiling_) {
PendingKernelInfo pending_kernel_info(context.NodeName(),
context.OpType(),
program.Name(),
key,
inputs,
outputs);

if (graph_capture_state_ == GraphCaptureState::Capturing) {
// Update the last captured command's profiling info
if (external_captured_commands_ && !external_captured_commands_->empty()) {
external_captured_commands_->back().pending_kernel_info = std::move(pending_kernel_info);
}
} else {
// Add to pending kernels for current run profiling
pending_kernels_.emplace_back(std::move(pending_kernel_info));
}
}

if (num_pending_dispatches_ >= max_num_pending_dispatches_ ||
(is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) {
EndComputePass();
Expand Down Expand Up @@ -577,7 +588,7 @@ wgpu::Limits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) cons
}

void WebGpuContext::WriteTimestamp(uint32_t query_index) {
if (!is_profiling_ || query_type_ != TimestampQueryType::InsidePasses) {
if (!is_profiling_ || graph_capture_state_ == GraphCaptureState::Capturing || query_type_ != TimestampQueryType::InsidePasses) {
return;
}

Expand Down Expand Up @@ -714,7 +725,11 @@ void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) {

EndComputePass();

if (is_profiling_ && num_pending_dispatches_ > 0) {
if (is_profiling_ && num_pending_dispatches_ > 0 && graph_capture_state_ != GraphCaptureState::Capturing) {
ORT_ENFORCE(num_pending_dispatches_ == pending_kernels_.size(),
"Number of pending dispatches (", num_pending_dispatches_,
") does not match pending kernels size (", pending_kernels_.size(), ")");

uint32_t query_count = num_pending_dispatches_ * 2;
current_command_encoder_.ResolveQuerySet(
query_set_,
Expand Down Expand Up @@ -793,11 +808,14 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
if (indirect_dispatch_tensor != nullptr) {
indirect_buffer = reinterpret_cast<WGPUBuffer>(const_cast<void*>(indirect_dispatch_tensor->DataRaw()));
}

// Profiling data will be populated in Run() after this call returns.
external_captured_commands_->push_back({program_artifact.compute_pipeline,
bind_group,
bind_group_layout,
{x, y, z},
indirect_buffer});
indirect_buffer,
std::nullopt});
} else {
compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline);
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr);
Expand Down Expand Up @@ -828,9 +846,6 @@ void WebGpuContext::CaptureBegin(std::vector<webgpu::CapturedCommandInfo>* captu
external_captured_commands_->clear();
}

// TODO: support profiling with graph capture.
ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode");

graph_capture_state_ = GraphCaptureState::Capturing;
}

Expand All @@ -843,6 +858,18 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
auto& command = captured_commands[i];
const auto& compute_pass_encoder = GetComputePassEncoder();
WriteTimestamp(num_pending_dispatches_ * 2);

// Restore profiling info when profiling is enabled. All commands are expected
// to have profiling data in this mode to keep pending_kernels_ consistent
// with num_pending_dispatches_.
if (is_profiling_) {
ORT_ENFORCE(command.pending_kernel_info.has_value(),
"WebGpuContext::Replay: profiling is enabled but captured command at index ",
i,
" is missing pending_kernel_info.");
pending_kernels_.emplace_back(*command.pending_kernel_info);
}

compute_pass_encoder.SetPipeline(command.compute_pipeline);
wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr);

Expand Down
68 changes: 37 additions & 31 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <memory>
#include <mutex>
#include <optional>

#include "core/providers/webgpu/webgpu_external_header.h"

Expand All @@ -25,13 +26,47 @@ class WebGpuContext;
class ComputeContextBase;
class ProgramBase;

// PendingKernelInfo stores profiling information for a kernel execution
struct PendingKernelInfo {
PendingKernelInfo(std::string_view kernel_name,
std::string_view kernel_type,
std::string_view program_name,
std::string_view cache_key,
const std::vector<ProgramInput>& inputs,
const std::vector<ProgramOutput>& outputs)
: name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} {
// Store shape information instead of tensor pointers to avoid accessing released tensors
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape());
}
output_shapes.reserve(outputs.size());
for (const auto& output : outputs) {
output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape());
}
}

PendingKernelInfo(const PendingKernelInfo&) = default;
PendingKernelInfo& operator=(const PendingKernelInfo&) = default;
PendingKernelInfo(PendingKernelInfo&&) = default;
PendingKernelInfo& operator=(PendingKernelInfo&&) = default;

std::string name;
std::string cache_key;
std::vector<TensorShape> input_shapes;
std::vector<TensorShape> output_shapes;
};

// Definition for CapturedCommandInfo in the webgpu namespace
struct CapturedCommandInfo {
wgpu::ComputePipeline compute_pipeline;
WGPUBindGroup bind_group;
WGPUBindGroupLayout bind_group_layout;
std::array<uint32_t, 3> dispatch_group;
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
// WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
WGPUBuffer indirect_buffer;
// Optional profiling data
std::optional<PendingKernelInfo> pending_kernel_info;
};

struct WebGpuBufferCacheConfig {
Expand Down Expand Up @@ -145,7 +180,7 @@ class WebGpuContext final {

wgpu::ComputePassDescriptor compute_pass_desc{};

if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) {
if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses && graph_capture_state_ != GraphCaptureState::Capturing) {
wgpu::PassTimestampWrites timestampWrites = {
nullptr,
query_set_,
Expand Down Expand Up @@ -261,35 +296,6 @@ class WebGpuContext final {
wgpu::Limits GetRequiredLimits(const wgpu::Adapter& adapter) const;
void WriteTimestamp(uint32_t query_index);

struct PendingKernelInfo {
PendingKernelInfo(std::string_view kernel_name,
std::string_view kernel_type,
std::string_view program_name,
std::string_view cache_key,
const std::vector<ProgramInput>& inputs,
const std::vector<ProgramOutput>& outputs)
: name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} {
// Store shape information instead of tensor pointers to avoid accessing released tensors
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape());
}
output_shapes.reserve(outputs.size());
for (const auto& output : outputs) {
output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape());
}
}

PendingKernelInfo(PendingKernelInfo&&) = default;
PendingKernelInfo& operator=(PendingKernelInfo&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo);

std::string name;
std::string cache_key;
std::vector<TensorShape> input_shapes;
std::vector<TensorShape> output_shapes;
};

struct PendingQueryInfo {
PendingQueryInfo(std::vector<PendingKernelInfo>&& kernels, wgpu::Buffer query_buffer)
: kernels{std::move(kernels)}, query_buffer{query_buffer} {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,14 @@

Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
// TODO: enable profiling in run level

Check warning on line 1078 in onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc:1078: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
if (session_profiler_ && session_profiler_->Enabled()) {
context_.StartProfiling();
}
context_.Replay(captured_commands_, *graph_buffer_mgr_);
if (session_profiler_ && session_profiler_->Enabled()) {
context_.CollectProfilingData();
}
return Status::OK();
}

Expand Down
Loading