Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
555e951
matmul nbits to optimize memory layout for avx instructions
liqunfu Sep 24, 2024
076998c
Merge branch 'main' into liqun/avx-layout
liqunfu Nov 7, 2024
99aec95
intermediate push
liqunfu Nov 18, 2024
8ce1a2a
pass mlas and utest for blklen32 avx512
liqunfu Nov 27, 2024
f016555
Merge branch 'main' into liqun/avx-layout
liqunfu Nov 27, 2024
d371c59
pass avx512/vnni-blklen32
liqunfu Nov 29, 2024
790b03f
pass avx512vnni-blklen128. plan to compute blksum in different loop t…
liqunfu Nov 29, 2024
557fbb0
attmpt to make blklen256 work. failed because blksum computation need…
liqunfu Nov 29, 2024
6b28657
avx512 blklen64 to compute blksum in a separate loop
liqunfu Nov 30, 2024
0b867f8
avx512 scaled_zp compute in a separate loop except blklen16
liqunfu Nov 30, 2024
2e74f56
avx512, all blklens, scaled_zp compute in a separate loop
liqunfu Nov 30, 2024
0bf47f7
Merge branch 'main' into liqun/avx-layout
liqunfu Dec 12, 2024
c19ae9e
avx2 passes
liqunfu Dec 13, 2024
b26b075
avxvnni, matmul_nbit kernel
liqunfu Dec 15, 2024
7e99d50
mlas nbit print correct compType
liqunfu Jan 8, 2025
f36ec96
clean up a bit
liqunfu Jan 8, 2025
6d0404f
Merge branch 'main' into liqun/avx-layout
liqunfu Jan 8, 2025
5901b52
lint
liqunfu Jan 9, 2025
eba1908
remove unused __m512 load_1blksum_512(const float* BlksumPtr)
liqunfu Jan 10, 2025
e8484eb
Merge branch 'main' into liqun/avx-layout
liqunfu Jan 10, 2025
6dac6ad
sqnbitgemm_kernel_avx512.cpp to apply -mavx512f
liqunfu Jan 10, 2025
429054a
undo sqnbitgemm_kernel_avx512.cpp to apply -mavx512f
liqunfu Jan 11, 2025
b1d7474
restore avx512 blklen32 from use special layout because related code …
liqunfu Jan 11, 2025
5647598
Merge branch 'main' into liqun/avx-layout
liqunfu Mar 17, 2025
af15c91
merge main
liqunfu Apr 2, 2025
42fc7e3
const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input<Tens…
liqunfu Apr 3, 2025
419822b
scales_are_packed_ set to ture in x64
liqunfu Apr 3, 2025
534befe
use scales_are_packed_
liqunfu Apr 3, 2025
e3f1b29
check scales against nullptr
liqunfu Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,160 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512(
return CountM;
}

__m512
ComputeMulScal(const float* a_ptr, size_t step, float& scale)
{
const __m512 signBit = _mm512_set1_ps(-0.0f);
__m512 maxAbs = _mm512_setzero_ps();

for (size_t kk = 0; kk < step; kk += 16) {
const size_t klen = std::min(size_t(16), step - kk);

uint32_t mask = 0xffff >> (16 - klen);
__m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), a_ptr + kk);

// Compute max(abs(e)) for the block
maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0));
}

__m256 max8 =
_mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0));
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
const float maxScalar = _mm_cvtss_f32(max4);

// Quantize these floats
scale = maxScalar / 127.f;

const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f;
return _mm512_set1_ps(inverse_scale);
}

void
QuantizeInt8ComputeBlksum(const float* a_ptr, size_t step, __m512& mul, float scale, __m256i& i0_32_epi8, float& blksum)
{
const __m256i one_16_epi16 = _mm256_set1_epi16(1);
__m256i sum_16_epi16 = _mm256_setzero_si256();
__m128i i_16_epi8[2] = {_mm_setzero_si128(), _mm_setzero_si128()};
int index = 0;
for (size_t kk = 0; kk < step; kk += 16, index++) {
const size_t klen = std::min(size_t(16), step - kk);

uint32_t mask = 0xffff >> (16 - klen);
__m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), a_ptr + kk);
v0 = _mm512_mul_ps(v0, mul);

// Round to nearest integer
v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST);

// Convert floats to integers
__m512i i0 = _mm512_cvtps_epi32(v0);

// Convert int32 to int8
i_16_epi8[index] = _mm512_cvtepi32_epi8(i0);
//_mm_storeu_si128(dst++, i0_8);

// accumulate Sum(a_i)
__m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8[index]);
sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16);
}
i0_32_epi8 = _mm256_set_m128i(i_16_epi8[1], i_16_epi8[0]);
const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16);
blksum = scale * hsum_8_epi32(sum_8_epi32);
}

