Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmake/external/abseil-cpp.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ else()
endif()
endif()

if(Patch_FOUND AND WIN32)
set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_windows.patch)
if(Patch_FOUND)
if (WIN32)
set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_windows.patch &&
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_cuda_warnings.patch)
else()
set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_cuda_warnings.patch)
endif()
else()
set(ABSL_PATCH_COMMAND "")
endif()
Expand Down
22 changes: 15 additions & 7 deletions cmake/external/cuda_configuration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ macro(setup_cuda_architectures)
# * Always use accelerated (`-a` suffix) target for supported real architectures.
# cmake-format: on

# Allow override via CUDAARCHS environment variable (standard CMake variable)
if(NOT CMAKE_CUDA_ARCHITECTURES AND DEFINED ENV{CUDAARCHS})
set(CMAKE_CUDA_ARCHITECTURES "$ENV{CUDAARCHS}")
endif()

if(CMAKE_CUDA_ARCHITECTURES STREQUAL "native")
# Detect highest available compute capability
set(OUTPUTFILE ${PROJECT_BINARY_DIR}/detect_cuda_arch)
Expand Down Expand Up @@ -139,12 +144,12 @@ macro(setup_cuda_architectures)
continue()
endif()

if(CUDA_ARCH MATCHES "^([1-9])([0-9])+a?-virtual$")
if(CUDA_ARCH MATCHES "^([1-9])([0-9])+[af]?-virtual$")
set(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL ${CUDA_ARCH})
elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?-real$")
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1})
elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?$")
elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)[af]?-real$")
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1})
elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)([af]?)$")
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}${CMAKE_MATCH_4})
else()
message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}")
endif()
Expand All @@ -156,7 +161,7 @@ macro(setup_cuda_architectures)
set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}")
message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}")

set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "120")
set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "110" "120")
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}")
Expand All @@ -165,10 +170,13 @@ macro(setup_cuda_architectures)
endforeach()

