@@ -1198,18 +1198,30 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
1198
1198
sumi_2 = _mm512_dpbusd_epi32 (sumi_2, deq2.bits .values [3 ], q8q);
1199
1199
// The scale is supposed to be per per tensor, so we can use the same scale
1200
1200
auto vd = _mm512_set1_ps (d*q8.scale (iy, i));
1201
- accd[2 * iy+0 ] = _mm512_fmadd_ps (vd, _mm512_cvtepi32_ps (sumi_1), accd[2 * iy+0 ]);
1202
- accd[2 * iy+1 ] = _mm512_fmadd_ps (vd, _mm512_cvtepi32_ps (sumi_2), accd[2 * iy+1 ]);
1201
+ accd[iy+ 0 ] = _mm512_fmadd_ps (vd, _mm512_cvtepi32_ps (sumi_1), accd[iy+ 0 ]);
1202
+ accd[iy+nrc_y ] = _mm512_fmadd_ps (vd, _mm512_cvtepi32_ps (sumi_2), accd[iy+nrc_y ]);
1203
1203
// Leaving this here just in case ternary models start using per row scales
1204
1204
// accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
1205
1205
// accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
1206
1206
}
1207
1207
1208
1208
}
1209
1209
1210
- for (int iy = 0 ; iy < nrc_y; ++iy) {
1211
- info.store (ix+0 , iy, _mm512_reduce_add_ps (accd[2 *iy+0 ]));
1212
- info.store (ix+1 , iy, _mm512_reduce_add_ps (accd[2 *iy+1 ]));
1210
+ if constexpr (nrc_y == 8 ) {
1211
+ __m256 sums[8 ];
1212
+ for (int iy = 0 ; iy < nrc_y; ++iy) {
1213
+ sums[iy] = _mm256_add_ps (_mm512_castps512_ps256 (accd[iy]), _mm512_extractf32x8_ps (accd[iy], 1 ));
1214
+ }
1215
+ store_8 (ix+0 , sums, info);
1216
+ for (int iy = 0 ; iy < nrc_y; ++iy) {
1217
+ sums[iy] = _mm256_add_ps (_mm512_castps512_ps256 (accd[iy+nrc_y]), _mm512_extractf32x8_ps (accd[iy+nrc_y], 1 ));
1218
+ }
1219
+ store_8 (ix+1 , sums, info);
1220
+ } else {
1221
+ for (int iy = 0 ; iy < nrc_y; ++iy) {
1222
+ info.store (ix+0 , iy, _mm512_reduce_add_ps (accd[iy+ 0 ]));
1223
+ info.store (ix+1 , iy, _mm512_reduce_add_ps (accd[iy+nrc_y]));
1224
+ }
1213
1225
}
1214
1226
1215
1227
}
0 commit comments