diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ce84928329e30..318e815255ff7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -3,6 +3,7 @@ #include #include +#include #if defined(__GNUC__) #pragma GCC diagnostic push @@ -282,16 +283,6 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); - if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.NodeName(), - context.OpType(), - program.Name(), - key, - inputs, - outputs); - pending_kernels_.emplace_back(std::move(pending_kernel_info)); - } - LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")"; const auto* program_artifact = program_mgr_->Get(key); @@ -480,6 +471,26 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra WriteTimestamp(num_pending_dispatches_ * 2 + 1); ++num_pending_dispatches_; + // Update profiling data after LaunchComputePipeline + if (is_profiling_) { + PendingKernelInfo pending_kernel_info(context.NodeName(), + context.OpType(), + program.Name(), + key, + inputs, + outputs); + + if (graph_capture_state_ == GraphCaptureState::Capturing) { + // Update the last captured command's profiling info + if (external_captured_commands_ && !external_captured_commands_->empty()) { + external_captured_commands_->back().pending_kernel_info = std::move(pending_kernel_info); + } + } else { + // Add to pending kernels for current run profiling + pending_kernels_.emplace_back(std::move(pending_kernel_info)); + } + } + if (num_pending_dispatches_ >= max_num_pending_dispatches_ || (is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) { EndComputePass(); @@ -577,7 +588,7 @@ wgpu::Limits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) cons } void WebGpuContext::WriteTimestamp(uint32_t query_index) { - if (!is_profiling_ || query_type_ != TimestampQueryType::InsidePasses) { + if (!is_profiling_ || graph_capture_state_ == GraphCaptureState::Capturing || query_type_ != TimestampQueryType::InsidePasses) { return; } @@ -714,7 +725,11 @@ void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) { EndComputePass(); - if (is_profiling_ && num_pending_dispatches_ > 0) { + if (is_profiling_ && num_pending_dispatches_ > 0 && graph_capture_state_ != GraphCaptureState::Capturing) { + ORT_ENFORCE(num_pending_dispatches_ == pending_kernels_.size(), + "Number of pending dispatches (", num_pending_dispatches_, + ") does not match pending kernels size (", pending_kernels_.size(), ")"); + uint32_t query_count = num_pending_dispatches_ * 2; current_command_encoder_.ResolveQuerySet( query_set_, @@ -793,11 +808,14 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput if (indirect_dispatch_tensor != nullptr) { indirect_buffer = reinterpret_cast(const_cast(indirect_dispatch_tensor->DataRaw())); } + + // Profiling data will be populated in Run() after this call returns. external_captured_commands_->push_back({program_artifact.compute_pipeline, bind_group, bind_group_layout, {x, y, z}, - indirect_buffer}); + indirect_buffer, + std::nullopt}); } else { compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline); wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr); @@ -828,9 +846,6 @@ void WebGpuContext::CaptureBegin(std::vector* captu external_captured_commands_->clear(); } - // TODO: support profiling with graph capture. - ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode"); - graph_capture_state_ = GraphCaptureState::Capturing; } @@ -843,6 +858,18 @@ void WebGpuContext::Replay(const std::vector& captu auto& command = captured_commands[i]; const auto& compute_pass_encoder = GetComputePassEncoder(); WriteTimestamp(num_pending_dispatches_ * 2); + + // Restore profiling info when profiling is enabled. All commands are expected + // to have profiling data in this mode to keep pending_kernels_ consistent + // with num_pending_dispatches_. + if (is_profiling_) { + ORT_ENFORCE(command.pending_kernel_info.has_value(), + "WebGpuContext::Replay: profiling is enabled but captured command at index ", + i, + " is missing pending_kernel_info."); + pending_kernels_.emplace_back(*command.pending_kernel_info); + } + compute_pass_encoder.SetPipeline(command.compute_pipeline); wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index ac33d8ecc2ab2..7645f1e6b2482 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,6 +5,7 @@ #include #include +#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -25,13 +26,47 @@ class WebGpuContext; class ComputeContextBase; class ProgramBase; +// PendingKernelInfo stores profiling information for a kernel execution +struct PendingKernelInfo { + PendingKernelInfo(std::string_view kernel_name, + std::string_view kernel_type, + std::string_view program_name, + std::string_view cache_key, + const std::vector& inputs, + const std::vector& outputs) + : name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} { + // Store shape information instead of tensor pointers to avoid accessing released tensors + input_shapes.reserve(inputs.size()); + for (const auto& input : inputs) { + input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape()); + } + output_shapes.reserve(outputs.size()); + for (const auto& output : outputs) { + output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape()); + } + } + + PendingKernelInfo(const PendingKernelInfo&) = default; + PendingKernelInfo& operator=(const PendingKernelInfo&) = default; + PendingKernelInfo(PendingKernelInfo&&) = default; + PendingKernelInfo& operator=(PendingKernelInfo&&) = default; + + std::string name; + std::string cache_key; + std::vector input_shapes; + std::vector output_shapes; +}; + // Definition for CapturedCommandInfo in the webgpu namespace struct CapturedCommandInfo { wgpu::ComputePipeline compute_pipeline; WGPUBindGroup bind_group; WGPUBindGroupLayout bind_group_layout; std::array dispatch_group; - WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch + // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch + WGPUBuffer indirect_buffer; + // Optional profiling data + std::optional pending_kernel_info; }; struct WebGpuBufferCacheConfig { @@ -145,7 +180,7 @@ class WebGpuContext final { wgpu::ComputePassDescriptor compute_pass_desc{}; - if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) { + if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses && graph_capture_state_ != GraphCaptureState::Capturing) { wgpu::PassTimestampWrites timestampWrites = { nullptr, query_set_, @@ -261,35 +296,6 @@ class WebGpuContext final { wgpu::Limits GetRequiredLimits(const wgpu::Adapter& adapter) const; void WriteTimestamp(uint32_t query_index); - struct PendingKernelInfo { - PendingKernelInfo(std::string_view kernel_name, - std::string_view kernel_type, - std::string_view program_name, - std::string_view cache_key, - const std::vector& inputs, - const std::vector& outputs) - : name{absl::StrJoin({kernel_name, kernel_type, program_name}, "&")}, cache_key{cache_key} { - // Store shape information instead of tensor pointers to avoid accessing released tensors - input_shapes.reserve(inputs.size()); - for (const auto& input : inputs) { - input_shapes.emplace_back(input.use_override_shape ? input.override_shape : input.tensor->Shape()); - } - output_shapes.reserve(outputs.size()); - for (const auto& output : outputs) { - output_shapes.emplace_back(output.use_override_shape ? output.override_shape : output.tensor->Shape()); - } - } - - PendingKernelInfo(PendingKernelInfo&&) = default; - PendingKernelInfo& operator=(PendingKernelInfo&&) = default; - ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo); - - std::string name; - std::string cache_key; - std::vector input_shapes; - std::vector output_shapes; - }; - struct PendingQueryInfo { PendingQueryInfo(std::vector&& kernels, wgpu::Buffer query_buffer) : kernels{std::move(kernels)}, query_buffer{query_buffer} {} diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e7da76d90fe4e..844591a930c0c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -1075,7 +1075,14 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); + // TODO: enable profiling in run level + if (session_profiler_ && session_profiler_->Enabled()) { + context_.StartProfiling(); + } context_.Replay(captured_commands_, *graph_buffer_mgr_); + if (session_profiler_ && session_profiler_->Enabled()) { + context_.CollectProfilingData(); + } return Status::OK(); }