Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 6d8bb4a

Browse files
authored
Sync itrex1.3 (#12)
1 parent 799f67c commit 6d8bb4a

File tree

9 files changed

+127
-59
lines changed

9 files changed

+127
-59
lines changed

CMakeLists.txt

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,7 @@ option(NE_AVX512_VBMI "neural_engine: enable AVX512-VBMI"
5959
option(NE_AVX512_VNNI "neural_engine: enable AVX512-VNNI" OFF)
6060
option(NE_FMA "neural_engine: enable FMA" ON)
6161
option(NE_AMX "neural_engine: enable AMX" OFF)
62-
63-
# in MSVC F16C is implied with AVX2/AVX512
64-
if (NOT MSVC)
65-
option(NE_F16C "neural_engine: enable F16C" ON)
66-
endif()
62+
option(NE_F16C "neural_engine: enable F16C" ON)
6763

6864
# 3rd party libs
6965
option(NE_ONEDNN "neural_engine: use oneDNN" ON)
@@ -93,6 +89,8 @@ if (NE_GELU_VEC)
9389
endif()
9490
option(NE_PYTHON_API "neural_engine: use python api" OFF)
9591
option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON)
92+
option(BUILD_SHARED_LIBS "If build as shared libs" ON)
93+
9694
if (NE_SIMD_VEC_DOT_F16)
9795
add_compile_definitions(NE_SIMD_VEC_DOT_F16)
9896
endif()
@@ -103,7 +101,6 @@ endif()
103101

104102
if (MSVC)
105103
add_compile_definitions(_CRT_SECURE_NO_WARNINGS NOMINMAX)
106-
107104
if (BUILD_SHARED_LIBS)
108105
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
109106
endif()

bestla/jblas/jit_blas_parallel.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ class SchedulerBase : public Scheduler2D {
204204
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
205205
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
206206
}
207-
const float DensityThres = 32;
207+
const float DensityThres = 16;
208208
static size_t constexpr ReservedSize = 32ULL * 1024ULL;
209209

210210
virtual float calculate_score() {
@@ -364,7 +364,7 @@ class SchedulerKBlock : public Scheduler2D {
364364
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
365365
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
366366
}
367-
const float DensityThres = 32;
367+
const float DensityThres = 16;
368368

369369
float calculate_score() {
370370
int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N;
@@ -489,13 +489,14 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
489489
this->mL2Use += static_cast<size_t>(blks) * (this->mBlock[1] + this->mStep[0]) *
490490
(sizeof(float) + sizeof(int8_t) + sizeof(float)); // scale+zp+reduce
491491
assert(this->mL2Use <= this->mL2Size - ReservedSize);
492-
assert(this->mBlock[0]>0);
493-
assert(this->mBlock[1]>0);
494-
assert(this->mBlock[2]>0);
492+
assert(this->mBlock[0] > 0);
493+
assert(this->mBlock[1] > 0);
494+
assert(this->mBlock[2] > 0);
495+
assert(this->mBlock[2] % _GemmCore_T::KTILE == 0);
495496
}
496497

497498
protected:
498-
const float DensityThres = 32;
499+
const float DensityThres = 16;
499500
static size_t constexpr ReservedSize = 32ULL * 1024ULL;
500501

501502
void cache_blocking_compute() override {
@@ -529,6 +530,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
529530
(this->mStep[0] * this->mEleSize[0] +
530531
float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock +
531532
this->mBlock[1] * this->mEleSize[1]));
533+
if (rawk < this->mKBlock) {
534+
rawk = static_cast<int>((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] -
535+
1 * CorSize * (this->mStep[0] + this->mBlock[1])) /
536+
(this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1]));
537+
}
532538
rawk = std::min(rawk, this->mSizePadded[2]);
533539
this->mBlock[2] = utils::padto_le(rawk, this->mStep[2]);
534540
if (this->mBlock[2] > this->mKBlock) {
@@ -569,9 +575,6 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
569575
this->mBlock[2] = static_cast<int>(getMaxK(this->mBlock[1]));
570576
this->mBlock[2] = utils::padto_le(this->mBlock[2], this->mStep[2]);
571577
this->mBlock[2] = std::min(mKBlock, this->mBlock[2]);
572-
auto tmp = utils::updiv(mKBlock, this->mBlock[2]);
573-
while (mKBlock % tmp != 0) tmp++; // TODO(Yu) optimize
574-
this->mBlock[2] = utils::downdiv(mKBlock, tmp);
575578
}
576579
}
577580

