@@ -187,21 +187,53 @@ get_bias_scale()
187187 return 3 ;
188188}
189189
190+ static inline void
191+ MlasAvx2LoaduDeinterleave32Ps (const float * src, __m256& v0, __m256& v1, __m256& v2, __m256& v3)
192+ {
193+ // Process 32 activations contiguously using loadu + shuffle.
194+ // This allows us to mix neighbors (src[4i], src[4i+1], src[4i+2], src[4i+3]) across lanes,
195+ // which matches the T-MAC weight packing.
196+ // We use loadu + shuffle instead of gather to avoid potential issues with gather
197+ // on some hardware and ensure deterministic behavior.
198+ __m256 vec_b0 = _mm256_loadu_ps (src + 0 );
199+ __m256 vec_b1 = _mm256_loadu_ps (src + 8 );
200+ __m256 vec_b2 = _mm256_loadu_ps (src + 16 );
201+ __m256 vec_b3 = _mm256_loadu_ps (src + 24 );
202+
203+ __m256 t0 = _mm256_unpacklo_ps (vec_b0, vec_b1);
204+ __m256 t1 = _mm256_unpackhi_ps (vec_b0, vec_b1);
205+ __m256 t2 = _mm256_unpacklo_ps (vec_b2, vec_b3);
206+ __m256 t3 = _mm256_unpackhi_ps (vec_b2, vec_b3);
207+
208+ __m256 u0 = _mm256_castpd_ps (_mm256_unpacklo_pd (_mm256_castps_pd (t0), _mm256_castps_pd (t2)));
209+ __m256 u1 = _mm256_castpd_ps (_mm256_unpackhi_pd (_mm256_castps_pd (t0), _mm256_castps_pd (t2)));
210+ __m256 u2 = _mm256_castpd_ps (_mm256_unpacklo_pd (_mm256_castps_pd (t1), _mm256_castps_pd (t3)));
211+ __m256 u3 = _mm256_castpd_ps (_mm256_unpackhi_pd (_mm256_castps_pd (t1), _mm256_castps_pd (t3)));
212+
213+ const __m256i perm_idx = _mm256_setr_epi32 (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
214+ v0 = _mm256_permutevar8x32_ps (u0, perm_idx);
215+ v1 = _mm256_permutevar8x32_ps (u1, perm_idx);
216+ v2 = _mm256_permutevar8x32_ps (u2, perm_idx);
217+ v3 = _mm256_permutevar8x32_ps (u3, perm_idx);
218+ }
219+
190220void
191221partial_max_g4_int8_k8 (float * lut_scales, const float * b)
192222{
193- // TODO(vraspar): add support for arm neon
194- const __m256i vec_bi = _mm256_set_epi32 (112 , 96 , 80 , 64 , 48 , 32 , 16 , 0 );
195- __m256 vec_b0 = _mm256_i32gather_ps (b + 0 , vec_bi, 1 );
196- __m256 vec_b1 = _mm256_i32gather_ps (b + 1 , vec_bi, 1 );
197- __m256 vec_b2 = _mm256_i32gather_ps (b + 2 , vec_bi, 1 );
198- __m256 vec_b3 = _mm256_i32gather_ps (b + 3 , vec_bi, 1 );
223+ __m256 vec_b0, vec_b1, vec_b2, vec_b3;
224+ MlasAvx2LoaduDeinterleave32Ps (b, vec_b0, vec_b1, vec_b2, vec_b3);
225+
199226 const __m256 vec_sign = _mm256_set1_ps (-0 .0f );
200227 __m256 vec_babs0 = _mm256_andnot_ps (vec_sign, vec_b0);
201228 __m256 vec_babs1 = _mm256_andnot_ps (vec_sign, vec_b1);
202229 __m256 vec_babs2 = _mm256_andnot_ps (vec_sign, vec_b2);
203230 __m256 vec_babs3 = _mm256_andnot_ps (vec_sign, vec_b3);
231+
232+ // The upper bound for the LUT values (mixtures of 4 activations) is the sum
233+ // of their absolute values.
204234 __m256 abssum = _mm256_add_ps (_mm256_add_ps (vec_babs0, vec_babs1), _mm256_add_ps (vec_babs2, vec_babs3));
235+
236+ // Reduce max across lanes to find the global maximum sum in this chunk.
205237 __m128 max4 = _mm_max_ps (_mm256_extractf128_ps (abssum, 1 ), _mm256_castps256_ps128 (abssum));
206238 max4 = _mm_max_ps (max4, _mm_movehl_ps (max4, max4));
207239 max4 = _mm_max_ss (max4, _mm_movehdup_ps (max4));
@@ -222,16 +254,14 @@ lut_ctor_g4_int8_impl(
222254)
223255{
224256 __m256 vec_lut[16 ];
225- float biases = 0.0 ;
226- const __m256i vec_bi = _mm256_set_epi32 (112 , 96 , 80 , 64 , 48 , 32 , 16 , 0 );
257+ float biases = 0 .0f ;
227258 float scales = *lut_scales;
228259 float t_scales = scales ? 1 .0f / scales : 0 .0f ;
229260
230261 for (int k = 0 ; k < act_k / 32 ; ++k) {
231- __m256 vec_b0 = _mm256_i32gather_ps (b + k * 32 + 0 , vec_bi, 1 );
232- __m256 vec_b1 = _mm256_i32gather_ps (b + k * 32 + 1 , vec_bi, 1 );
233- __m256 vec_b2 = _mm256_i32gather_ps (b + k * 32 + 2 , vec_bi, 1 );
234- __m256 vec_b3 = _mm256_i32gather_ps (b + k * 32 + 3 , vec_bi, 1 );
262+ const float * b_chunk = b + k * 32 ;
263+ __m256 vec_b0, vec_b1, vec_b2, vec_b3;
264+ MlasAvx2LoaduDeinterleave32Ps (b_chunk, vec_b0, vec_b1, vec_b2, vec_b3);
235265
236266 PRAGMA_UNROLL
237267 for (int g = 1 ; g < 16 ; g += 2 ) {
0 commit comments