Skip to content

Commit 897438e

Browse files
committed
Adding Basic MKL C# Wrappers from torch.cs
1 parent 436de06 commit 897438e

File tree

7 files changed

+545
-0
lines changed

7 files changed

+545
-0
lines changed

examples/dotnet/Intel/Intel.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
public static partial class Intel {
2+
}

examples/dotnet/Intel/Intel.mkl.cs

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#pragma warning disable CS8981
2+
3+
using System;
4+
using System.Diagnostics;
5+
using System.Runtime.InteropServices;
6+
using System.Text;
7+
8+
public static partial class Intel {
9+
public unsafe static partial class mkl {
10+
static readonly IntPtr mkl_rt;
11+
12+
public struct MKLVersion {
13+
public int MajorVersion;
14+
public int MinorVersion;
15+
public int UpdateVersion;
16+
public byte* ProductStatus;
17+
public byte* Build;
18+
public byte* Processor;
19+
public byte* Platform;
20+
}
21+
22+
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
23+
delegate void MKL_Get_Version(out MKLVersion version);
24+
25+
public static MKLVersion? Version;
26+
27+
static mkl() {
28+
string StringFromPChar(byte* pSz) {
29+
int len = 0;
30+
for (int i = 0; i < 1024; i++) {
31+
if (pSz[i] == 0) {
32+
break;
33+
}
34+
len++;
35+
}
36+
return Encoding.UTF8.GetString(pSz, len);
37+
}
38+
Version = null;
39+
if (kernel32.LoadLibraryW("c:\\python312\\library\\bin\\mkl_rt.2.dll", out mkl_rt)) {
40+
Debug.Write("\n> Found " + kernel32.GetModuleFileName(mkl_rt) + "\n");
41+
if (!kernel32.GetProcAddress(mkl_rt, "MKL_Get_Version", out MKL_Get_Version mkl_get_version)) {
42+
goto error;
43+
}
44+
mkl_get_version(out MKLVersion version);
45+
Debug.Write($"> Major version = {version.MajorVersion}\n");
46+
Debug.Write($"> Minor version = {version.MinorVersion}\n");
47+
Debug.Write($"> Update version = {version.UpdateVersion}\n");
48+
Debug.Write($"> Product status = {StringFromPChar(version.ProductStatus)}\n");
49+
Debug.Write($"> Build = {StringFromPChar(version.Build)}\n");
50+
Debug.Write($"> Platform = {StringFromPChar(version.Platform)}\n\n");
51+
Version = version;
52+
if (!kernel32.GetProcAddress(mkl_rt, "cblas_sdot", out sdot)) {
53+
goto error;
54+
}
55+
if (!kernel32.GetProcAddress(mkl_rt, "cblas_sscal", out sscal)) {
56+
goto error;
57+
}
58+
if (!kernel32.GetProcAddress(mkl_rt, "cblas_sgemm", out sgemm)) {
59+
goto error;
60+
}
61+
if (!kernel32.GetProcAddress(mkl_rt, "cblas_sgemv", out sgemv)) {
62+
goto error;
63+
}
64+
if (!kernel32.GetProcAddress(mkl_rt, "cblas_saxpy", out saxpy)) {
65+
goto error;
66+
}
67+
if (!kernel32.GetProcAddress(mkl_rt, "vsTanh", out tanh)) {
68+
goto error;
69+
}
70+
return;
71+
}
72+
error:
73+
Version = null;
74+
if (mkl_rt != IntPtr.Zero) {
75+
Debug.Write("> \u001b[33mWARNING: Version not supported.\u001b[0m\n");
76+
kernel32.FreeLibrary(mkl_rt);
77+
mkl_rt = IntPtr.Zero;
78+
}
79+
tanh = (n, a, y) => throw new NotSupportedException($"Not found '{nameof(tanh)}': Intel® oneAPI Math Kernel Library.");
80+
sdot = error_sdot;
81+
sscal = error_sscal;
82+
saxpy = error_saxpy;
83+
sgemv = error_sgemv;
84+
sgemm = error_sgemm;
85+
}
86+
87+
public static bool IsSupported => mkl_rt != IntPtr.Zero && Version != null;
88+
89+
[System.Security.SuppressUnmanagedCodeSecurity]
90+
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
91+
public unsafe delegate void vsTanh(int N, float* a, float* y);
92+
93+
/// <see href="https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-0/v-tanh.html"></see>
94+
public static readonly vsTanh tanh;
95+
96+
/// <summary>
97+
/// x = a * x
98+
/// </summary>
99+
/// <see href="https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-1/cblas-scal.html"></see>
100+
public static readonly CBLAS_SSCAL sscal;
101+
102+
/// <summary>
103+
/// y := a * x + y
104+
/// </summary>
105+
/// <see href="https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-1/cblas-axpy.html"></see>
106+
public static readonly CBLAS_SAXPY saxpy;
107+
108+
/// <summary>
109+
/// y := alpha * A * x + beta * y
110+
/// </summary>
111+
/// <see href="https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-1/cblas-gemv.html"></see>
112+
public static readonly CBLAS_SGEMV sgemv;
113+
114+
/// <summary>
115+
/// C := alpha* op(A)*op(B) + beta* C
116+
/// </summary>
117+
/// <see href="https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-1/cblas-gemm-001.html"></see>
118+
public static readonly CBLAS_SGEMM sgemm;
119+
120+
/// <summary>
121+
/// Computes a vector-vector dot product.
122+
/// </summary>
123+
/// <see href="https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-1/cblas-dot.html"></see>
124+
public static readonly CBLAS_SDOT sdot;
125+
126+
unsafe static void error_sscal(int N, float a, float* x, int incx) {
127+
throw new NotSupportedException($"Not found '{nameof(sscal)}': Intel® oneAPI Math Kernel Library.");
128+
}
129+
130+
unsafe static float error_sdot(int N, float* x, int incx, float* y, int incy) {
131+
throw new NotSupportedException($"Not found '{nameof(sdot)}': Intel® oneAPI Math Kernel Library.");
132+
}
133+
134+
unsafe static void error_saxpy(int N, float a, float* x, int incx, float* y, int incy) {
135+
throw new NotSupportedException($"Not found '{nameof(saxpy)}': Intel® oneAPI Math Kernel Library.");
136+
}
137+
138+
unsafe static void error_sgemv(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE Trans, int M, int N, float alpha, float* A, int lda, float* x, int incx, float beta, float* y, int incy) {
139+
throw new NotSupportedException($"Not found '{nameof(sgemv)}': Intel® oneAPI Math Kernel Library.");
140+
}
141+
142+
unsafe static void error_sgemm(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, int M, int N, int K, float alpha, float* A, int lda, float* B, int ldb, float beta, float* C, int ldc) {
143+
throw new NotSupportedException($"Not found '{nameof(sgemm)}': Intel® oneAPI Math Kernel Library.");
144+
}
145+
}
146+
}