bestla/jblas/kernel_avx2.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,14 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
412412
for (; j < align_col; j += 8) quant();
413413
for (; j < col; j++) {
414414
auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type);
415-
if constexpr (std::is_same_v<_S_T, utils::f8>) {
416-
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
417-
} else if constexpr (std::is_same_v<_S_T, float>) {
418-
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
415+
if constexpr (WITH_SCALE) {
416+
if constexpr (std::is_same_v<_S_T, utils::f8>) {
417+
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
418+
} else if constexpr (std::is_same_v<_S_T, float>) {
419+
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
420+
}
421+
} else {
422+
dstptr[i * ld_dst + j] = fp_v;
419423
}
420424
}
421425
}
@@ -636,6 +640,14 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(
636640
vzps[iv] = _mm256_cvtepi8_epi32(tmp);
637641
}
638642
}
643+
auto rowre = row - irow;
644+
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
645+
for (; irow < rowpad4; irow += UnrollRow) {
646+
for (int iter16 = 0; iter16 < Loop16; iter16++)
647+
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
648+
for (int iterr = 0; iterr < UnrollRow; iterr++)
649+
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps);
650+
}
639651
for (; irow < row; irow++) {
640652
if constexpr (_NCOL == 24) {
641653
pad_bit4_16(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2));

bestla/jblas/kernel_avx512f.h

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,28 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr,
321321
vzps[iv] = _mm512_cvtepi8_epi32(tmp);
322322
}
323323
}
324-
}
325-
for (; irow < row; irow++) {
326-
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
327-
if constexpr (_IS_SYM) {
328-
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
329-
} else {
330-
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
324+
auto rowre = row - irow;
325+
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
326+
for (; irow < rowpad4; irow += UnrollRow) {
327+
for (int iter64 = 0; iter64 < Loop64; iter64++) {
328+
pad_bit4(tmpbuf + iter64 * 64, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 32 * iter64), zmm_mask,
329+
LoadMask64);
330+
}
331+
for (int iterr = 0; iterr < UnrollRow; iterr++) {
332+
if constexpr (_IS_SYM) {
333+
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr);
334+
} else {
335+
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps);
336+
}
337+
}
338+
}
339+
for (; irow < row; irow++) {
340+
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
341+
if constexpr (_IS_SYM) {
342+
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
343+
} else {
344+
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
345+
}
331346
}
332347
}
333348
return JblasSuccess;
@@ -565,7 +580,7 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
565580
auto quant = [&](__mmask16 mask) {
566581
__m128i f8_src;
567582
auto sign_revert =
568-
_mm512_cvtepi8_epi32(_mm_mask_loadu_epi8(f8_src, mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
583+
_mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
569584
auto e_revert = sign_revert;
570585
auto mantissa_revert = sign_revert;
571586
sign_revert = _mm512_slli_epi32(sign_revert, 24);
@@ -888,10 +903,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
888903
zmm2 = _mm512_add_ps(zmm2, zmm_zp);
889904
zmm3 = _mm512_add_ps(zmm3, zmm_zp);
890905
} else {
891-
mask4 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
892-
mask5 = _mm512_cmplt_ps_mask(zmm1, zmm_v0);
893-
mask6 = _mm512_cmplt_ps_mask(zmm2, zmm_v0);
894-
mask7 = _mm512_cmplt_ps_mask(zmm3, zmm_v0);
906+
mask4 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
907+
mask5 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 1);
908+
mask6 = _mm512_cmp_ps_mask(zmm2, zmm_v0, 1);
909+
mask7 = _mm512_cmp_ps_mask(zmm3, zmm_v0, 1);
895910

896911
zmm0 = _mm512_abs_ps(zmm0);
897912
zmm1 = _mm512_abs_ps(zmm1);
@@ -908,10 +923,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
908923
zmm5 = _mm512_sub_ps(zmm1, sub_v);
909924
zmm6 = _mm512_sub_ps(zmm2, sub_v);
910925
zmm7 = _mm512_sub_ps(zmm3, sub_v);
911-
mask0 = _mm512_cmple_ps_mask(zmm4, zmm_v0);
912-
mask1 = _mm512_cmple_ps_mask(zmm5, zmm_v0);
913-
mask2 = _mm512_cmple_ps_mask(zmm6, zmm_v0);
914-
mask3 = _mm512_cmple_ps_mask(zmm7, zmm_v0);
926+
mask0 = _mm512_cmp_ps_mask(zmm4, zmm_v0, 2);
927+
mask1 = _mm512_cmp_ps_mask(zmm5, zmm_v0, 2);
928+
mask2 = _mm512_cmp_ps_mask(zmm6, zmm_v0, 2);
929+
mask3 = _mm512_cmp_ps_mask(zmm7, zmm_v0, 2);
915930
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
916931
xmm1 = _mm_mask_blend_epi8(mask1, xmm1, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
917932
xmm2 = _mm_mask_blend_epi8(mask2, xmm2, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
@@ -949,7 +964,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
949964
auto zp = _mm512_set1_ps(0.8480964004993439f);
950965
zmm0 = _mm512_add_ps(zmm0, zp);
951966
} else {
952-
mask1 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
967+
mask1 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
953968
zmm0 = _mm512_abs_ps(zmm0);
954969
}
955970
constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8;
@@ -959,7 +974,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
959974
if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]);
960975
if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]);
961976
zmm1 = _mm512_sub_ps(zmm0, sub_v);
962-
mask0 = _mm512_cmple_ps_mask(zmm1, zmm_v0);
977+
mask0 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 2);
963978
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
964979
zmm0 = _mm512_mask_add_ps(zmm0, mask0, zmm0, avoid_double_cmp);
965980
}

