@@ -792,11 +792,20 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput
792792 if (indirect_dispatch_tensor != nullptr ) {
793793 indirect_buffer = reinterpret_cast <WGPUBuffer>(const_cast <void *>(indirect_dispatch_tensor->DataRaw ()));
794794 }
795+
796+ // Store profiling info if profiling is enabled
797+ std::optional<std::tuple<std::string, std::string, std::vector<TensorShape>, std::vector<TensorShape>>> profiling_data;
798+ if (is_profiling_ && !pending_kernels_.empty ()) {
799+ const auto & kernel_info = pending_kernels_.back ();
800+ profiling_data = std::make_tuple (kernel_info.name , kernel_info.cache_key , kernel_info.input_shapes , kernel_info.output_shapes );
801+ }
802+
795803 external_captured_commands_->push_back ({program_artifact.compute_pipeline ,
796804 bind_group,
797805 bind_group_layout,
798806 {x, y, z},
799- indirect_buffer});
807+ indirect_buffer,
808+ profiling_data});
800809 } else {
801810 compute_pass_encoder.SetPipeline (program_artifact.compute_pipeline );
802811 wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , bind_group, 0 , nullptr );
@@ -827,9 +836,6 @@ void WebGpuContext::CaptureBegin(std::vector<webgpu::CapturedCommandInfo>* captu
827836 external_captured_commands_->clear ();
828837 }
829838
830- // TODO: support profiling with graph capture.
831- ORT_ENFORCE (!is_profiling_, " profiling is not supported yet under graph capture mode" );
832-
833839 graph_capture_state_ = GraphCaptureState::Capturing;
834840}
835841
@@ -842,6 +848,13 @@ void WebGpuContext::Replay(const std::vector<webgpu::CapturedCommandInfo>& captu
842848 auto & command = captured_commands[i];
843849 const auto & compute_pass_encoder = GetComputePassEncoder ();
844850 WriteTimestamp (num_pending_dispatches_ * 2 );
851+
852+ // Restore profiling info if available and profiling is enabled
853+ if (is_profiling_ && command.pending_kernel_info .has_value ()) {
854+ const auto & [name, cache_key, input_shapes, output_shapes] = command.pending_kernel_info .value ();
855+ pending_kernels_.emplace_back (name, cache_key, input_shapes, output_shapes);
856+ }
857+
845858 compute_pass_encoder.SetPipeline (command.compute_pipeline );
846859 wgpuComputePassEncoderSetBindGroup (compute_pass_encoder.Get (), 0 , command.bind_group , 0 , nullptr );
847860
0 commit comments