Skip to content

Commit 07b5d73

Browse files
committed
Also apply to iq2_tn
1 parent 94cdadd commit 07b5d73

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

+17-5
Original file line numberDiff line numberDiff line change
@@ -1198,18 +1198,30 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
11981198
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
11991199
// The scale is supposed to be per per tensor, so we can use the same scale
12001200
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
1201-
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
1202-
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
1201+
accd[iy+ 0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[iy+ 0]);
1202+
accd[iy+nrc_y] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[iy+nrc_y]);
12031203
// Leaving this here just in case ternary models start using per row scales
12041204
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
12051205
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
12061206
}
12071207

12081208
}
12091209

1210-
for (int iy = 0; iy < nrc_y; ++iy) {
1211-
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
1212-
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
1210+
if constexpr (nrc_y == 8) {
1211+
__m256 sums[8];
1212+
for (int iy = 0; iy < nrc_y; ++iy) {
1213+
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
1214+
}
1215+
store_8(ix+0, sums, info);
1216+
for (int iy = 0; iy < nrc_y; ++iy) {
1217+
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy+nrc_y]), _mm512_extractf32x8_ps(accd[iy+nrc_y], 1));
1218+
}
1219+
store_8(ix+1, sums, info);
1220+
} else {
1221+
for (int iy = 0; iy < nrc_y; ++iy) {
1222+
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[iy+ 0]));
1223+
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[iy+nrc_y]));
1224+
}
12131225
}
12141226

12151227
}

0 commit comments

Comments
 (0)