Skip to content

Commit 878dff6

Browse files
authored
Merge branch 'microsoft:main' into main
2 parents 7acbfcf + 06fe9a4 commit 878dff6

39 files changed

+1146
-227
lines changed

.vscode/settings.json

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,23 @@
1414
"-build/include_subdir",
1515
"-runtime/references"
1616
],
17-
"C_Cpp.autoAddFileAssociations": false
17+
"C_Cpp.autoAddFileAssociations": false,
18+
19+
// Exclude build directories and non-essential folders from C++ parsing
20+
"C_Cpp.files.exclude": {
21+
"**/build/**": true,
22+
"**/build_*/**": true,
23+
"**/cmake/external/**": true,
24+
"**/node_modules/**": true,
25+
"**/.git/**": true
26+
},
27+
28+
// Exclude from search but keep in explorer
29+
"search.exclude": {
30+
"**/build/**": true,
31+
"**/build_*/**": true,
32+
"**/cmake/external/**": true,
33+
"**/node_modules/**": true,
34+
"**/.git/**": true
35+
}
1836
}

cmake/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,8 +1801,11 @@ if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS)
18011801
)
18021802
endif()
18031803

1804-
if(NOT onnxruntime_BUILD_SHARED_LIB AND onnxruntime_USE_WEBGPU)
1805-
message(WARNING "CMake target files will not be generated for static onnxruntime builds with webgpu support")
1804+
if (NOT onnxruntime_BUILD_SHARED_LIB AND
1805+
(onnxruntime_USE_WEBGPU OR (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_USE_XNNPACK)))
1806+
message(WARNING
1807+
"CMake target files will not be generated for static onnxruntime builds "
1808+
"with WebGPU or Emscripten+XNNPACK support")
18061809
else()
18071810
# Install
18081811
include(CMakePackageConfigHelpers)

cmake/external/abseil-cpp.cmake

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,8 @@ set(ABSL_USE_EXTERNAL_GOOGLETEST ON)
1212

1313
# Both abseil and xnnpack create a target called memory, which
1414
# results in a duplicate target if ABSL_ENABLE_INSTALL is on.
15-
if (onnxruntime_USE_XNNPACK)
16-
set(ABSL_ENABLE_INSTALL OFF)
17-
else()
18-
if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX")
15+
if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX")
1916
set(ABSL_ENABLE_INSTALL ON)
20-
endif()
2117
endif()
2218

2319
if(Patch_FOUND)

cmake/external/onnxruntime_external_deps.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,11 @@ if (onnxruntime_USE_WEBGPU)
764764
# - (private) Fix compatibility issues with Safari. Contains the following changes:
765765
# - Polyfill for `device.AdapterInfo` (returns `undefined` in Safari v26.0)
766766
#
767-
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/safari_polyfill.patch)
767+
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/safari_polyfill.patch &&
768+
769+
# Remove the test folder to speed up potential file scan operations (70k+ files not needed for build).
770+
# Using <SOURCE_DIR> token ensures the correct absolute path regardless of working directory.
771+
${CMAKE_COMMAND} -E rm -rf <SOURCE_DIR>/test)
768772

769773
onnxruntime_fetchcontent_declare(
770774
dawn

cmake/external/xnnpack.cmake

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
6262
SET(ORT_TARGET_PROCESSOR "arm64")
6363
ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "ppc64le")
6464
SET(ORT_TARGET_PROCESSOR "ppc64")
65+
ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
66+
SET(ORT_TARGET_PROCESSOR "wasm")
6567
ELSEIF(NOT ORT_TARGET_PROCESSOR MATCHES "^(x86(_64)?|arm64|riscv(32|64|128)|Hexagon|ppc64)$")
6668
SET(ORT_TARGET_PROCESSOR "${CMAKE_SYSTEM_PROCESSOR}")
6769
ELSE()
@@ -90,18 +92,21 @@ onnxruntime_fetchcontent_makeavailable(googlexnnpack)
9092
set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR})
9193
set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include)
9294

