Skip to content

Commit 9008d33

Browse files
committed
impl
1 parent 1159e35 commit 9008d33

File tree

13 files changed

+102
-52
lines changed

13 files changed

+102
-52
lines changed

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ class IExecutionProvider {
356356
return logger_;
357357
}
358358

359-
virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
359+
virtual std::unique_ptr<profiling::EpProfiler> GetProfiler(bool /*enable_profiling*/) {
360360
return {};
361361
}
362362

include/onnxruntime/core/framework/run_options.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ struct OrtRunOptions {
3838
// Set to 'true' to enable profiling for this run.
3939
bool enable_profiling = false;
4040

41+
// File prefix for profiling result for this run.
42+
// The actual filename will be: <profile_file_prefix>_<timestamp>.json
43+
// Only used when enable_profiling is true.
44+
std::string profile_file_prefix = "onnxruntime_run_profile";
45+
4146
#ifdef ENABLE_TRAINING
4247
// Used by onnxruntime::training::TrainingSession. This class is now deprecated.
4348
// Delete training_mode when TrainingSession is deleted.

onnxruntime/core/framework/sequential_executor.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,12 @@ class SessionScope {
180180
bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
181181
bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled();
182182

183-
if (session_profiling_enabled || run_profiling_enabled) {
184-
auto now = std::chrono::high_resolution_clock::now();
185-
if (session_profiling_enabled) {
186-
session_start_ = session_state_.Profiler().Start(now);
187-
}
188-
if (run_profiling_enabled) {
189-
run_profiler_start_ = run_profiler_->Start(now);
190-
}
183+
auto now = std::chrono::high_resolution_clock::now();
184+
if (session_profiling_enabled) {
185+
session_start_ = session_state_.Profiler().Start(now);
186+
}
187+
if (run_profiling_enabled) {
188+
run_profiler_start_ = run_profiler_->Start(now);
191189
}
192190

193191
auto& logger = session_state_.Logger();
@@ -238,14 +236,12 @@ class SessionScope {
238236
bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
239237
bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled();
240238

241-
if (session_profiling_enabled || run_profiling_enabled) {
242-
auto now = std::chrono::high_resolution_clock::now();
243-
if (session_profiling_enabled) {
244-
session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now);
245-
}
246-
if (run_profiling_enabled) {
247-
run_profiler_->EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", run_profiler_start_, now);
248-
}
239+
auto now = std::chrono::high_resolution_clock::now();
240+
if (session_profiling_enabled) {
241+
session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now);
242+
}
243+
if (run_profiling_enabled) {
244+
run_profiler_->EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", run_profiler_start_, now);
249245
}
250246

251247
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ ITuningContext* CUDAExecutionProvider::GetTuningContext() const {
410410
return const_cast<cuda::tunable::CudaTuningContext*>(&tuning_context_);
411411
}
412412

413-
std::unique_ptr<profiling::EpProfiler> CUDAExecutionProvider::GetProfiler() {
413+
std::unique_ptr<profiling::EpProfiler> CUDAExecutionProvider::GetProfiler(bool /*enable_profiling*/) {
414414
return std::make_unique<profiling::CudaProfiler>();
415415
}
416416

onnxruntime/core/providers/cuda/cuda_execution_provider.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
117117

118118
ITuningContext* GetTuningContext() const override;
119119

120-
std::unique_ptr<profiling::EpProfiler> GetProfiler() override;
120+
std::unique_ptr<profiling::EpProfiler> GetProfiler(bool enable_profiling) override;
121121

122122
bool IsGraphCaptureEnabled() const override;
123123
bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;

onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ common::Status VitisAIExecutionProvider::SetEpDynamicOptions(gsl::span<const cha
141141
return Status::OK();
142142
}
143143

144-
std::unique_ptr<profiling::EpProfiler> VitisAIExecutionProvider::GetProfiler() {
144+
std::unique_ptr<profiling::EpProfiler> VitisAIExecutionProvider::GetProfiler(bool /*enable_profiling*/) {
145145
return std::make_unique<profiling::VitisaiProfiler>();
146146
}
147147

onnxruntime/core/providers/vitisai/vitisai_execution_provider.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class VitisAIExecutionProvider : public IExecutionProvider {
3838
std::vector<NodeComputeInfo>& node_compute_funcs) override;
3939
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
4040

41-
std::unique_ptr<profiling::EpProfiler> GetProfiler() override;
41+
std::unique_ptr<profiling::EpProfiler> GetProfiler(bool enable_profiling) override;
4242

4343
// This method is called after both `GetComputeCapabilityOps()` and `Compile()`.
4444
// This timing is required to work with both compliation-based EPs and non-compilation-based EPs.

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ void WebGpuContext::StartProfiling() {
621621
}
622622
}
623623

624-
void WebGpuContext::CollectProfilingData() {
624+
void WebGpuContext::CollectProfilingData(const std::vector<WebGpuProfiler*>& profilers) {
625625
if (!pending_queries_.empty()) {
626626
for (const auto& pending_query : pending_queries_) {
627627
const auto& pending_kernels = pending_query.kernels;
@@ -671,7 +671,7 @@ void WebGpuContext::CollectProfilingData() {
671671
static_cast<int64_t>(std::round((end_time - start_time) / 1000.0)),
672672
event_args);
673673

674-
for (auto* profiler : profilers_) {
674+
for (auto* profiler : profilers) {
675675
profiler->Events().emplace_back(event);
676676
}
677677
}

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class WebGpuContext final {
160160
}
161161

162162
void StartProfiling();
163-
void CollectProfilingData();
163+
void CollectProfilingData(const std::vector<WebGpuProfiler*>& profilers);
164164
void EndProfiling();
165165

166166
void RegisterProfiler(WebGpuProfiler* profiler);

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -957,15 +957,31 @@ WebGpuExecutionProvider::~WebGpuExecutionProvider() {
957957
WebGpuContextFactory::ReleaseContext(context_id_);
958958
}
959959

960-
std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler() {
961-
return std::make_unique<WebGpuProfiler>(context_);
960+
namespace {
961+
thread_local webgpu::WebGpuProfiler* t_current_profiler = nullptr;
962+
}
963+
964+
std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler(bool enable_profiling) {
965+
auto p = std::make_unique<WebGpuProfiler>(context_);
966+
if (enable_profiling) {
967+
t_current_profiler = p.get();
968+
} else if (!session_profiler_) {
969+
// If not for a run, it's for the session (RegisterExecutionProvider)
970+
session_profiler_ = p.get();
971+
}
972+
return p;
962973
}
963974

964975
Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
965976
if (context_.ValidationMode() >= ValidationMode::Basic) {
966977
context_.PushErrorScope();
967978
}
968979

980+
// Session-level profiling handling if needed
981+
if (run_options.enable_profiling || (session_profiler_ && session_profiler_->Enabled())) {
982+
context_.StartProfiling();
983+
}
984+
969985
if (IsGraphCaptureEnabled()) {
970986
auto graph_annotation_str = run_options.config_options.GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation);
971987
int graph_annotation_id = 0;
@@ -984,7 +1000,7 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
9841000
return Status::OK();
9851001
}
9861002

987-
Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /* run_options */) {
1003+
Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& run_options) {
9881004
context_.Flush(BufferManager());
9891005

9901006
if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) {
@@ -997,17 +1013,24 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
9971013
}
9981014
}
9991015

1000-
if (context_.IsProfilingEnabled()) {
1001-
context_.CollectProfilingData();
1016+
std::vector<WebGpuProfiler*> profilers;
1017+
if (session_profiler_ && session_profiler_->Enabled()) {
1018+
profilers.push_back(session_profiler_);
10021019
}
10031020

1004-
context_.OnRunEnd();
1021+
if (run_options.enable_profiling && t_current_profiler) {
1022+
if (t_current_profiler->Enabled()) {
1023+
profilers.push_back(t_current_profiler);
1024+
}
1025+
t_current_profiler = nullptr;
1026+
}
10051027

1006-
if (context_.ValidationMode() >= ValidationMode::Basic) {
1007-
return context_.PopErrorScope();
1008-
} else {
1009-
return Status::OK();
1028+
if (!profilers.empty()) {
1029+
context_.CollectProfilingData(profilers);
10101030
}
1031+
1032+
context_.OnRunEnd();
1033+
return Status::OK();
10111034
}
10121035

10131036
bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {

0 commit comments

Comments
 (0)