@@ -957,15 +957,31 @@ WebGpuExecutionProvider::~WebGpuExecutionProvider() {
957957 WebGpuContextFactory::ReleaseContext (context_id_);
958958}
959959
960- std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler () {
961- return std::make_unique<WebGpuProfiler>(context_);
960+ namespace {
961+ thread_local webgpu::WebGpuProfiler* t_current_profiler = nullptr ;
962+ }
963+
964+ std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler (bool enable_profiling) {
965+ auto p = std::make_unique<WebGpuProfiler>(context_);
966+ if (enable_profiling) {
967+ t_current_profiler = p.get ();
968+ } else if (!session_profiler_) {
969+ // If not for a run, it's for the session (RegisterExecutionProvider)
970+ session_profiler_ = p.get ();
971+ }
972+ return p;
962973}
963974
964975Status WebGpuExecutionProvider::OnRunStart (const onnxruntime::RunOptions& run_options) {
965976 if (context_.ValidationMode () >= ValidationMode::Basic) {
966977 context_.PushErrorScope ();
967978 }
968979
980+ // Session-level profiling handling if needed
981+ if (run_options.enable_profiling || (session_profiler_ && session_profiler_->Enabled ())) {
982+ context_.StartProfiling ();
983+ }
984+
969985 if (IsGraphCaptureEnabled ()) {
970986 auto graph_annotation_str = run_options.config_options .GetConfigEntry (kOrtRunOptionsConfigCudaGraphAnnotation );
971987 int graph_annotation_id = 0 ;
@@ -984,7 +1000,7 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
9841000 return Status::OK ();
9851001}
9861002
987- Status WebGpuExecutionProvider::OnRunEnd (bool /* sync_stream */ , const onnxruntime::RunOptions& /* run_options */ ) {
1003+ Status WebGpuExecutionProvider::OnRunEnd (bool /* sync_stream */ , const onnxruntime::RunOptions& run_options) {
9881004 context_.Flush (BufferManager ());
9891005
9901006 if (IsGraphCaptureEnabled () && !IsGraphCaptured (m_current_graph_annotation_id)) {
@@ -997,17 +1013,24 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
9971013 }
9981014 }
9991015
1000- if (context_.IsProfilingEnabled ()) {
1001- context_.CollectProfilingData ();
1016+ std::vector<WebGpuProfiler*> profilers;
1017+ if (session_profiler_ && session_profiler_->Enabled ()) {
1018+ profilers.push_back (session_profiler_);
10021019 }
10031020
1004- context_.OnRunEnd ();
1021+ if (run_options.enable_profiling && t_current_profiler) {
1022+ if (t_current_profiler->Enabled ()) {
1023+ profilers.push_back (t_current_profiler);
1024+ }
1025+ t_current_profiler = nullptr ;
1026+ }
10051027
1006- if (context_.ValidationMode () >= ValidationMode::Basic) {
1007- return context_.PopErrorScope ();
1008- } else {
1009- return Status::OK ();
1028+ if (!profilers.empty ()) {
1029+ context_.CollectProfilingData (profilers);
10101030 }
1031+
1032+ context_.OnRunEnd ();
1033+ return Status::OK ();
10111034}
10121035
10131036bool WebGpuExecutionProvider::IsGraphCaptureEnabled () const {
0 commit comments