bestla/jblas/kernel_ref.h

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,47 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) {
230230
dstptr[7] = tmp;
231231
}
232232

233+
inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) {
234+
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
235+
auto tmp = static_cast<int>(src32 & 0xf);
236+
dstptr[0] = static_cast<int8_t>(tmp);
237+
tmp = static_cast<int>(src32 & 0xf0) >> 4;
238+
dstptr[1] = static_cast<int8_t>(tmp);
239+
tmp = static_cast<int>((src32 & 0xf00) >> 8);
240+
dstptr[2] = static_cast<int8_t>(tmp);
241+
tmp = static_cast<int>((src32 & 0xf000) >> 12);
242+
dstptr[3] = static_cast<int8_t>(tmp);
243+
tmp = static_cast<int>((src32 & 0xf0000) >> 16);
244+
dstptr[4] = static_cast<int8_t>(tmp);
245+
tmp = static_cast<int>((src32 & 0xf00000) >> 20);
246+
dstptr[5] = static_cast<int8_t>(tmp);
247+
tmp = static_cast<int>((src32 & 0xf000000) >> 24);
248+
dstptr[6] = static_cast<int8_t>(tmp);
249+
tmp = static_cast<int>((src32 & 0xf0000000) >> 28);
250+
dstptr[7] = static_cast<int8_t>(tmp);
251+
}
252+
233253
template <>
234254
inline void convert_s4_s8_8<JBLAS_DTYPE::S4_FULLRANGE>(int8_t* dstptr, int8_t* srcptr) {
235-
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
236-
auto tmp = static_cast<int8_t>(src32 & 0xf);
237-
dstptr[0] = tmp - 8;
238-
tmp = static_cast<int8_t>(src32 & 0xf0) >> 4;
239-
dstptr[1] = tmp - 8;
240-
tmp = static_cast<int8_t>((src32 & 0xf00) >> 8);
241-
dstptr[2] = tmp - 8;
242-
tmp = static_cast<int8_t>((src32 & 0xf000) >> 12);
243-
dstptr[3] = tmp - 8;
244-
tmp = static_cast<int8_t>((src32 & 0xf0000) >> 16);
245-
dstptr[4] = tmp - 8;
246-
tmp = static_cast<int8_t>((src32 & 0xf00000) >> 20);
247-
dstptr[5] = tmp - 8;
248-
tmp = static_cast<int8_t>((src32 & 0xf000000) >> 24);
249-
dstptr[6] = tmp - 8;
250-
tmp = static_cast<int8_t>((src32 & 0xf0000000) >> 28);
251-
dstptr[7] = tmp - 8;
255+
convert_s4_s8_8_lowbits(dstptr, srcptr);
256+
for (size_t i = 0; i < 8; i++) {
257+
dstptr[i] -= 8;
258+
}
259+
}
260+
261+
template <>
262+
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_BNB>(int8_t* dstptr, int8_t* srcptr) {
263+
convert_s4_s8_8_lowbits(dstptr, srcptr);
264+
}
265+
266+
template <>
267+
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_NF4>(int8_t* dstptr, int8_t* srcptr) {
268+
convert_s4_s8_8_lowbits(dstptr, srcptr);
269+
}
270+
271+
template <>
272+
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_E2M1>(int8_t* dstptr, int8_t* srcptr) {
273+
convert_s4_s8_8_lowbits(dstptr, srcptr);
252274
}
253275