93-
set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK xnnpack-microkernels-prod pthreadpool)
95+
set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool)
9496
if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC")
9597
list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK kleidiai)
9698
endif()
99+
if(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
100+
list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK xnnpack-microkernels-prod)
101+
endif()
97102

98103
# the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up
99104
if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
100105
# See source lists in _deps/googlexnnpack-src/BUILD.bazel for wasm_prod_microkernels
101106
message("Adding WebAssembly Source Files to XNNPACK")
102107
set(wasm_srcs "")
103108

104-
file(READ "${XNNPACK_DIR}/BUILD.bazel" xnnpack_bazel_config)
109+
file(READ "${XNNPACK_DIR}/build_srcs.bzl" xnnpack_bazel_config)
105110

106111
# Replace newlines with semicolon so that it is treated as a list by CMake
107112
# Also replace '[' and ']' so the bazel source lists don't get parsed as a nested list by cmake
@@ -139,19 +144,26 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
139144
GetSrcListFromBazel("TABLE_SRCS" table_srcs)
140145
list(APPEND wasm_srcs ${operator_srcs} ${table_srcs})
141146

142-
# kernels
143-
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/scalar.c)
144-
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasm.c)
147+
set(microkernel_src "")
148+
149+
include(${XNNPACK_DIR}/cmake/gen/scalar_microkernels.cmake)
150+
list(APPEND microkernel_src ${PROD_SCALAR_MICROKERNEL_SRCS})
151+
list(APPEND microkernel_src ${PROD_WASM_MICROKERNEL_SRCS})
145152