void
Quantize1BlkBlkLen32(const float* a_ptr, size_t step, __m256i& i_32_epi8, float& scale, float& blksum)
{
// 32 float to 32 epi8s in i0_32_epi8
__m512 mul = ComputeMulScal(a_ptr, step, scale);
QuantizeInt8ComputeBlksum(a_ptr, step, mul, scale, i_32_epi8, blksum);
}

void
store_4blk_blklen32_interleaved(__m256i i_32_epi8[4], int8_t* blob)
{
// 0 1 2 3 32 33 34 35 64 65 66 67 96 97 98 99
// 4 5 6 7 36 37 38 39 68 69 70 71 100 101 102 103
// 8 9 10 11 40 41 42 43 72 73 74 75 104 105 106 107
// 12 13 14 15 44 45 46 47 76 77 78 79 108 109 110 111
//
// 16 17 18 19 48 49 50 51 80 81 82 83 112 113 114 115
// 20 21 22 23 52 53 54 55 84 85 86 87 116 117 118 119
// 24 25 26 27 56 57 58 59 88 89 90 91 120 121 122 123
// 28 29 30 31 60 61 62 63 92 93 94 95 124 125 126 127

// Interleave and store i_32_epi8[4] in the specified layout
__m256i a0_lower = _mm256_permute2x128_si256(i_32_epi8[0], i_32_epi8[1], 0x20);
__m256i a0_higher = _mm256_permute2x128_si256(i_32_epi8[0], i_32_epi8[1], 0x31);
__m256i a1_lower = _mm256_permute2x128_si256(i_32_epi8[2], i_32_epi8[3], 0x20);
__m256i a1_higher = _mm256_permute2x128_si256(i_32_epi8[2], i_32_epi8[3], 0x31);

__m512i a_lower = _mm512_inserti64x4(_mm512_castsi256_si512(a0_lower), a1_lower, 1);
__m512i a_higher = _mm512_inserti64x4(_mm512_castsi256_si512(a0_higher), a1_higher, 1);

__m512i idx = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
__m512i a_lower_interleaved = _mm512_permutexvar_epi32(idx, a_lower);
__m512i a_higher_interleaved = _mm512_permutexvar_epi32(idx, a_higher);

_mm512_storeu_si512(reinterpret_cast<__m512i*>(blob + 0 * 64), a_lower_interleaved);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(blob + 1 * 64), a_higher_interleaved);
}

void MLASCALL
QuantizeARow_CompInt8_avx512_blklen32(
const float* A,
size_t CountK,
std::byte* QuantA,
float* QuantAScale,
float* AScaledBlkSum // scale_k * Sum_blklen(a_i)
)
{
const size_t BlkLen = 32;
const int64_t SubBlkLen = 4 * BlkLen; // process 128 weights at a time and then process the remaining weights

const float* a_ptr = A;
int8_t* quant_a_ptr = reinterpret_cast<int8_t*>(QuantA);
float* scale_ptr = QuantAScale;
float* blksum_ptr = AScaledBlkSum;

int k_remaining = (int)CountK;

for (; k_remaining >= SubBlkLen; k_remaining -= SubBlkLen) {
__m256i i_32_epi8[4];
float scale[4];
float blksum[4];
for (int i = 0; i < 4; i++) {
Quantize1BlkBlkLen32(a_ptr, BlkLen, i_32_epi8[i], scale[i], blksum[i]);
a_ptr += BlkLen;
}
store_4blk_blklen32_interleaved(i_32_epi8, quant_a_ptr);
quant_a_ptr += BlkLen * 4;
std::copy(scale, scale + 4, scale_ptr);
scale_ptr += 4;
std::copy(blksum, blksum + 4, blksum_ptr);
blksum_ptr += 4;
}

while (k_remaining > 0) {
// for (size_t k = 0; k < CountK; k += BlkLen) {
__m256i i_32_epi8;
float scale;
float blksum;
const size_t step = std::min(BlkLen, (size_t)k_remaining);
Quantize1BlkBlkLen32(a_ptr, step, i_32_epi8, scale, blksum);
_mm256_storeu_epi8(quant_a_ptr, i_32_epi8);
a_ptr += BlkLen;
quant_a_ptr += BlkLen;
*scale_ptr = scale;
scale_ptr++;
*blksum_ptr = blksum;
blksum_ptr++;
k_remaining -= BlkLen;
}
}

