diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 9b71f4ba2ebec..d1a3d3cb1fea1 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -3,6 +3,10 @@ #include "core/platform/windows/telemetry.h" #include +#include +#include +#include +#include #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -51,6 +55,80 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim // {3a26b1ff-7484-7484-7484-15261f42614d} (0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d), TraceLoggingOptionMicrosoftTelemetry()); + +std::string ConvertWideStringToUtf8(const std::wstring& wide) { + if (wide.empty()) + return {}; + + const UINT code_page = CP_UTF8; + const DWORD flags = 0; + LPCWCH const src = wide.data(); + const int src_len = static_cast(wide.size()); + int utf8_length = ::WideCharToMultiByte(code_page, flags, src, src_len, nullptr, 0, nullptr, nullptr); + if (utf8_length == 0) + return {}; + + std::string utf8(utf8_length, '\0'); + if (::WideCharToMultiByte(code_page, flags, src, src_len, utf8.data(), utf8_length, nullptr, nullptr) == 0) + return {}; + + return utf8; +} + +std::string GetServiceNamesForCurrentProcess() { + static std::once_flag once_flag; + static std::string service_names; + + std::call_once(once_flag, [] { + SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE); + if (service_manager == nullptr) + return; + + DWORD bytes_needed = 0; + DWORD services_returned = 0; + DWORD resume_handle = 0; + if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, nullptr, 0, &bytes_needed, + &services_returned, &resume_handle, nullptr) && + ::GetLastError() != ERROR_MORE_DATA) { + ::CloseServiceHandle(service_manager); + return; + } + + if (bytes_needed == 0) { + ::CloseServiceHandle(service_manager); + return; + } + + std::vector buffer(bytes_needed); + auto* services = reinterpret_cast(buffer.data()); + services_returned = 0; + resume_handle = 0; + if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, reinterpret_cast(services), + bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) { + ::CloseServiceHandle(service_manager); + return; + } + + DWORD current_pid = ::GetCurrentProcessId(); + std::wstring aggregated; + bool first = true; + for (DWORD i = 0; i < services_returned; ++i) { + if (services[i].ServiceStatusProcess.dwProcessId == current_pid) { + if (!first) { + aggregated.push_back(L','); + } + aggregated.append(services[i].lpServiceName); + first = false; + } + } + + ::CloseServiceHandle(service_manager); + + service_names = ConvertWideStringToUtf8(aggregated); + }); + + return service_names; +} } // namespace #ifdef _MSC_VER @@ -178,6 +256,7 @@ void WindowsTelemetry::LogProcessInfo() const { #if BUILD_INBOX isRedist = false; #endif + const std::string service_names = GetServiceNamesForCurrentProcess(); TraceLoggingWrite(telemetry_provider_handle, "ProcessInfo", TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), @@ -189,7 +268,8 @@ void WindowsTelemetry::LogProcessInfo() const { TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"), TraceLoggingBool(isRedist, "isRedist"), - TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"), + TraceLoggingString(service_names.c_str(), "serviceNames")); process_info_logged = true; } @@ -278,6 +358,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio execution_provider_string += i; } + const std::string service_names = GetServiceNamesForCurrentProcess(); // Difference is MeasureEvent & isCaptureState, but keep in sync otherwise if (!captureState) { TraceLoggingWrite(telemetry_provider_handle, @@ -304,7 +385,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(service_names.c_str(), "serviceNames")); } else { TraceLoggingWrite(telemetry_provider_handle, "SessionCreation_CaptureState", @@ -330,7 +412,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(service_names.c_str(), "serviceNames")); } }