146153
if(onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD)
147-
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c)
148-
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmrelaxedsimd.c)
154+
include(${XNNPACK_DIR}/cmake/gen/wasmsimd_microkernels.cmake)
155+
include(${XNNPACK_DIR}/cmake/gen/wasmrelaxedsimd_microkernels.cmake)
156+
list(APPEND microkernel_src ${PROD_WASMSIMD_MICROKERNEL_SRCS})
157+
list(APPEND microkernel_src ${PROD_WASMRELAXEDSIMD_MICROKERNEL_SRCS})
149158
target_compile_options(XNNPACK PRIVATE "-msimd128")
150159
target_compile_options(XNNPACK PRIVATE "-mrelaxed-simd")
151160
elseif(onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
152-
list(APPEND wasm_srcs ${XNNPACK_DIR}/src/amalgam/gen/wasmsimd.c)
161+
include(${XNNPACK_DIR}/cmake/gen/wasmsimd_microkernels.cmake)
162+
list(APPEND microkernel_src ${PROD_WASMSIMD_MICROKERNEL_SRCS})
153163
target_compile_options(XNNPACK PRIVATE "-msimd128")
154164
endif()
165+
list(TRANSFORM microkernel_src PREPEND "${XNNPACK_DIR}/")
166+
list(APPEND wasm_srcs ${microkernel_src})
155167

156168
message(DEBUG "wasm_srcs: ${wasm_srcs}\n")
157169
target_sources(XNNPACK PRIVATE ${wasm_srcs})

docs/Versioning.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,67 @@ The version number of the current stable release can be found
1111
## Release cadence
1212
See [Release Management](ReleaseManagement.md)
1313

14+
## Updating the Version for a Release
15+
16+
When preparing a release, follow these steps to update the version number across the codebase. This applies both when creating an initial release branch (updating `main`) and when preparing patch releases on release branches:
17+
18+
### Prerequisites
19+
- Node.js (check [js/.nvmrc](../js/.nvmrc) for the required version)
20+
- npm (comes with Node.js)
21+
- Python 3
22+
23+
Verify your setup:
24+
```bash
25+
node --version # Should match the version in js/.nvmrc
26+
npm --version # Should be v8.0 or newer
27+
```
28+
29+
### Steps
30+
31+
1. **Update the VERSION_NUMBER file**
32+
33+
Edit [VERSION_NUMBER](../VERSION_NUMBER) in the repository root to reflect the new version (e.g., `1.23.3`).
34+
35+
2. **Run the version update script**
36+
37+
From the repository root, run:
38+
```bash
39+
python tools/python/update_version.py
40+
```
41+
42+
This script automatically updates version numbers in:
43+
- `docs/Versioning.md` - Adds a new row to the version table
44+
- `docs/python/README.rst` - Adds release notes entry
45+
- `onnxruntime/__init__.py` - Python package version
46+
- `js/` packages - All NPM package versions and lock files
47+
48+
3. **Update the C API static_assert (Manual Step)**
49+
50+
The script does **not** update the version check in the C API. You must manually update the `static_assert` in [onnxruntime/core/session/onnxruntime_c_api.cc](../onnxruntime/core/session/onnxruntime_c_api.cc).
51+
52+
Search for `static_assert(std::string_view(ORT_VERSION)` and update the version string:
53+
```cpp
54+
static_assert(std::string_view(ORT_VERSION) == "X.Y.Z",
55+
"ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly");
56+
```
57+
58+
Replace `X.Y.Z` with your new version number. The comments following this assert explain additional steps if new APIs were added to this release.
59+
60+
4. **Review all changes**
61+
62+
Review all modified files. Verify:
63+
- Version numbers are correct in all updated files
64+
- The release notes URL format is correct (e.g., `https://github.com/Microsoft/onnxruntime/releases/tag/vX.Y.Z`)
65+
66+
5. **Commit and create PR**
67+
68+
Commit all changes and create a PR targeting `main` or a release branch as appropriate.
69+
70+
### Notes
71+
72+
- The version table in this file and the ONNX opset compatibility information on [onnxruntime.ai](https://onnxruntime.ai/docs/reference/compatibility.html#onnx-opset-support) are the canonical sources for version compatibility information.
73+
- For ONNX version/opset/IR reference numbers, see the [ONNX Versioning documentation](https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions).
74+
1475
# Compatibility
1576
1677
## Backwards compatibility

js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ export const parseConvTransposeAttributes = (attributes: Record<string, unknown>
132132
typeof attributes.autoPad == 'undefined' ? 0 : (attributes.autoPad as number)
133133
];
134134
const dilations = attributes.dilations as [number, number];
135-
const group = attributes.group as number;
135+
const group = (attributes.group as number) ?? 1; // default to 1 per ONNX spec
136136
const kernelShape = attributes.kernelShape as [number, number];
137137
const pads = attributes.pads as [number, number, number, number];
138138
const strides = attributes.strides as [number, number];

onnxruntime/core/platform/windows/telemetry.cc

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
#include "core/platform/windows/telemetry.h"
55
#include <mutex>
6+
#include <string>
7+
#include <vector>
8+
#include <cwchar>
9+
#include <winsvc.h>
610
#include "core/common/logging/logging.h"
711
#include "onnxruntime_config.h"
812

@@ -51,6 +55,80 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim
5155
// {3a26b1ff-7484-7484-7484-15261f42614d}
5256
(0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d),
5357
TraceLoggingOptionMicrosoftTelemetry());
58+
59+
std::string ConvertWideStringToUtf8(const std::wstring& wide) {
60+
if (wide.empty())
61+
return {};
62+
63+
const UINT code_page = CP_UTF8;
64+
const DWORD flags = 0;
65+
LPCWCH const src = wide.data();
66+
const int src_len = static_cast<int>(wide.size());
67+
int utf8_length = ::WideCharToMultiByte(code_page, flags, src, src_len, nullptr, 0, nullptr, nullptr);
68+
if (utf8_length == 0)
69+
return {};
70+
71+
std::string utf8(utf8_length, '\0');
72+
if (::WideCharToMultiByte(code_page, flags, src, src_len, utf8.data(), utf8_length, nullptr, nullptr) == 0)
73+
return {};
74+
75+
return utf8;
76+
}
77+
78+
std::string GetServiceNamesForCurrentProcess() {
79+
static std::once_flag once_flag;
80+
static std::string service_names;
81+
82+
std::call_once(once_flag, [] {
83+
SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE);
84+
if (service_manager == nullptr)
85+
return;
86+
87+
DWORD bytes_needed = 0;
88+
DWORD services_returned = 0;
89+
DWORD resume_handle = 0;
90+
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, nullptr, 0, &bytes_needed,
91+
&services_returned, &resume_handle, nullptr) &&
92+
::GetLastError() != ERROR_MORE_DATA) {
93+
::CloseServiceHandle(service_manager);
94+
return;
95+
}
96+
97+
if (bytes_needed == 0) {
98+
::CloseServiceHandle(service_manager);
99+
return;
100+
}
101+
102+
std::vector<uint8_t> buffer(bytes_needed);
103+
auto* services = reinterpret_cast<ENUM_SERVICE_STATUS_PROCESSW*>(buffer.data());
104+
services_returned = 0;
105+
resume_handle = 0;
106+
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, reinterpret_cast<LPBYTE>(services),
107+
bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) {
108+
::CloseServiceHandle(service_manager);
109+
return;
110+
}
111+
112+
DWORD current_pid = ::GetCurrentProcessId();
113+
std::wstring aggregated;
114+
bool first = true;
115+
for (DWORD i = 0; i < services_returned; ++i) {
116+
if (services[i].ServiceStatusProcess.dwProcessId == current_pid) {
117+
if (!first) {
118+
aggregated.push_back(L',');
119+
}
120+
aggregated.append(services[i].lpServiceName);
121+
first = false;
122+
}
123+
}
124+
125+
::CloseServiceHandle(service_manager);
126+
127+
service_names = ConvertWideStringToUtf8(aggregated);
128+
});
129+
130+
return service_names;
131+
}
54132
} // namespace
55133

