@@ -1266,15 +1266,17 @@ SIMSIMD_PUBLIC void simsimd_scale_i16_haswell(simsimd_i16_t const *a, simsimd_si
12661266 simsimd_f32_t beta_f32 = (simsimd_f32_t )beta ;
12671267 __m256 alpha_vec = _mm256_set1_ps (alpha_f32 );
12681268 __m256 beta_vec = _mm256_set1_ps (beta_f32 );
1269+ __m256 min_vec = _mm256_set1_ps (-32768.0f );
1270+ __m256 max_vec = _mm256_set1_ps (32767.0f );
12691271
12701272 // The main loop:
12711273 simsimd_size_t i = 0 ;
12721274 for (; i + 8 <= n ; i += 8 ) {
12731275 __m256 a_vec = _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm_lddqu_si128 ((__m128i * )(a + i ))));
12741276 __m256 sum_vec = _mm256_fmadd_ps (a_vec , alpha_vec , beta_vec );
1277+ sum_vec = _mm256_max_ps (sum_vec , min_vec );
1278+ sum_vec = _mm256_min_ps (sum_vec , max_vec );
12751279 __m256i sum_i32_vec = _mm256_cvtps_epi32 (sum_vec );
1276- sum_i32_vec = _mm256_max_epi32 (sum_i32_vec , _mm256_set1_epi32 (-32768 ));
1277- sum_i32_vec = _mm256_min_epi32 (sum_i32_vec , _mm256_set1_epi32 (32767 ));
12781280 // Casting down to 16-bit integers is tricky!
12791281 __m128i sum_i16_vec =
12801282 _mm_packs_epi32 (_mm256_castsi256_si128 (sum_i32_vec ), _mm256_extracti128_si256 (sum_i32_vec , 1 ));
@@ -1296,6 +1298,8 @@ SIMSIMD_PUBLIC void simsimd_fma_i16_haswell(
12961298 simsimd_f32_t beta_f32 = (simsimd_f32_t )beta ;
12971299 __m256 alpha_vec = _mm256_set1_ps (alpha_f32 );
12981300 __m256 beta_vec = _mm256_set1_ps (beta_f32 );
1301+ __m256 min_vec = _mm256_set1_ps (-32768.0f );
1302+ __m256 max_vec = _mm256_set1_ps (32767.0f );
12991303
13001304 // The main loop:
13011305 simsimd_size_t i = 0 ;
@@ -1306,9 +1310,9 @@ SIMSIMD_PUBLIC void simsimd_fma_i16_haswell(
13061310 __m256 ab_vec = _mm256_mul_ps (a_vec , b_vec );
13071311 __m256 ab_scaled_vec = _mm256_mul_ps (ab_vec , alpha_vec );
13081312 __m256 sum_vec = _mm256_fmadd_ps (c_vec , beta_vec , ab_scaled_vec );
1313+ sum_vec = _mm256_max_ps (sum_vec , min_vec );
1314+ sum_vec = _mm256_min_ps (sum_vec , max_vec );
13091315 __m256i sum_i32_vec = _mm256_cvtps_epi32 (sum_vec );
1310- sum_i32_vec = _mm256_max_epi32 (sum_i32_vec , _mm256_set1_epi32 (-32768 ));
1311- sum_i32_vec = _mm256_min_epi32 (sum_i32_vec , _mm256_set1_epi32 (32767 ));
13121316 // Casting down to 16-bit integers is tricky!
13131317 __m128i sum_i16_vec =
13141318 _mm_packs_epi32 (_mm256_castsi256_si128 (sum_i32_vec ), _mm256_extracti128_si256 (sum_i32_vec , 1 ));
@@ -1344,23 +1348,24 @@ SIMSIMD_PUBLIC void simsimd_sum_u16_haswell(simsimd_u16_t const *a, simsimd_u16_
13441348
13451349SIMSIMD_PUBLIC void simsimd_scale_u16_haswell (simsimd_u16_t const * a , simsimd_size_t n , simsimd_distance_t alpha ,
13461350 simsimd_distance_t beta , simsimd_u16_t * result ) {
1347-
13481351 simsimd_f32_t alpha_f32 = (simsimd_f32_t )alpha ;
13491352 simsimd_f32_t beta_f32 = (simsimd_f32_t )beta ;
13501353 __m256 alpha_vec = _mm256_set1_ps (alpha_f32 );
13511354 __m256 beta_vec = _mm256_set1_ps (beta_f32 );
1355+ __m256 min_vec = _mm256_setzero_ps ();
1356+ __m256 max_vec = _mm256_set1_ps (65535.0f );
13521357
13531358 // The main loop:
13541359 simsimd_size_t i = 0 ;
13551360 for (; i + 8 <= n ; i += 8 ) {
13561361 __m256 a_vec = _mm256_cvtepi32_ps (_mm256_cvtepu16_epi32 (_mm_lddqu_si128 ((__m128i * )(a + i ))));
13571362 __m256 sum_vec = _mm256_fmadd_ps (a_vec , alpha_vec , beta_vec );
1363+ sum_vec = _mm256_max_ps (sum_vec , min_vec );
1364+ sum_vec = _mm256_min_ps (sum_vec , max_vec );
13581365 __m256i sum_i32_vec = _mm256_cvtps_epi32 (sum_vec );
1359- sum_i32_vec = _mm256_max_epi32 (sum_i32_vec , _mm256_setzero_si256 ());
1360- sum_i32_vec = _mm256_min_epi32 (sum_i32_vec , _mm256_set1_epi32 (65535 ));
13611366 // Casting down to 16-bit integers is tricky!
13621367 __m128i sum_u16_vec =
1363- _mm_packs_epi32 (_mm256_castsi256_si128 (sum_i32_vec ), _mm256_extracti128_si256 (sum_i32_vec , 1 ));
1368+ _mm_packus_epi32 (_mm256_castsi256_si128 (sum_i32_vec ), _mm256_extracti128_si256 (sum_i32_vec , 1 ));
13641369 _mm_storeu_si128 ((__m128i * )(result + i ), sum_u16_vec );
13651370 }
13661371
@@ -1379,6 +1384,8 @@ SIMSIMD_PUBLIC void simsimd_fma_u16_haswell(
13791384 simsimd_f32_t beta_f32 = (simsimd_f32_t )beta ;
13801385 __m256 alpha_vec = _mm256_set1_ps (alpha_f32 );
13811386 __m256 beta_vec = _mm256_set1_ps (beta_f32 );
1387+ __m256 min_vec = _mm256_setzero_ps ();
1388+ __m256 max_vec = _mm256_set1_ps (65535.0f );
13821389
13831390 // The main loop:
13841391 simsimd_size_t i = 0 ;
@@ -1389,12 +1396,12 @@ SIMSIMD_PUBLIC void simsimd_fma_u16_haswell(
13891396 __m256 ab_vec = _mm256_mul_ps (a_vec , b_vec );
13901397 __m256 ab_scaled_vec = _mm256_mul_ps (ab_vec , alpha_vec );
13911398 __m256 sum_vec = _mm256_fmadd_ps (c_vec , beta_vec , ab_scaled_vec );
1399+ sum_vec = _mm256_max_ps (sum_vec , min_vec );
1400+ sum_vec = _mm256_min_ps (sum_vec , max_vec );
13921401 __m256i sum_i32_vec = _mm256_cvtps_epi32 (sum_vec );
1393- sum_i32_vec = _mm256_max_epi32 (sum_i32_vec , _mm256_setzero_si256 ());
1394- sum_i32_vec = _mm256_min_epi32 (sum_i32_vec , _mm256_set1_epi32 (65535 ));
13951402 // Casting down to 16-bit integers is tricky!
13961403 __m128i sum_u16_vec =
1397- _mm_packs_epi32 (_mm256_castsi256_si128 (sum_i32_vec ), _mm256_extracti128_si256 (sum_i32_vec , 1 ));
1404+ _mm_packus_epi32 (_mm256_castsi256_si128 (sum_i32_vec ), _mm256_extracti128_si256 (sum_i32_vec , 1 ));
13981405 _mm_storeu_si128 ((__m128i * )(result + i ), sum_u16_vec );
13991406 }
14001407
@@ -2677,18 +2684,34 @@ SIMSIMD_PUBLIC void simsimd_sum_u16_ice(simsimd_u16_t const *a, simsimd_u16_t co
26772684}
26782685
26792686SIMSIMD_INTERNAL __m512i _mm512_adds_epi32_ice (__m512i a , __m512i b ) {
2687+ // ! There are many flavors of addition with saturation in AVX-512: i8, u8, i16, and u16.
2688+ // ! But not for larger numeric types. We have to do it manually.
2689+ // ! https://stackoverflow.com/a/56531252/2766161
26802690 __m512i sum = _mm512_add_epi32 (a , b );
2681- __m512i sign_mask = _mm512_set1_epi32 (0x80000000 );
2682-
2683- __m512i overflow = _mm512_and_si512 (_mm512_xor_si512 (a , b ), sign_mask ); // Same sign inputs
2684- __m512i overflows = _mm512_or_si512 (overflow , _mm512_xor_si512 (sum , a )); // Overflow condition
26852691
2692+ // Set constants for overflow and underflow limits
26862693 __m512i max_val = _mm512_set1_epi32 (2147483647 );
2687- __m512i min_val = _mm512_set1_epi32 (-2147483647 - 1 );
2688- __m512i overflow_result =
2689- _mm512_mask_blend_epi32 (_mm512_cmp_epi32_mask (sum , min_val , _MM_CMPINT_LT ), max_val , min_val );
2694+ __m512i min_val = _mm512_set1_epi32 (-2147483648 );
2695+
2696+ // TODO: Consider using ternary operator for performance.
2697+ // Detect positive overflow: (a > 0) && (b > 0) && (sum < 0)
2698+ __mmask16 a_is_positive = _mm512_cmpgt_epi32_mask (a , _mm512_setzero_si512 ());
2699+ __mmask16 b_is_positive = _mm512_cmpgt_epi32_mask (b , _mm512_setzero_si512 ());
2700+ __mmask16 sum_is_negative = _mm512_cmplt_epi32_mask (sum , _mm512_setzero_si512 ());
2701+ __mmask16 pos_overflow_mask = _kand_mask16 (_kand_mask16 (a_is_positive , b_is_positive ), sum_is_negative );
26902702
2691- return _mm512_mask_blend_epi32 (_mm512_test_epi32_mask (overflows , overflows ), sum , overflow_result );
2703+ // TODO: Consider using ternary operator for performance.
2704+ // Detect negative overflow: (a < 0) && (b < 0) && (sum >= 0)
2705+ __mmask16 a_is_negative = _mm512_cmplt_epi32_mask (a , _mm512_setzero_si512 ());
2706+ __mmask16 b_is_negative = _mm512_cmplt_epi32_mask (b , _mm512_setzero_si512 ());
2707+ __mmask16 sum_is_non_negative = _mm512_cmpge_epi32_mask (sum , _mm512_setzero_si512 ());
2708+ __mmask16 neg_overflow_mask = _kand_mask16 (_kand_mask16 (a_is_negative , b_is_negative ), sum_is_non_negative );
2709+
2710+ // Apply saturation for positive overflow
2711+ sum = _mm512_mask_blend_epi32 (pos_overflow_mask , sum , max_val );
2712+ // Apply saturation for negative overflow
2713+ sum = _mm512_mask_blend_epi32 (neg_overflow_mask , sum , min_val );
2714+ return sum ;
26922715}
26932716
26942717SIMSIMD_INTERNAL __m512i _mm512_adds_epu32_ice (__m512i a , __m512i b ) {
@@ -3144,6 +3167,9 @@ SIMSIMD_PUBLIC void simsimd_fma_i8_sapphire(
31443167 __m512h c_f16_low_vec , c_f16_high_vec , ab_f16_low_vec , ab_f16_high_vec ;
31453168 __m512h ab_scaled_f16_low_vec , ab_scaled_f16_high_vec , sum_f16_low_vec , sum_f16_high_vec ;
31463169 __m512i sum_i16_low_vec , sum_i16_high_vec ;
3170+ __m512h min_f16_vec = _mm512_cvtepi16_ph (_mm512_set1_epi16 (-128 ));
3171+ __m512h max_f16_vec = _mm512_cvtepi16_ph (_mm512_set1_epi16 (127 ));
3172+
31473173simsimd_fma_i8_sapphire_cycle :
31483174 if (n < 64 ) {
31493175 mask = (__mmask64 )_bzhi_u64 (0xFFFFFFFFFFFFFFFFull , n );
@@ -3174,9 +3200,13 @@ SIMSIMD_PUBLIC void simsimd_fma_i8_sapphire(
31743200 // Add:
31753201 sum_f16_low_vec = _mm512_fmadd_ph (c_f16_low_vec , beta_vec , ab_scaled_f16_low_vec );
31763202 sum_f16_high_vec = _mm512_fmadd_ph (c_f16_high_vec , beta_vec , ab_scaled_f16_high_vec );
3203+ // Clip the 16-bit result to 8-bit:
3204+ sum_f16_low_vec = _mm512_max_ph (_mm512_min_ph (sum_f16_low_vec , max_f16_vec ), min_f16_vec );
3205+ sum_f16_high_vec = _mm512_max_ph (_mm512_min_ph (sum_f16_high_vec , max_f16_vec ), min_f16_vec );
31773206 // Downcast:
31783207 sum_i16_low_vec = _mm512_cvtph_epi16 (sum_f16_low_vec );
31793208 sum_i16_high_vec = _mm512_cvtph_epi16 (sum_f16_high_vec );
3209+ // Merge back:
31803210 sum_i8_vec = _mm512_inserti64x4 (_mm512_castsi256_si512 (_mm512_cvtsepi16_epi8 (sum_i16_low_vec )),
31813211 _mm512_cvtsepi16_epi8 (sum_i16_high_vec ), 1 );
31823212 _mm512_mask_storeu_epi8 (result , mask , sum_i8_vec );
@@ -3196,6 +3226,9 @@ SIMSIMD_PUBLIC void simsimd_fma_u8_sapphire(
31963226 __m512h c_f16_low_vec , c_f16_high_vec , ab_f16_low_vec , ab_f16_high_vec ;
31973227 __m512h ab_scaled_f16_low_vec , ab_scaled_f16_high_vec , sum_f16_low_vec , sum_f16_high_vec ;
31983228 __m512i sum_i16_low_vec , sum_i16_high_vec ;
3229+ __m512h min_f16_vec = _mm512_cvtepi16_ph (_mm512_set1_epi16 (0 ));
3230+ __m512h max_f16_vec = _mm512_cvtepi16_ph (_mm512_set1_epi16 (255 ));
3231+
31993232simsimd_fma_u8_sapphire_cycle :
32003233 if (n < 64 ) {
32013234 mask = (__mmask64 )_bzhi_u64 (0xFFFFFFFFFFFFFFFFull , n );
@@ -3226,9 +3259,13 @@ SIMSIMD_PUBLIC void simsimd_fma_u8_sapphire(
32263259 // Add:
32273260 sum_f16_low_vec = _mm512_fmadd_ph (c_f16_low_vec , beta_vec , ab_scaled_f16_low_vec );
32283261 sum_f16_high_vec = _mm512_fmadd_ph (c_f16_high_vec , beta_vec , ab_scaled_f16_high_vec );
3262+ // Clip the 16-bit result to 8-bit:
3263+ sum_f16_low_vec = _mm512_max_ph (_mm512_min_ph (sum_f16_low_vec , max_f16_vec ), min_f16_vec );
3264+ sum_f16_high_vec = _mm512_max_ph (_mm512_min_ph (sum_f16_high_vec , max_f16_vec ), min_f16_vec );
32293265 // Downcast:
32303266 sum_i16_low_vec = _mm512_cvtph_epi16 (sum_f16_low_vec );
32313267 sum_i16_high_vec = _mm512_cvtph_epi16 (sum_f16_high_vec );
3268+ // Merge back:
32323269 sum_u8_vec = _mm512_packus_epi16 (sum_i16_low_vec , sum_i16_high_vec );
32333270 _mm512_mask_storeu_epi8 (result , mask , sum_u8_vec );
32343271 result += 64 ;
0 commit comments