Skip to content

Commit df00e91

Browse files
bmehta001tianleiwu
authored andcommitted
Record service in telemetry events (#27252)
This change records the service name(s), if any, as part of the SessionCreation/ProcessInfo events. We cache the service names after the first time we calculate them in order to avoid unnecessary overhead. These changes enable deeper understanding of ORT usage, since multiple services can run inside an application in svchost, which currently obscures our understanding of which services/use cases are most popular. Understanding which services are actually being used can help prioritize more investments in making ORT better targeted to end users. Have tested that the logic in GetServiceNamesForCurrentProcess can accurately return service name for a given process
1 parent 4bae1b4 commit df00e91

File tree

1 file changed

+82
-1
lines changed

1 file changed

+82
-1
lines changed

onnxruntime/core/platform/windows/telemetry.cc

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "core/platform/windows/telemetry.h"
55
#include <mutex>
6+
#include <string>
7+
#include <vector>
8+
#include <cwchar>
9+
#include <winsvc.h>
610
#include "core/common/logging/logging.h"
711
#include "onnxruntime_config.h"
812

@@ -51,6 +55,80 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim
5155
// {3a26b1ff-7484-7484-7484-15261f42614d}
5256
(0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d),
5357
TraceLoggingOptionMicrosoftTelemetry());
58+
59+
std::string ConvertWideStringToUtf8(const std::wstring& wide) {
60+
if (wide.empty())
61+
return {};
62+
63+
const UINT code_page = CP_UTF8;
64+
const DWORD flags = 0;
65+
LPCWCH const src = wide.data();
66+
const int src_len = static_cast<int>(wide.size());
67+
int utf8_length = ::WideCharToMultiByte(code_page, flags, src, src_len, nullptr, 0, nullptr, nullptr);
68+
if (utf8_length == 0)
69+
return {};
70+
71+
std::string utf8(utf8_length, '\0');
72+
if (::WideCharToMultiByte(code_page, flags, src, src_len, utf8.data(), utf8_length, nullptr, nullptr) == 0)
73+
return {};
74+
75+
return utf8;
76+
}
77+
78+
std::string GetServiceNamesForCurrentProcess() {
79+
static std::once_flag once_flag;
80+
static std::string service_names;
81+
82+
std::call_once(once_flag, [] {
83+
SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE);
84+
if (service_manager == nullptr)
85+
return;
86+
87+
DWORD bytes_needed = 0;
88+
DWORD services_returned = 0;
89+
DWORD resume_handle = 0;
90+
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, nullptr, 0, &bytes_needed,
91+
&services_returned, &resume_handle, nullptr) &&
92+
::GetLastError() != ERROR_MORE_DATA) {
93+
::CloseServiceHandle(service_manager);
94+
return;
95+
}
96+
97+
if (bytes_needed == 0) {
98+
::CloseServiceHandle(service_manager);
99+
return;
100+
}
101+
102+
std::vector<uint8_t> buffer(bytes_needed);
103+
auto* services = reinterpret_cast<ENUM_SERVICE_STATUS_PROCESSW*>(buffer.data());
104+
services_returned = 0;
105+
resume_handle = 0;
106+
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, reinterpret_cast<LPBYTE>(services),
107+
bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) {
108+
::CloseServiceHandle(service_manager);
109+
return;
110+
}
111+
112+
DWORD current_pid = ::GetCurrentProcessId();
113+
std::wstring aggregated;
114+
bool first = true;
115+
for (DWORD i = 0; i < services_returned; ++i) {
116+
if (services[i].ServiceStatusProcess.dwProcessId == current_pid) {
117+
if (!first) {
118+
aggregated.push_back(L',');
119+
}
120+
aggregated.append(services[i].lpServiceName);
121+
first = false;
122+
}
123+
}
124+
125+
::CloseServiceHandle(service_manager);
126+
127+
service_names = ConvertWideStringToUtf8(aggregated);
128+
});
129+
130+
return service_names;
131+
}
54132
} // namespace
55133

56134
#ifdef _MSC_VER
@@ -178,6 +256,7 @@ void WindowsTelemetry::LogProcessInfo() const {
178256
#if BUILD_INBOX
179257
isRedist = false;
180258
#endif
259+
const std::string service_names = GetServiceNamesForCurrentProcess();
181260
TraceLoggingWrite(telemetry_provider_handle,
182261
"ProcessInfo",
183262
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
@@ -189,7 +268,8 @@ void WindowsTelemetry::LogProcessInfo() const {
189268
TraceLoggingString(ORT_VERSION, "runtimeVersion"),
190269
TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"),
191270
TraceLoggingBool(isRedist, "isRedist"),
192-
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
271+
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"),
272+
TraceLoggingString(service_names.c_str(), "serviceNames"));
193273

194274
process_info_logged = true;
195275
}
@@ -279,6 +359,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
279359
execution_provider_string += i;
280360
}
281361

362+
const std::string service_names = GetServiceNamesForCurrentProcess();
282363
// Difference is MeasureEvent & isCaptureState, but keep in sync otherwise
283364
if (!captureState) {
284365
TraceLoggingWrite(telemetry_provider_handle,

0 commit comments

Comments
 (0)