void MLASCALL
QuantizeARow_CompInt8_avx512(
size_t BlkLen,
Expand All @@ -257,6 +411,10 @@ QuantizeARow_CompInt8_avx512(
float* AScaledBlkSum // scale_k * Sum_blklen(a_i)
)
{
if (BlkLen == 32) {
QuantizeARow_CompInt8_avx512_blklen32(A, CountK, QuantA, QuantAScale, AScaledBlkSum);
return;
}
// port from MlasQ80BlkQuantRow
assert(BlkLen % 16 == 0);
const __m512 signBit = _mm512_set1_ps(-0.0f);
Expand Down
120 changes: 50 additions & 70 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi
bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127
}

static const uint32_t index_array[16] = {0, 0, 2, 2, 0, 0, 2, 2, 1, 1, 3, 3, 1, 1, 3, 3};
static MLAS_FORCEINLINE
__m512 load_4blksum_512(const float* BlksumPtr)
{
// Load 128-bit data into __m128 register
__m128 blksum4_4_ps = _mm_loadu_ps(BlksumPtr);

// Insert the __m256 register into the lower 256 bits of the __m512 register
return _mm512_insertf32x4(_mm512_setzero_ps(), blksum4_4_ps, 0);
}

static MLAS_FORCEINLINE void
accumulate_blklen32_r1c1blk4_avx512(
Expand All @@ -36,20 +44,13 @@ accumulate_blklen32_r1c1blk4_avx512(
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps);

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
// __m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133

const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 2~2,3~3

const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333
const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8);
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8);
const __m512i sum_32_epi16 = _mm512_add_epi16(dot0_32_epi16, dot1_32_epi16);
const __m512i one_32_epi16 = generate_ones_32_epi16();
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16);
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);
}
Expand All @@ -70,47 +71,37 @@ accumulate_blklen32_r2c1blk4_avx512(
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
// __m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133

const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3
const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 00112233 x 4 epi16s
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 00112233 x 4 epi16s
const __m512i sum_32_epi16 = _mm512_add_epi16(dot0_32_epi16, dot1_32_epi16);

const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333
const __m512i one_32_epi16 = generate_ones_32_epi16();
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0123 x 4 epi32s
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);

const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
const __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);
}
{
const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123
const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps);
__m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
// __m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133
const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8);
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8);
const __m512i sum_32_epi16 = _mm512_add_epi16(dot0_32_epi16, dot1_32_epi16);

const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3

const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333
const __m512i one_32_epi16 = generate_ones_32_epi16();
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16);
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);

const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123
const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps);
__m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps);

acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1);
}
}
Expand All @@ -122,30 +113,29 @@ accumulate_blklen32_r1c1blk4_avx512vnni(
const std::byte* QuantBDataPtr,
const float* scale_a,
const float* scale_b,
//const float* blksum_a,
//const float* blksum_b,
__m512& acc0
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); // 0000111122223333 x 4 (64 unsigned int8)

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
//__m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133
const __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000000011111111
const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 2222222233333333
__m512i sum_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8);
sum_16_epi32 = _mm512_dpbusd_epi32(sum_16_epi32, bv1_64_epi8, av1_64_epi8); // 0123012301230123

const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);

//const __m512 blksum_a0_ps = load_4blksum_512(blksum_a); // 0123000000000000
//const __m512 blksum_b0_ps = load_4blksum_512(blksum_b); // 0123000000000000
//acc0 = _mm512_fmadd_ps(blksum_a0_ps, blksum_b0_ps, acc0);
}
}

Expand All @@ -164,24 +154,17 @@ accumulate_blklen32_r2c1blk4_avx512vnni(
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);
__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
//__m512i idx = _mm512_loadu_epi8(&index_array[0]);
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); // 0000111122223333 x 4 (64 unsigned int8)

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133

const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111
const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333
__m512i sum_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8);
sum_16_epi32 = _mm512_dpbusd_epi32(sum_16_epi32, bv1_64_epi8, av01_64_epi8); // 0123012301230123

const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);
}
Expand All @@ -190,14 +173,9 @@ accumulate_blklen32_r2c1blk4_avx512vnni(
const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps);
__m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123

scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133

const __m512i dot0_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000000011111111
const __m512i dot1_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 2222222233333333
__m512i sum_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8);
sum_16_epi32 = _mm512_dpbusd_epi32(sum_16_epi32, bv1_64_epi8, av11_64_epi8); // 0123012301230123

const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133
const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1);
}
Expand All @@ -208,6 +186,8 @@ accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_e
{
__m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8);
const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32);
// TODO: to compare with:
// acc = _mm256_fmadd_ps(sum_ps, _mm256_broadcast_ps((__m128 const*)(&combined_scale)), acc);
acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc);
}

Expand Down
Loading