Skip to content

Commit bc8e05b

Browse files
bmehta001Aditya Rastogi
authored andcommitted
Record service in telemetry events (#27252)
### Description This change records the service name(s), if any, as part of the SessionCreation/ProcessInfo events. We cache the service names after the first time we calculate them in order to avoid unnecessary overhead. ### Motivation and Context These changes enable deeper understanding of ORT usage, since multiple services can run inside an application in svchost, which currently obscures our understanding of which services/use cases are most popular. Understanding which services are actually being used can help prioritize more investments in making ORT better targeted to end users. ### Testing Have tested that the logic in GetServiceNamesForCurrentProcess can accurately return service name for a given process
1 parent 2ff7447 commit bc8e05b

File tree

1 file changed

+86
-3
lines changed

1 file changed

+86
-3
lines changed

onnxruntime/core/platform/windows/telemetry.cc

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "core/platform/windows/telemetry.h"
55
#include <mutex>
6+
#include <string>
7+
#include <vector>
8+
#include <cwchar>
9+
#include <winsvc.h>
610
#include "core/common/logging/logging.h"
711
#include "onnxruntime_config.h"
812

@@ -51,6 +55,80 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim
5155
// {3a26b1ff-7484-7484-7484-15261f42614d}
5256
(0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d),
5357
TraceLoggingOptionMicrosoftTelemetry());
58+
59+
std::string ConvertWideStringToUtf8(const std::wstring& wide) {
60+
if (wide.empty())
61+
return {};
62+
63+
const UINT code_page = CP_UTF8;
64+
const DWORD flags = 0;
65+
LPCWCH const src = wide.data();
66+
const int src_len = static_cast<int>(wide.size());
67+
int utf8_length = ::WideCharToMultiByte(code_page, flags, src, src_len, nullptr, 0, nullptr, nullptr);
68+
if (utf8_length == 0)
69+
return {};
70+
71+
std::string utf8(utf8_length, '\0');
72+
if (::WideCharToMultiByte(code_page, flags, src, src_len, utf8.data(), utf8_length, nullptr, nullptr) == 0)
73+
return {};
74+
75+
return utf8;
76+
}
77+
78+
std::string GetServiceNamesForCurrentProcess() {
79+
static std::once_flag once_flag;
80+
static std::string service_names;
81+
82+
std::call_once(once_flag, [] {
83+
SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE);
84+
if (service_manager == nullptr)
85+
return;
86+
87+
DWORD bytes_needed = 0;
88+
DWORD services_returned = 0;
89+
DWORD resume_handle = 0;
90+
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, nullptr, 0, &bytes_needed,
91+
&services_returned, &resume_handle, nullptr) &&
92+
::GetLastError() != ERROR_MORE_DATA) {
93+
::CloseServiceHandle(service_manager);
94+
return;
95+
}
96+
97+
if (bytes_needed == 0) {
98+
::CloseServiceHandle(service_manager);
99+
return;
100+
}
101+
102+
std::vector<uint8_t> buffer(bytes_needed);
103+
auto* services = reinterpret_cast<ENUM_SERVICE_STATUS_PROCESSW*>(buffer.data());
104+
services_returned = 0;
105+
resume_handle = 0;
106+
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, reinterpret_cast<LPBYTE>(services),
107+
bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) {
108+
::CloseServiceHandle(service_manager);
109+
return;
110+
}
111+
112+
DWORD current_pid = ::GetCurrentProcessId();
113+
std::wstring aggregated;
114+
bool first = true;
115+
for (DWORD i = 0; i < services_returned; ++i) {
116+
if (services[i].ServiceStatusProcess.dwProcessId == current_pid) {
117+
if (!first) {
118+
aggregated.push_back(L',');
119+
}
120+
aggregated.append(services[i].lpServiceName);
121+
first = false;
122+
}
123+
}
124+
125+
::CloseServiceHandle(service_manager);
126+
127+
service_names = ConvertWideStringToUtf8(aggregated);
128+
});
129+
130+
return service_names;
131+
}
54132
} // namespace
55133

56134
#ifdef _MSC_VER
@@ -178,6 +256,7 @@ void WindowsTelemetry::LogProcessInfo() const {
178256
#if BUILD_INBOX
179257
isRedist = false;
180258
#endif
259+
const std::string service_names = GetServiceNamesForCurrentProcess();
181260
TraceLoggingWrite(telemetry_provider_handle,
182261
"ProcessInfo",
183262
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
@@ -189,7 +268,8 @@ void WindowsTelemetry::LogProcessInfo() const {
189268
TraceLoggingString(ORT_VERSION, "runtimeVersion"),
190269
TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"),
191270
TraceLoggingBool(isRedist, "isRedist"),
192-
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
271+
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"),
272+
TraceLoggingString(service_names.c_str(), "serviceNames"));
193273

194274
process_info_logged = true;
195275
}
@@ -278,6 +358,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
278358
execution_provider_string += i;
279359
}
280360

361+
const std::string service_names = GetServiceNamesForCurrentProcess();
281362
// Difference is MeasureEvent & isCaptureState, but keep in sync otherwise
282363
if (!captureState) {
283364
TraceLoggingWrite(telemetry_provider_handle,
@@ -304,7 +385,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
304385
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
305386
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
306387
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
307-
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
388+
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
389+
TraceLoggingString(service_names.c_str(), "serviceNames"));
308390
} else {
309391
TraceLoggingWrite(telemetry_provider_handle,
310392
"SessionCreation_CaptureState",
@@ -330,7 +412,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
330412
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
331413
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
332414
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
333-
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
415+
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
416+
TraceLoggingString(service_names.c_str(), "serviceNames"));
334417
}
335418
}
336419

0 commit comments

Comments
 (0)