Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions onnxruntime/core/platform/windows/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

#include "core/platform/windows/telemetry.h"
#include <mutex>
#include <string>
#include <vector>
#include <cwchar>
#include <winsvc.h>

Check warning on line 9 in onnxruntime/core/platform/windows/telemetry.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: telemetry.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/platform/windows/telemetry.cc:9: Found C system header after C++ system header. Should be: telemetry.h, c system, c++ system, other. [build/include_order] [4]
#include "core/common/logging/logging.h"
#include "onnxruntime_config.h"

Expand Down Expand Up @@ -51,6 +55,75 @@
// {3a26b1ff-7484-7484-7484-15261f42614d}
(0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d),
TraceLoggingOptionMicrosoftTelemetry());

std::string WideToUtf8(const std::wstring& wide) {
if (wide.empty())
return {};

int utf8_length = ::WideCharToMultiByte(CP_UTF8, 0, wide.data(), static_cast<int>(wide.size()), nullptr, 0, nullptr, nullptr);
if (utf8_length == 0)
return {};

std::string utf8(utf8_length, '\0');
if (::WideCharToMultiByte(CP_UTF8, 0, wide.data(), static_cast<int>(wide.size()), utf8.data(), utf8_length, nullptr, nullptr) == 0)
return {};

return utf8;
}

std::string GetServiceNamesForCurrentProcess() {
static std::once_flag once_flag;
static std::string service_names;

std::call_once(once_flag, [] {
SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE);
if (service_manager == nullptr)
return;

DWORD bytes_needed = 0;
DWORD services_returned = 0;
DWORD resume_handle = 0;
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_STATE_ALL, nullptr, 0, &bytes_needed,
&services_returned, &resume_handle, nullptr) && ::GetLastError() != ERROR_MORE_DATA) {
::CloseServiceHandle(service_manager);
return;
}

if (bytes_needed == 0) {
::CloseServiceHandle(service_manager);
return;
}

std::vector<uint8_t> buffer(bytes_needed);
auto* services = reinterpret_cast<ENUM_SERVICE_STATUS_PROCESSW*>(buffer.data());
services_returned = 0;
resume_handle = 0;
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_STATE_ALL, reinterpret_cast<LPBYTE>(services),
bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) {
::CloseServiceHandle(service_manager);
return;
}

DWORD current_pid = ::GetCurrentProcessId();
std::wstring aggregated;
bool first = true;
for (DWORD i = 0; i < services_returned; ++i) {
if (services[i].ServiceStatusProcess.dwProcessId == current_pid) {
if (!first) {
aggregated.push_back(L',');
}
aggregated.append(services[i].lpServiceName);
first = false;
}
}

::CloseServiceHandle(service_manager);

service_names = WideToUtf8(aggregated);
});

return service_names;
}
} // namespace

#ifdef _MSC_VER
Expand Down Expand Up @@ -178,6 +251,7 @@
#if BUILD_INBOX
isRedist = false;
#endif
const std::string service_names = GetServiceNamesForCurrentProcess();
TraceLoggingWrite(telemetry_provider_handle,
"ProcessInfo",
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
Expand All @@ -189,7 +263,8 @@
TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"),
TraceLoggingBool(isRedist, "isRedist"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"),
TraceLoggingString(service_names.c_str(), "serviceNames"));

process_info_logged = true;
}
Expand Down Expand Up @@ -278,6 +353,7 @@
execution_provider_string += i;
}

const std::string service_names = GetServiceNamesForCurrentProcess();
// Difference is MeasureEvent & isCaptureState, but keep in sync otherwise
if (!captureState) {
TraceLoggingWrite(telemetry_provider_handle,
Expand All @@ -304,7 +380,8 @@
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
TraceLoggingString(service_names.c_str(), "serviceNames"));
} else {
TraceLoggingWrite(telemetry_provider_handle,
"SessionCreation_CaptureState",
Expand All @@ -330,7 +407,8 @@
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
TraceLoggingString(service_names.c_str(), "serviceNames"));
}
}

Expand Down
Loading