Skip to content

Commit 556d9cb

Browse files
authored
Merge branch 'main' into dev/tirupath/fusedmatmul_op
2 parents 00cd7c6 + ba11af4 commit 556d9cb

File tree

113 files changed

+6785
-539
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

113 files changed

+6785
-539
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
4545
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
4646
${MLAS_SRC_DIR}/qnbitgemm.h
4747
${MLAS_SRC_DIR}/qnbitgemm.cpp
48+
${MLAS_SRC_DIR}/qlutgemm.h
49+
${MLAS_SRC_DIR}/qlutgemm.cpp
4850
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
4951
${MLAS_SRC_DIR}/flashattn.cpp
5052
${MLAS_SRC_DIR}/cast.cpp
@@ -209,6 +211,8 @@ function(setup_mlas_source_for_windows)
209211
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
210212
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
211213
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
214+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h
215+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
212216
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
213217
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
214218
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
@@ -693,6 +697,8 @@ else()
693697
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
694698
${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp
695699
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
700+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h
701+
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
696702
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
697703
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
698704
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp

include/onnxruntime/core/framework/data_types.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "core/common/float8.h"
1717
#include "core/common/float16.h"
1818
#include "core/framework/int4.h"
19+
#include "core/framework/int2.h"
1920
#include "core/framework/float4.h"
2021
#include "core/graph/onnx_protobuf.h"
2122
#include "core/framework/to_tensor_proto_element_type.h"
@@ -211,6 +212,7 @@ class DataTypeImpl {
211212
static const std::vector<MLDataType>& AllTensorTypesIRv9();
212213
static const std::vector<MLDataType>& AllTensorTypesIRv10();
213214
static const std::vector<MLDataType>& AllTensorTypesIRv11();
215+
static const std::vector<MLDataType>& AllTensorTypesIRv13();
214216

215217
static const std::vector<MLDataType>& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated
216218
static const std::vector<MLDataType>& AllFixedSizeTensorTypesIRv4();
@@ -285,7 +287,7 @@ template <typename T>
285287
struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
286288
int32_t, int64_t, std::string, bool, MLFloat16,
287289
double, uint32_t, uint64_t, BFloat16,
288-
Int4x2, UInt4x2
290+
Int4x2, UInt4x2, Int2x4, UInt2x4
289291
#if !defined(DISABLE_FLOAT8_TYPES)
290292
,
291293
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
@@ -304,7 +306,8 @@ struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_
304306
template <typename T>
305307
struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
306308
int32_t, int64_t, std::string, bool, MLFloat16,
307-
double, uint32_t, uint64_t, BFloat16
309+
double, uint32_t, uint64_t, BFloat16,
310+
Int4x2, UInt4x2, Int2x4, UInt2x4
308311
#if !defined(DISABLE_FLOAT8_TYPES)
309312
,
310313
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ

include/onnxruntime/core/framework/data_types_internal.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ namespace utils {
102102
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
103103
function<UInt4x2>(__VA_ARGS__); \
104104
break; \
105+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
106+
function<Int2x4>(__VA_ARGS__); \
107+
break; \
108+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
109+
function<UInt2x4>(__VA_ARGS__); \
110+
break; \
105111
default: \
106112
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
107113
}
@@ -171,6 +177,12 @@ namespace utils {
171177
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
172178
retval = function<UInt4x2>(__VA_ARGS__); \
173179
break; \
180+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
181+
retval = function<Int2x4>(__VA_ARGS__); \
182+
break; \
183+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
184+
retval = function<UInt2x4>(__VA_ARGS__); \
185+
break; \
174186
default: \
175187
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
176188
}
@@ -230,6 +242,12 @@ namespace utils {
230242
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
231243
function<UInt4x2>(__VA_ARGS__); \
232244
break; \
245+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
246+
function<Int2x4>(__VA_ARGS__); \
247+
break; \
248+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
249+
function<UInt2x4>(__VA_ARGS__); \
250+
break; \
233251
default: \
234252
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
235253
}
@@ -287,6 +305,12 @@ namespace utils {
287305
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
288306
retval = function<UInt4x2>(__VA_ARGS__); \
289307
break; \
308+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
309+
retval = function<Int2x4>(__VA_ARGS__); \
310+
break; \
311+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
312+
retval = function<UInt2x4>(__VA_ARGS__); \
313+
break; \
290314
default: \
291315
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
292316
}
@@ -355,6 +379,12 @@ namespace utils {
355379
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
356380
function<UInt4x2>(__VA_ARGS__); \
357381
break; \
382+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
383+
function<Int2x4>(__VA_ARGS__); \
384+
break; \
385+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
386+
function<UInt2x4>(__VA_ARGS__); \
387+
break; \
358388
default: \
359389
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
360390
}
@@ -421,6 +451,12 @@ namespace utils {
421451
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
422452
retval = function<UInt4x2>(__VA_ARGS__); \
423453
break; \
454+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
455+
retval = function<Int2x4>(__VA_ARGS__); \
456+
break; \
457+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
458+
retval = function<UInt2x4>(__VA_ARGS__); \
459+
break; \
424460
default: \
425461
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
426462
}
@@ -477,6 +513,12 @@ namespace utils {
477513
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
478514
function<UInt4x2>(__VA_ARGS__); \
479515
break; \
516+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
517+
function<Int2x4>(__VA_ARGS__); \
518+
break; \
519+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
520+
function<UInt2x4>(__VA_ARGS__); \
521+
break; \
480522
default: \
481523
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
482524
}
@@ -531,6 +573,12 @@ namespace utils {
531573
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
532574
retval = function<UInt4x2>(__VA_ARGS__); \
533575
break; \
576+
case ONNX_NAMESPACE::TensorProto_DataType_INT2: \
577+
retval = function<Int2x4>(__VA_ARGS__); \
578+
break; \
579+
case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \
580+
retval = function<UInt2x4>(__VA_ARGS__); \
581+
break; \
534582
default: \
535583
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
536584
}
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <cassert>
7+
#include <type_traits>
8+
#include "core/common/common.h"
9+
#include <gsl/gsl>
10+
11+
namespace onnxruntime {
12+
13+
template <bool Signed>
14+
struct Int2Traits;
15+
16+
template <>
17+
struct Int2Traits<true> {
18+
using UnpackedType = int8_t;
19+
static constexpr int8_t min_val = -2;
20+
static constexpr int8_t max_val = 1;
21+
};
22+
23+
template <>
24+
struct Int2Traits<false> {
25+
using UnpackedType = uint8_t;
26+
static constexpr uint8_t min_val = 0;
27+
static constexpr uint8_t max_val = 3;
28+
};
29+
30+
/// <summary>
31+
/// Stores 4 packed 2-bit elements in 1 byte.
32+
/// Packing follows ONNX spec: x0 | (x1 << 2) | (x2 << 4) | (x3 << 6)
33+
/// </summary>
34+
/// <typeparam name="Signed">Set to true if signed int2, or false if unsigned uint2.</typeparam>
35+
template <bool Signed>
36+
struct Int2x4Base {
37+
using UnpackedType = typename Int2Traits<Signed>::UnpackedType;
38+
static constexpr UnpackedType min_val = Int2Traits<Signed>::min_val;
39+
static constexpr UnpackedType max_val = Int2Traits<Signed>::max_val;
40+
41+
std::byte bits_{};
42+
43+
Int2x4Base() = default;
44+
45+
explicit Int2x4Base(std::byte bits) {
46+
bits_ = bits;
47+
}
48+
49+
Int2x4Base(UnpackedType val0, UnpackedType val1, UnpackedType val2, UnpackedType val3) {
50+
bits_ = static_cast<std::byte>(
51+
(val0 & 0x3) |
52+
((val1 & 0x3) << 2) |
53+
((val2 & 0x3) << 4) |
54+
((val3 & 0x3) << 6));
55+
}
56+
57+
static inline int8_t SignExtendLower2Bits(std::byte bits) {
58+
// Sign-extend lower 2-bits by left shifting and then doing an arithmetic right shift.
59+
constexpr uint8_t shift = (sizeof(int32_t) * 8) - 2;
60+
return static_cast<int8_t>((static_cast<int32_t>(bits) << shift) >> shift);
61+
}
62+
63+
inline UnpackedType GetElem(size_t index) const {
64+
assert(index <= 3);
65+
const uint8_t shift = 2 * static_cast<uint8_t>(index);
66+
const std::byte val = (bits_ >> shift) & std::byte{0x3};
67+
68+
if constexpr (Signed) {
69+
return SignExtendLower2Bits(val);
70+
} else {
71+
return static_cast<UnpackedType>(val);
72+
}
73+
}
74+
75+
inline void SetElem(size_t index, UnpackedType val) {
76+
assert(index <= 3);
77+
const uint8_t shift = 2 * static_cast<uint8_t>(index);
78+
const std::byte clear_mask = ~(std::byte{0x3} << shift);
79+
80+
bits_ &= clear_mask; // Clear 2-bit element to 0
81+
bits_ |= static_cast<std::byte>((val & 0x3) << shift); // Set 2-bit element to val
82+
}
83+
84+
inline std::byte ToBits() const {
85+
return bits_;
86+
}
87+
88+
/// <summary>
89+
/// Calculates the number of packed byte units needed to store the given number of 2-bit elements.
90+
/// Each byte stores 4 x 2-bit elements.
91+
/// </summary>
92+
static size_t CalcNumInt2Quads(size_t num_int2_elems) {
93+
return (num_int2_elems + 3) / 4;
94+
}
95+
96+
/// <summary>
97+
/// Copy a source buffer of 2-bit elements (packed) into a destination buffer of 8-bit elements (unpacked).
98+
/// </summary>
99+
/// <param name="dst">Destination buffer to store unpacked 8-bit elements</param>
100+
/// <param name="src">Source buffer with 2-bit elements</param>
101+
/// <returns>True on success</returns>
102+
static bool Unpack(gsl::span<UnpackedType> dst, gsl::span<const Int2x4Base<Signed>> src) {
103+
if (CalcNumInt2Quads(dst.size()) != src.size()) {
104+
return false;
105+
}
106+
107+
if (src.empty()) {
108+
return true;
109+
}
110+
111+
for (size_t i = 0; i < dst.size(); i++) {
112+
size_t byte_idx = i >> 2; // i / 4
113+
size_t elem_idx = i & 0x3; // i % 4
114+
dst[i] = src[byte_idx].GetElem(elem_idx);
115+
}
116+
117+
return true;
118+
}
119+
120+
/// <summary>
121+
/// Copy a source buffer of 8-bit elements (unpacked) into a destination buffer of 2-bit elements (packed).
122+
/// </summary>
123+
/// <param name="dst">Destination buffer to store packed 2-bit elements</param>
124+
/// <param name="src">Source buffer with 8-bit elements</param>
125+
/// <returns>True on success</returns>
126+
static bool Pack(gsl::span<Int2x4Base<Signed>> dst, gsl::span<const UnpackedType> src) {
127+
if (CalcNumInt2Quads(src.size()) != dst.size()) {
128+
return false;
129+
}
130+
131+
if (src.empty()) {
132+
return true;
133+
}
134+
135+
size_t src_i = 0;
136+
size_t dst_i = 0;
137+
const size_t full_quads = src.size() / 4;
138+
139+
// Process complete groups of 4 elements
140+
for (; dst_i < full_quads; dst_i++) {
141+
dst[dst_i] = Int2x4Base<Signed>(src[src_i], src[src_i + 1], src[src_i + 2], src[src_i + 3]);
142+
src_i += 4;
143+
}
144+
145+
// Handle remaining elements (1-3)
146+
if (src_i < src.size()) {
147+
UnpackedType vals[4] = {0, 0, 0, 0};
148+
size_t remaining = src.size() - src_i;
149+
for (size_t j = 0; j < remaining; j++) {
150+
vals[j] = src[src_i + j];
151+
}
152+
dst[dst_i] = Int2x4Base<Signed>(vals[0], vals[1], vals[2], vals[3]);
153+
}
154+
155+
return true;
156+
}
157+
158+
/// <summary>
159+
/// Returns hierarchical indices for a packed int2 element from the given element index.
160+
///
161+
/// Usage:
162+
/// Int2x4* data = ...;
163+
/// auto indices = GetTensorElemIndices(5); // 6th int2 element
164+
/// int8_t elem = data[indices.first].GetElem(indices.second);
165+
/// </summary>
166+
/// <param name="index">Index of 2-bit element</param>
167+
/// <returns>Pair of (byte_index, element_index_within_byte)</returns>
168+
static inline std::pair<size_t, size_t> GetTensorElemIndices(size_t index) {
169+
return {index >> 2, index & 0x3};
170+
}
171+
};
172+
173+
using Int2x4 = Int2x4Base<true>;
174+
using UInt2x4 = Int2x4Base<false>;
175+
static_assert(sizeof(Int2x4) == sizeof(std::byte));
176+
static_assert(sizeof(UInt2x4) == sizeof(std::byte));
177+
178+
} // namespace onnxruntime

include/onnxruntime/core/framework/to_tensor_proto_element_type.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "core/framework/float4.h"
1414
#include "core/common/float8.h"
1515
#include "core/common/float16.h"
16+
#include "core/framework/int2.h"
1617
#include "core/framework/int4.h"
1718

1819
namespace onnxruntime {
@@ -116,5 +117,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<UInt4x2>
116117
return ONNX_NAMESPACE::TensorProto_DataType_UINT4;
117118
}
118119

120+
template <>
121+
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<Int2x4>() {
122+
return ONNX_NAMESPACE::TensorProto_DataType_INT2;
123+
}
124+
template <>
125+
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<UInt2x4>() {
126+
return ONNX_NAMESPACE::TensorProto_DataType_UINT2;
127+
}
128+
119129
} // namespace utils
120130
} // namespace onnxruntime

0 commit comments

Comments
 (0)