254276
template <JBLAS_DTYPE S4_T>

neural_speed/cmake/Common.cmake

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,25 @@ function(add_executable_w_warning TARGET)
3636
warning_check(${TARGET})
3737
endfunction()
3838

39-
function(add_library_w_warning TARGET)
40-
add_library(${TARGET} STATIC ${ARGN})
39+
function(add_library_w_warning_ TARGET)
40+
add_library(${TARGET} ${ARGN})
4141
set_target_properties(${TARGET} PROPERTIES C_STANDARD 11 C_STANDARD_REQUIRED ON C_EXTENSIONS OFF)
4242
set_target_properties(${TARGET} PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF)
4343
warning_check(${TARGET})
4444
endfunction()
45+
46+
function(add_library_w_warning TARGET)
47+
add_library_w_warning_(${TARGET} STATIC ${ARGN})
48+
endfunction()
49+
50+
function(add_shared_library_w_warning TARGET)
51+
add_library_w_warning_(${TARGET} SHARED ${ARGN})
52+
endfunction()
53+
54+
function(add_shareable_library_w_warning TARGET)
55+
if (BUILD_SHARED_LIBS)
56+
add_library_w_warning_(${TARGET} SHARED ${ARGN})
57+
else()
58+
add_library_w_warning_(${TARGET} STATIC ${ARGN})
59+
endif()
60+
endfunction()

neural_speed/cmake/ISA.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515
if (MSVC)
16+
if(NE_F16C)
17+
add_compile_definitions(__F16C__)
18+
endif()
1619
if (NE_AVX512)
1720
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
1821
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)

neural_speed/core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ find_package(Threads REQUIRED)
1616
file(GLOB layers_srcs "layers/*.cpp")
1717
set(sources ne_layers.c ${layers_srcs})
1818

19-
add_library_w_warning(ne_layers "${sources}")
19+
add_shareable_library_w_warning(ne_layers "${sources}")
2020

2121
target_include_directories(ne_layers PUBLIC .)
2222
target_compile_features(ne_layers PUBLIC c_std_11) # don't bump

neural_speed/scripts/convert_mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,8 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
855855
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)
856856

857857

858-
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL}
859-
858+
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL,
859+
'BF16': DT_BF16}
860860

861861
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
862862
header_size, = struct.unpack('<Q', fp.read(8))

0 commit comments

Comments
 (0)