Skip to content

Commit 82fddd7

Browse files
authored
Cherry pick telemetry changes from win-onnxruntime (#24957)
### Description This change cherry-picks telemetry changes from win-onnxruntime to improve telemetry data collection for ONNX Runtime on Windows. ### Motivation and Context These changes are already present in win-onnxruntime, so cherry-picking these changes here for Windows use cases that rely on public ONNX Runtime.
1 parent 9ffc650 commit 82fddd7

File tree

11 files changed

+303
-26
lines changed

11 files changed

+303
-26
lines changed

cmake/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,6 +1669,10 @@ if (onnxruntime_ENABLE_DLPACK)
16691669
add_compile_definitions(ENABLE_DLPACK)
16701670
endif()
16711671

1672+
if (onnxruntime_CALLER_FRAMEWORK)
1673+
add_definitions(-DORT_CALLER_FRAMEWORK="${onnxruntime_CALLER_FRAMEWORK}")
1674+
endif()
1675+
16721676
if (UNIX OR onnxruntime_USE_NCCL)
16731677
# Find NCCL
16741678
if (onnxruntime_USE_NCCL)

include/onnxruntime/core/graph/graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
15241524

15251525
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);
15261526

1527+
int32_t weight_data_type_freq_[ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE] = {0};
1528+
15271529
private:
15281530
void InitializeStateFromModelFileGraphProto();
15291531

onnxruntime/core/graph/graph.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,10 @@ Graph::Graph(const Model& owning_model,
13461346
ORT_THROW("This is an invalid model. Tensor does not have type information.");
13471347
}
13481348

1349+
if (tensor.has_data_type() && (tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) {
1350+
weight_data_type_freq_[tensor.data_type()]++;
1351+
}
1352+
13491353
if (ir_version_ < 4) {
13501354
// initializers can have matching graph inputs but are treated as constant,
13511355
// so we prefer the shape from the initializer

onnxruntime/core/platform/telemetry.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ void Telemetry::LogEvaluationStart() const {
5252
void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
5353
const std::string& model_producer_version, const std::string& model_domain,
5454
const std::unordered_map<std::string, int>& domain_to_version_map,
55+
const std::string& model_file_name,
5556
const std::string& model_graph_name,
57+
const std::string& model_weight_type,
58+
const std::string& model_graph_hash,
59+
const std::string& model_weight_hash,
5660
const std::unordered_map<std::string, std::string>& model_metadata,
5761
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids,
5862
bool use_fp16, bool captureState) const {
@@ -62,7 +66,11 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons
6266
ORT_UNUSED_PARAMETER(model_producer_version);
6367
ORT_UNUSED_PARAMETER(model_domain);
6468
ORT_UNUSED_PARAMETER(domain_to_version_map);
69+
ORT_UNUSED_PARAMETER(model_file_name);
6570
ORT_UNUSED_PARAMETER(model_graph_name);
71+
ORT_UNUSED_PARAMETER(model_weight_type);
72+
ORT_UNUSED_PARAMETER(model_graph_hash);
73+
ORT_UNUSED_PARAMETER(model_weight_hash);
6674
ORT_UNUSED_PARAMETER(model_metadata);
6775
ORT_UNUSED_PARAMETER(loadedFrom);
6876
ORT_UNUSED_PARAMETER(execution_provider_ids);
@@ -79,10 +87,12 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu
7987
ORT_UNUSED_PARAMETER(line);
8088
}
8189

82-
void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const {
90+
void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
91+
std::unordered_map<int64_t, long long> duration_per_batch_size) const {
8392
ORT_UNUSED_PARAMETER(session_id);
8493
ORT_UNUSED_PARAMETER(total_runs_since_last);
8594
ORT_UNUSED_PARAMETER(total_run_duration_since_last);
95+
ORT_UNUSED_PARAMETER(duration_per_batch_size);
8696
}
8797

8898
void Telemetry::LogExecutionProviderEvent(LUID* adapterLuid) const {

onnxruntime/core/platform/telemetry.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,20 @@ class Telemetry {
5757
virtual void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
5858
const std::string& model_producer_version, const std::string& model_domain,
5959
const std::unordered_map<std::string, int>& domain_to_version_map,
60+
const std::string& model_file_name,
6061
const std::string& model_graph_name,
62+
const std::string& model_weight_type,
63+
const std::string& model_graph_hash,
64+
const std::string& model_weight_hash,
6165
const std::unordered_map<std::string, std::string>& model_metadata,
6266
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids,
6367
bool use_fp16, bool captureState) const;
6468

6569
virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
6670
const char* function, uint32_t line) const;
6771

68-
virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const;
72+
virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
73+
std::unordered_map<int64_t, long long> duration_per_batch_size) const;
6974

7075
virtual void LogExecutionProviderEvent(LUID* adapterLuid) const;
7176

onnxruntime/core/platform/windows/telemetry.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim
5757
#pragma warning(pop)
5858
#endif
5959

60+
#ifndef ORT_CALLER_FRAMEWORK
61+
#define ORT_CALLER_FRAMEWORK ""
62+
#endif
63+
6064
std::mutex WindowsTelemetry::mutex_;
6165
std::mutex WindowsTelemetry::provider_change_mutex_;
6266
uint32_t WindowsTelemetry::global_register_count_ = 0;
@@ -184,7 +188,8 @@ void WindowsTelemetry::LogProcessInfo() const {
184188
TraceLoggingUInt8(0, "schemaVersion"),
185189
TraceLoggingString(ORT_VERSION, "runtimeVersion"),
186190
TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"),
187-
TraceLoggingBool(isRedist, "isRedist"));
191+
TraceLoggingBool(isRedist, "isRedist"),
192+
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
188193

189194
process_info_logged = true;
190195
}
@@ -220,7 +225,11 @@ void WindowsTelemetry::LogEvaluationStart() const {
220225
void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
221226
const std::string& model_producer_version, const std::string& model_domain,
222227
const std::unordered_map<std::string, int>& domain_to_version_map,
228+
const std::string& model_file_name,
223229
const std::string& model_graph_name,
230+
const std::string& model_weight_type,
231+
const std::string& model_graph_hash,
232+
const std::string& model_weight_hash,
224233
const std::unordered_map<std::string, std::string>& model_metadata,
225234
const std::string& loaded_from, const std::vector<std::string>& execution_provider_ids,
226235
bool use_fp16, bool captureState) const {
@@ -285,7 +294,11 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
285294
TraceLoggingString(model_domain.c_str(), "modelDomain"),
286295
TraceLoggingBool(use_fp16, "usefp16"),
287296
TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"),
297+
TraceLoggingString(model_file_name.c_str(), "modelFileName"),
288298
TraceLoggingString(model_graph_name.c_str(), "modelGraphName"),
299+
TraceLoggingString(model_weight_type.c_str(), "modelWeightType"),
300+
TraceLoggingString(model_graph_hash.c_str(), "modelGraphHash"),
301+
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
289302
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
290303
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
291304
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
@@ -307,7 +320,11 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
307320
TraceLoggingString(model_domain.c_str(), "modelDomain"),
308321
TraceLoggingBool(use_fp16, "usefp16"),
309322
TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"),
323+
TraceLoggingString(model_file_name.c_str(), "modelFileName"),
310324
TraceLoggingString(model_graph_name.c_str(), "modelGraphName"),
325+
TraceLoggingString(model_weight_type.c_str(), "modelWeightType"),
326+
TraceLoggingString(model_graph_hash.c_str(), "modelGraphHash"),
327+
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
311328
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
312329
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
313330
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
@@ -356,10 +373,22 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
356373
#endif
357374
}
358375