examples/dotnet/cBLAS.cs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms
2+
3+
[System.Security.SuppressUnmanagedCodeSecurity]
4+
[System.Runtime.InteropServices.UnmanagedFunctionPointer(System.Runtime.InteropServices.CallingConvention.Cdecl)]
5+
public unsafe delegate void CBLAS_SSCAL(int N, float a, float* x, int incx);
6+
7+
[System.Security.SuppressUnmanagedCodeSecurity]
8+
[System.Runtime.InteropServices.UnmanagedFunctionPointer(System.Runtime.InteropServices.CallingConvention.Cdecl)]
9+
public unsafe delegate void CBLAS_SAXPY(int N, float a, float* x, int incx, float* y, int incy);
10+
11+
/// <summary>
12+
/// Computes a vector-vector dot product.
13+
/// </summary>
14+
/// <param name="N">Specifies the number of elements in vectors x and y.</param>
15+
/// <param name="x">Array, size at least (1+(n-1)*abs(incx)).</param>
16+
/// <param name="incx">Specifies the increment for the elements of x.</param>
17+
/// <param name="y">Array, size at least (1+(n-1)*abs(incy)).</param>
18+
/// <param name="incy">Specifies the increment for the elements of y.</param>
19+
/// <returns>The result of the dot product of x and y, if n is positive. Otherwise, returns 0.</returns>
20+
[System.Security.SuppressUnmanagedCodeSecurity]
21+
[System.Runtime.InteropServices.UnmanagedFunctionPointer(System.Runtime.InteropServices.CallingConvention.Cdecl)]
22+
public unsafe delegate float CBLAS_SDOT(int N, float* x, int incx, float* y, int incy);
23+
24+
public enum CBLAS_LAYOUT : int {
25+
RowMajor = 101,
26+
ColMajor = 102
27+
};
28+
29+
public enum CBLAS_TRANSPOSE : int {
30+
NoTrans = 111,
31+
Trans = 112,
32+
/*ConjTrans = 113*/
33+
};
34+
35+
[System.Security.SuppressUnmanagedCodeSecurity]
36+
[System.Runtime.InteropServices.UnmanagedFunctionPointer(System.Runtime.InteropServices.CallingConvention.Cdecl)]
37+
public unsafe delegate void CBLAS_SGEMV(
38+
CBLAS_LAYOUT Layout,
39+
CBLAS_TRANSPOSE TransA,
40+
int M,
41+
int N,
42+
float alpha,
43+
float* A,
44+
int lda,
45+
float* X,
46+
int incX,
47+
float beta,
48+
float* Y,
49+
int incY);
50+
51+
[System.Security.SuppressUnmanagedCodeSecurity]
52+
[System.Runtime.InteropServices.UnmanagedFunctionPointer(System.Runtime.InteropServices.CallingConvention.Cdecl)]
53+
public unsafe delegate void CBLAS_SGEMM(
54+
CBLAS_LAYOUT Layout,
55+
CBLAS_TRANSPOSE TransA,
56+
CBLAS_TRANSPOSE TransB,
57+
int M,
58+
int N,
59+
int K,
60+
float alpha,
61+
float* A,
62+
int lda,
63+
float* B,
64+
int ldb,
65+
float beta,
66+
float* C,
67+
int ldc);
68+
69+
public unsafe static partial class cBLAS {
70+
public const CBLAS_LAYOUT CblasColMajor = CBLAS_LAYOUT.ColMajor;
71+
public const CBLAS_LAYOUT CblasRowMajor = CBLAS_LAYOUT.RowMajor;
72+
73+
public const CBLAS_TRANSPOSE CblasNoTrans = CBLAS_TRANSPOSE.NoTrans;
74+
public const CBLAS_TRANSPOSE CblasTrans = CBLAS_TRANSPOSE.Trans;
75+
}
76+
77+
public unsafe static partial class cBLAS {
78+
public static bool IsSupported => true;
79+
}

examples/dotnet/dotnet.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ namespace dotnet {
1010
class Demo {
1111
static unsafe void Main() {
1212

13+
Console.WriteLine(Intel.mkl.Version.ToString());
14+
1315
IntPtr db = Embeddings.AllocDb();
1416

1517
// OPEN (like mode "a+"): GENERIC_READ | FILE_APPEND_DATA, OPEN_ALWAYS

examples/dotnet/dotnet.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,15 @@
4949
<Compile Include="..\..\src\embeddings.cs">
5050
<Link>embeddings.cs</Link>
5151
</Compile>
52+
<Compile Include="cBLAS.cs" />
5253
<Compile Include="dotnet.cs" />
54+
<Compile Include="Intel\Intel.cs" />
55+
<Compile Include="Intel\Intel.mkl.cs" />
56+
<Compile Include="kernel32.cs" />
5357
</ItemGroup>
5458
<ItemGroup>
5559
<None Include="App.config" />
5660
</ItemGroup>
61+
<ItemGroup />
5762
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
5863
</Project>

0 commit comments

Comments
 (0)