# Enable accelerated features (like WGMMA, TMA and setmaxnreg) for SM >= 90.
set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "120")
set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "110" "120")
unset(CMAKE_CUDA_ARCHITECTURES_NORMALIZED)
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
if("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL)
if(CUDA_ARCH MATCHES "^([0-9]+)f$")
# Family code, no -real suffix
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}")
elseif("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL)
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real")
else()
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real")
Expand Down
40 changes: 40 additions & 0 deletions cmake/patches/abseil/absl_cuda_warnings.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
diff --git a/absl/hash/internal/hash.h b/absl/hash/internal/hash.h
index 1234567..abcdefg 100644
--- a/absl/hash/internal/hash.h
+++ b/absl/hash/internal/hash.h
@@ -477,7 +477,7 @@ H AbslHashValue(H hash_state, T (&)[N]) {
template <typename H, typename T, size_t N>
H AbslHashValue(H hash_state, T (&)[N]) {
static_assert(
- sizeof(T) == -1,
+ sizeof(T) == size_t(-1),
"Hashing C arrays is not allowed. For string literals, wrap the literal "
"in absl::string_view(). To hash the array contents, use "
"absl::MakeSpan() or make the array an std::array. To hash the array "
diff --git a/absl/hash/hash.h b/absl/hash/hash.h
index 1234567..abcdefg 100644
--- a/absl/hash/hash.h
+++ b/absl/hash/hash.h
@@ -333,7 +333,8 @@ class HashState : public hash_internal::HashStateBase<HashState> {
absl::enable_if_t<
std::is_base_of<hash_internal::HashStateBase<T>, T>::value, int> = 0>
static HashState Create(T* state) {
- HashState s;
+ HashState s = {};
+ (void)s;
s.Init(state);
return s;
}
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index 1234567..abcdefg 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -464,7 +464,7 @@ inline uint16_t NextSeed() {
inline uint16_t NextSeed() {
static_assert(PerTableSeed::kBitCount == 16);
thread_local uint16_t seed =
- static_cast<uint16_t>(reinterpret_cast<uintptr_t>(&seed));
+ static_cast<uint16_t>(reinterpret_cast<uintptr_t>(&seed) & 0xFFFFu);
seed += uint16_t{0xad53};
return seed;
}
40 changes: 40 additions & 0 deletions cmake/vcpkg-ports/abseil/absl_cuda_warnings.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
diff --git a/absl/hash/internal/hash.h b/absl/hash/internal/hash.h
index 1234567..abcdefg 100644
--- a/absl/hash/internal/hash.h
+++ b/absl/hash/internal/hash.h
@@ -477,7 +477,7 @@ H AbslHashValue(H hash_state, T (&)[N]) {
template <typename H, typename T, size_t N>
H AbslHashValue(H hash_state, T (&)[N]) {
static_assert(
- sizeof(T) == -1,
+ sizeof(T) == size_t(-1),
"Hashing C arrays is not allowed. For string literals, wrap the literal "
"in absl::string_view(). To hash the array contents, use "
"absl::MakeSpan() or make the array an std::array. To hash the array "
diff --git a/absl/hash/hash.h b/absl/hash/hash.h
index 1234567..abcdefg 100644
--- a/absl/hash/hash.h
+++ b/absl/hash/hash.h
@@ -333,7 +333,8 @@ class HashState : public hash_internal::HashStateBase<HashState> {
absl::enable_if_t<
std::is_base_of<hash_internal::HashStateBase<T>, T>::value, int> = 0>
static HashState Create(T* state) {
- HashState s;
+ HashState s = {};
+ (void)s;
s.Init(state);
return s;
}
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index 1234567..abcdefg 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -464,7 +464,7 @@ inline uint16_t NextSeed() {
inline uint16_t NextSeed() {
static_assert(PerTableSeed::kBitCount == 16);
thread_local uint16_t seed =
- static_cast<uint16_t>(reinterpret_cast<uintptr_t>(&seed));
+ static_cast<uint16_t>(reinterpret_cast<uintptr_t>(&seed) & 0xFFFFu);
seed += uint16_t{0xad53};
return seed;
}
1 change: 1 addition & 0 deletions cmake/vcpkg-ports/abseil/portfile.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ vcpkg_from_github(
SHA512 4ee1a217203933382e728d354a149253a517150eee7580a0abecc69584b2eb200d91933ef424487e3a3fe0e8ab5e77b0288485cac982171b3585314a4417e7d4
HEAD_REF master
PATCHES absl_windows.patch
absl_cuda_warnings.patch
)


Expand Down
140 changes: 132 additions & 8 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Reflection;
using System.Runtime.InteropServices;
using static Microsoft.ML.OnnxRuntime.NativeMethods;

Expand Down Expand Up @@ -474,6 +475,12 @@ internal static class NativeMethods

static NativeMethods()
{
#if !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__
// Register a custom DllImportResolver to handle platform-specific library loading.
// Replaces default resolution specifically on Windows for case-sensitivity.
NativeLibrary.SetDllImportResolver(typeof(NativeMethods).Assembly, DllImportResolver);
#endif

#if NETSTANDARD2_0
IntPtr ortApiBasePtr = OrtGetApiBase();
OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase));
Expand Down Expand Up @@ -847,7 +854,7 @@ static NativeMethods()
api_.CreateSyncStreamForEpDevice,
typeof(DOrtCreateSyncStreamForEpDevice));

OrtSyncStream_GetHandle =
OrtSyncStream_GetHandle =
(DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer(
api_.SyncStream_GetHandle,
typeof(DOrtSyncStream_GetHandle));
Expand All @@ -872,11 +879,127 @@ internal class NativeLib
// Define the library name required for iOS
internal const string DllName = "__Internal";
#else
// Note: the file name in ONNX Runtime nuget package must be onnxruntime.dll instead of onnxruntime.DLL(Windows filesystem can be case sensitive)
internal const string DllName = "onnxruntime.dll";
// For desktop platforms (including .NET Standard 2.0), we use the simple name
// to allow .NET's automatic platform-specific resolution (lib*.so, lib*.dylib, *.dll).
// For .NET Core 3.0+, case-sensitivity on Windows is handled by DllImportResolver.
internal const string DllName = "onnxruntime";
#endif
}

#if !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__
/// <summary>
/// Custom DllImportResolver to handle platform-specific library loading.
/// On Windows, it explicitly loads the library with a lowercase .dll extension to handle
/// case-sensitive filesystems.
/// </summary>
private static IntPtr DllImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? searchPath)
{
if (libraryName == NativeLib.DllName || libraryName == OrtExtensionsNativeMethods.ExtensionsDllName)
{
string mappedName = null;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// Explicitly load with .dll extension to avoid issues where the OS might try .DLL
mappedName = libraryName + ".dll";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
// Explicitly load with .so extension and lib prefix
mappedName = "lib" + libraryName + ".so";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
// Explicitly load with .dylib extension and lib prefix
mappedName = "lib" + libraryName + ".dylib";
}

if (mappedName != null)
{
// 1. Try default loading (name only)
if (NativeLibrary.TryLoad(mappedName, assembly, searchPath, out IntPtr handle))
{
return handle;
}

// 2. Try relative to assembly location (look into runtimes subfolders)
string assemblyLocation = null;
try { assemblyLocation = assembly.Location; } catch { }
if (!string.IsNullOrEmpty(assemblyLocation))
{
string assemblyDir = System.IO.Path.GetDirectoryName(assemblyLocation);
string rid = RuntimeInformation.RuntimeIdentifier;

// Probe the specific RID first, then common fallbacks for the current OS
string[] ridsToTry;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
ridsToTry = new[] { rid, "win-x64", "win-arm64" };
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
ridsToTry = new[] { rid, "linux-x64", "linux-arm64" };
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
// We no longer provide osx-x64 in official package since 1.24.
// However, we keep it in the list for build-from-source users.
ridsToTry = new[] { rid, "osx-arm64", "osx-x64" };
}
else
{
ridsToTry = new[] { rid };
}

foreach (var tryRid in ridsToTry)
{
string probePath = System.IO.Path.Combine(assemblyDir, "runtimes", tryRid, "native", mappedName);
if (System.IO.File.Exists(probePath) && NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
}
}
}

// 3. Try AppContext.BaseDirectory as a fallback
string baseDir = AppContext.BaseDirectory;
if (!string.IsNullOrEmpty(baseDir))
{
string probePath = System.IO.Path.Combine(baseDir, mappedName);
if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
}

string rid = RuntimeInformation.RuntimeIdentifier;
probePath = System.IO.Path.Combine(baseDir, "runtimes", rid, "native", mappedName);
if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle))
{
LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}");
return handle;
}
}

LogLibLoad($"[DllImportResolver] Failed loading {mappedName} (RID: {RuntimeInformation.RuntimeIdentifier}, Assembly: {assemblyLocation})");

}
}

// Fall back to default resolution
return IntPtr.Zero;
}

private static void LogLibLoad(string message)
{
System.Diagnostics.Trace.WriteLine(message);
if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("ORT_LOADER_VERBOSITY")))
{
Console.WriteLine(message);
}
}
#endif

[DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
#if NETSTANDARD2_0
public static extern IntPtr OrtGetApiBase();
Expand Down Expand Up @@ -2644,7 +2767,7 @@ public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
byte[] /* const char* */ value);

/// <summary>
/// Get the value for the provided key.
/// Get the value for the provided key.
/// </summary>
/// <returns>Value. Returns IntPtr.Zero if key was not found.</returns>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
Expand Down Expand Up @@ -2767,7 +2890,7 @@ out IntPtr /* OrtSyncStream** */ stream
// Auto Selection EP registration and selection customization

/// <summary>
/// Register an execution provider library.
/// Register an execution provider library.
/// The library must implement CreateEpFactories and ReleaseEpFactory.
/// </summary>
/// <param name="env">Environment to add the EP library to.</param>
Expand Down Expand Up @@ -2952,9 +3075,10 @@ internal static class OrtExtensionsNativeMethods
#elif __IOS__
internal const string ExtensionsDllName = "__Internal";
#else
// For desktop platforms, explicitly specify the DLL name with extension to avoid
// issues on case-sensitive filesystems. See NativeLib.DllName for detailed explanation.
internal const string ExtensionsDllName = "ortextensions.dll";
// For desktop platforms, use the simple name to allow .NET's
// automatic platform-specific resolution (lib*.so, lib*.dylib, *.dll).
// Case-sensitivity on Windows is handled by DllImportResolver.
internal const string ExtensionsDllName = "ortextensions";
#endif

[DllImport(ExtensionsDllName, CharSet = CharSet.Ansi,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@

<!-- arm64 -->
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime.dll"
Condition="'$(PlatformTarget)' == 'ARM64'">
Condition="'$(PlatformTarget)' == 'ARM64' AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime.dll')">
<Link>onnxruntime.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
Expand All @@ -128,7 +129,8 @@

<!-- arm -->
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-arm\native\onnxruntime.dll"
Condition="'$(PlatformTarget)' == 'ARM'">
Condition="'$(PlatformTarget)' == 'ARM' AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm\native\onnxruntime.dll')">
<Link>onnxruntime.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
Expand Down
Loading
Loading