359-
void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const {
376+
void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
377+
std::unordered_map<int64_t, long long> duration_per_batch_size) const {
360378
if (global_register_count_ == 0 || enabled_ == false)
361379
return;
362380

381+
// Convert duration_per_batch_size to a formatted string
382+
std::string total_duration_per_batch_size;
383+
for (const auto& entry : duration_per_batch_size) {
384+
if (!total_duration_per_batch_size.empty()) {
385+
total_duration_per_batch_size += ", ";
386+
}
387+
total_duration_per_batch_size += std::to_string(entry.first);
388+
total_duration_per_batch_size += ": ";
389+
total_duration_per_batch_size += std::to_string(entry.second);
390+
}
391+
363392
TraceLoggingWrite(telemetry_provider_handle,
364393
"RuntimePerf",
365394
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
@@ -369,7 +398,8 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s
369398
TraceLoggingUInt8(0, "schemaVersion"),
370399
TraceLoggingUInt32(session_id, "sessionId"),
371400
TraceLoggingUInt32(total_runs_since_last, "totalRuns"),
372-
TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"));
401+
TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"),
402+
TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"));
373403
}
374404

375405
void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const {

onnxruntime/core/platform/windows/telemetry.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <mutex>
1414
#include "core/platform/windows/TraceLoggingConfig.h"
1515

16+
static constexpr size_t TelemetrySampleCount = 10;
17+
1618
namespace onnxruntime {
1719

1820
/**
@@ -47,15 +49,20 @@ class WindowsTelemetry : public Telemetry {
4749
void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
4850
const std::string& model_producer_version, const std::string& model_domain,
4951
const std::unordered_map<std::string, int>& domain_to_version_map,
52+
const std::string& model_file_name,
5053
const std::string& model_graph_name,
54+
const std::string& model_weight_type,
55+
const std::string& model_graph_hash,
56+
const std::string& model_weight_hash,
5157
const std::unordered_map<std::string, std::string>& model_metadata,
5258
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids,
5359
bool use_fp16, bool captureState) const override;
5460

5561
void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
5662
const char* function, uint32_t line) const override;
5763

58-
void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override;
64+
void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
65+
std::unordered_map<int64_t, long long> duration_per_batch_size) const override;
5966

6067
void LogExecutionProviderEvent(LUID* adapterLuid) const override;
6168

0 commit comments

Comments
 (0)