56134
#ifdef _MSC_VER
@@ -178,6 +256,7 @@ void WindowsTelemetry::LogProcessInfo() const {
178256
#if BUILD_INBOX
179257
isRedist = false;
180258
#endif
259+
const std::string service_names = GetServiceNamesForCurrentProcess();
181260
TraceLoggingWrite(telemetry_provider_handle,
182261
"ProcessInfo",
183262
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
@@ -189,7 +268,8 @@ void WindowsTelemetry::LogProcessInfo() const {
189268
TraceLoggingString(ORT_VERSION, "runtimeVersion"),
190269
TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"),
191270
TraceLoggingBool(isRedist, "isRedist"),
192-
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
271+
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"),
272+
TraceLoggingString(service_names.c_str(), "serviceNames"));
193273

194274
process_info_logged = true;
195275
}
@@ -278,6 +358,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
278358
execution_provider_string += i;
279359
}
280360

361+
const std::string service_names = GetServiceNamesForCurrentProcess();
281362
// Difference is MeasureEvent & isCaptureState, but keep in sync otherwise
282363
if (!captureState) {
283364
TraceLoggingWrite(telemetry_provider_handle,
@@ -304,7 +385,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
304385
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
305386
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
306387
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
307-
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
388+
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
389+
TraceLoggingString(service_names.c_str(), "serviceNames"));
308390
} else {
309391
TraceLoggingWrite(telemetry_provider_handle,
310392
"SessionCreation_CaptureState",
@@ -330,7 +412,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
330412
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
331413
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
332414
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
333-
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
415+
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
416+
TraceLoggingString(service_names.c_str(), "serviceNames"));
334417
}
335418
}
336419

onnxruntime/core/providers/cpu/cpu_provider_shared.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
186186
std::unique_ptr<EinsumTypedComputeProcessor<float>> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<float>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
187187
std::unique_ptr<EinsumTypedComputeProcessor<double>> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<double>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
188188
std::unique_ptr<EinsumTypedComputeProcessor<MLFloat16>> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<MLFloat16>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
189-
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
190-
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
191-
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
189+
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); }
190+
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); }
191+
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); }
192192
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) override { return p->Run(); }
193193
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) override { return p->Run(); }
194194
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) override { return p->Run(); }

0 commit comments

Comments
 (0)