Skip to content

Commit 45e806f

Browse files
committed
Improve: Clipping on x86
1 parent ac60194 commit 45e806f

File tree

2 files changed

+84
-44
lines changed

2 files changed

+84
-44
lines changed

include/simsimd/elementwise.h

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,15 +1266,17 @@ SIMSIMD_PUBLIC void simsimd_scale_i16_haswell(simsimd_i16_t const *a, simsimd_si
12661266
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
12671267
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
12681268
__m256 beta_vec = _mm256_set1_ps(beta_f32);
1269+
__m256 min_vec = _mm256_set1_ps(-32768.0f);
1270+
__m256 max_vec = _mm256_set1_ps(32767.0f);
12691271

12701272
// The main loop:
12711273
simsimd_size_t i = 0;
12721274
for (; i + 8 <= n; i += 8) {
12731275
__m256 a_vec = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_lddqu_si128((__m128i *)(a + i))));
12741276
__m256 sum_vec = _mm256_fmadd_ps(a_vec, alpha_vec, beta_vec);
1277+
sum_vec = _mm256_max_ps(sum_vec, min_vec);
1278+
sum_vec = _mm256_min_ps(sum_vec, max_vec);
12751279
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
1276-
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(-32768));
1277-
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(32767));
12781280
// Casting down to 16-bit integers is tricky!
12791281
__m128i sum_i16_vec =
12801282
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
@@ -1296,6 +1298,8 @@ SIMSIMD_PUBLIC void simsimd_fma_i16_haswell(
12961298
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
12971299
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
12981300
__m256 beta_vec = _mm256_set1_ps(beta_f32);
1301+
__m256 min_vec = _mm256_set1_ps(-32768.0f);
1302+
__m256 max_vec = _mm256_set1_ps(32767.0f);
12991303

13001304
// The main loop:
13011305
simsimd_size_t i = 0;
@@ -1306,9 +1310,9 @@ SIMSIMD_PUBLIC void simsimd_fma_i16_haswell(
13061310
__m256 ab_vec = _mm256_mul_ps(a_vec, b_vec);
13071311
__m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec);
13081312
__m256 sum_vec = _mm256_fmadd_ps(c_vec, beta_vec, ab_scaled_vec);
1313+
sum_vec = _mm256_max_ps(sum_vec, min_vec);
1314+
sum_vec = _mm256_min_ps(sum_vec, max_vec);
13091315
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
1310-
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(-32768));
1311-
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(32767));
13121316
// Casting down to 16-bit integers is tricky!
13131317
__m128i sum_i16_vec =
13141318
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
@@ -1344,23 +1348,24 @@ SIMSIMD_PUBLIC void simsimd_sum_u16_haswell(simsimd_u16_t const *a, simsimd_u16_
13441348

13451349
SIMSIMD_PUBLIC void simsimd_scale_u16_haswell(simsimd_u16_t const *a, simsimd_size_t n, simsimd_distance_t alpha,
13461350
simsimd_distance_t beta, simsimd_u16_t *result) {
1347-
13481351
simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
13491352
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
13501353
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
13511354
__m256 beta_vec = _mm256_set1_ps(beta_f32);
1355+
__m256 min_vec = _mm256_setzero_ps();
1356+
__m256 max_vec = _mm256_set1_ps(65535.0f);
13521357

13531358
// The main loop:
13541359
simsimd_size_t i = 0;
13551360
for (; i + 8 <= n; i += 8) {
13561361
__m256 a_vec = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_lddqu_si128((__m128i *)(a + i))));
13571362
__m256 sum_vec = _mm256_fmadd_ps(a_vec, alpha_vec, beta_vec);
1363+
sum_vec = _mm256_max_ps(sum_vec, min_vec);
1364+
sum_vec = _mm256_min_ps(sum_vec, max_vec);
13581365
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
1359-
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_setzero_si256());
1360-
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(65535));
13611366
// Casting down to 16-bit integers is tricky!
13621367
__m128i sum_u16_vec =
1363-
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
1368+
_mm_packus_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
13641369
_mm_storeu_si128((__m128i *)(result + i), sum_u16_vec);
13651370
}
13661371

@@ -1379,6 +1384,8 @@ SIMSIMD_PUBLIC void simsimd_fma_u16_haswell(
13791384
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
13801385
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
13811386
__m256 beta_vec = _mm256_set1_ps(beta_f32);
1387+
__m256 min_vec = _mm256_setzero_ps();
1388+
__m256 max_vec = _mm256_set1_ps(65535.0f);
13821389

13831390
// The main loop:
13841391
simsimd_size_t i = 0;
@@ -1389,12 +1396,12 @@ SIMSIMD_PUBLIC void simsimd_fma_u16_haswell(
13891396
__m256 ab_vec = _mm256_mul_ps(a_vec, b_vec);
13901397
__m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec);
13911398
__m256 sum_vec = _mm256_fmadd_ps(c_vec, beta_vec, ab_scaled_vec);
1399+
sum_vec = _mm256_max_ps(sum_vec, min_vec);
1400+
sum_vec = _mm256_min_ps(sum_vec, max_vec);
13921401
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
1393-
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_setzero_si256());
1394-
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(65535));
13951402
// Casting down to 16-bit integers is tricky!
13961403
__m128i sum_u16_vec =
1397-
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
1404+
_mm_packus_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
13981405
_mm_storeu_si128((__m128i *)(result + i), sum_u16_vec);
13991406
}
14001407

@@ -2677,18 +2684,34 @@ SIMSIMD_PUBLIC void simsimd_sum_u16_ice(simsimd_u16_t const *a, simsimd_u16_t co
26772684
}
26782685

26792686
SIMSIMD_INTERNAL __m512i _mm512_adds_epi32_ice(__m512i a, __m512i b) {
2687+
// ! There are many flavors of addition with saturation in AVX-512: i8, u8, i16, and u16.
2688+
// ! But not for larger numeric types. We have to do it manually.
2689+
// ! https://stackoverflow.com/a/56531252/2766161
26802690
__m512i sum = _mm512_add_epi32(a, b);
2681-
__m512i sign_mask = _mm512_set1_epi32(0x80000000);
2682-
2683-
__m512i overflow = _mm512_and_si512(_mm512_xor_si512(a, b), sign_mask); // Same sign inputs
2684-
__m512i overflows = _mm512_or_si512(overflow, _mm512_xor_si512(sum, a)); // Overflow condition
26852691

2692+
// Set constants for overflow and underflow limits
26862693
__m512i max_val = _mm512_set1_epi32(2147483647);
2687-
__m512i min_val = _mm512_set1_epi32(-2147483647 - 1);
2688-
__m512i overflow_result =
2689-
_mm512_mask_blend_epi32(_mm512_cmp_epi32_mask(sum, min_val, _MM_CMPINT_LT), max_val, min_val);
2694+
__m512i min_val = _mm512_set1_epi32(-2147483648);
2695+
2696+
// TODO: Consider using ternary operator for performance.
2697+
// Detect positive overflow: (a > 0) && (b > 0) && (sum < 0)
2698+
__mmask16 a_is_positive = _mm512_cmpgt_epi32_mask(a, _mm512_setzero_si512());
2699+
__mmask16 b_is_positive = _mm512_cmpgt_epi32_mask(b, _mm512_setzero_si512());
2700+
__mmask16 sum_is_negative = _mm512_cmplt_epi32_mask(sum, _mm512_setzero_si512());
2701+
__mmask16 pos_overflow_mask = _kand_mask16(_kand_mask16(a_is_positive, b_is_positive), sum_is_negative);
26902702

2691-
return _mm512_mask_blend_epi32(_mm512_test_epi32_mask(overflows, overflows), sum, overflow_result);
2703+
// TODO: Consider using ternary operator for performance.
2704+
// Detect negative overflow: (a < 0) && (b < 0) && (sum >= 0)
2705+
__mmask16 a_is_negative = _mm512_cmplt_epi32_mask(a, _mm512_setzero_si512());
2706+
__mmask16 b_is_negative = _mm512_cmplt_epi32_mask(b, _mm512_setzero_si512());
2707+
__mmask16 sum_is_non_negative = _mm512_cmpge_epi32_mask(sum, _mm512_setzero_si512());
2708+
__mmask16 neg_overflow_mask = _kand_mask16(_kand_mask16(a_is_negative, b_is_negative), sum_is_non_negative);
2709+
2710+
// Apply saturation for positive overflow
2711+
sum = _mm512_mask_blend_epi32(pos_overflow_mask, sum, max_val);
2712+
// Apply saturation for negative overflow
2713+
sum = _mm512_mask_blend_epi32(neg_overflow_mask, sum, min_val);
2714+
return sum;
26922715
}
26932716

26942717
SIMSIMD_INTERNAL __m512i _mm512_adds_epu32_ice(__m512i a, __m512i b) {
@@ -3144,6 +3167,9 @@ SIMSIMD_PUBLIC void simsimd_fma_i8_sapphire(
31443167
__m512h c_f16_low_vec, c_f16_high_vec, ab_f16_low_vec, ab_f16_high_vec;
31453168
__m512h ab_scaled_f16_low_vec, ab_scaled_f16_high_vec, sum_f16_low_vec, sum_f16_high_vec;
31463169
__m512i sum_i16_low_vec, sum_i16_high_vec;
3170+
__m512h min_f16_vec = _mm512_cvtepi16_ph(_mm512_set1_epi16(-128));
3171+
__m512h max_f16_vec = _mm512_cvtepi16_ph(_mm512_set1_epi16(127));
3172+
31473173
simsimd_fma_i8_sapphire_cycle:
31483174
if (n < 64) {
31493175
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
@@ -3174,9 +3200,13 @@ SIMSIMD_PUBLIC void simsimd_fma_i8_sapphire(
31743200
// Add:
31753201
sum_f16_low_vec = _mm512_fmadd_ph(c_f16_low_vec, beta_vec, ab_scaled_f16_low_vec);
31763202
sum_f16_high_vec = _mm512_fmadd_ph(c_f16_high_vec, beta_vec, ab_scaled_f16_high_vec);
3203+
// Clip the 16-bit result to 8-bit:
3204+
sum_f16_low_vec = _mm512_max_ph(_mm512_min_ph(sum_f16_low_vec, max_f16_vec), min_f16_vec);
3205+
sum_f16_high_vec = _mm512_max_ph(_mm512_min_ph(sum_f16_high_vec, max_f16_vec), min_f16_vec);
31773206
// Downcast:
31783207
sum_i16_low_vec = _mm512_cvtph_epi16(sum_f16_low_vec);
31793208
sum_i16_high_vec = _mm512_cvtph_epi16(sum_f16_high_vec);
3209+
// Merge back:
31803210
sum_i8_vec = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(sum_i16_low_vec)),
31813211
_mm512_cvtsepi16_epi8(sum_i16_high_vec), 1);
31823212
_mm512_mask_storeu_epi8(result, mask, sum_i8_vec);
@@ -3196,6 +3226,9 @@ SIMSIMD_PUBLIC void simsimd_fma_u8_sapphire(
31963226
__m512h c_f16_low_vec, c_f16_high_vec, ab_f16_low_vec, ab_f16_high_vec;
31973227
__m512h ab_scaled_f16_low_vec, ab_scaled_f16_high_vec, sum_f16_low_vec, sum_f16_high_vec;
31983228
__m512i sum_i16_low_vec, sum_i16_high_vec;
3229+
__m512h min_f16_vec = _mm512_cvtepi16_ph(_mm512_set1_epi16(0));
3230+
__m512h max_f16_vec = _mm512_cvtepi16_ph(_mm512_set1_epi16(255));
3231+
31993232
simsimd_fma_u8_sapphire_cycle:
32003233
if (n < 64) {
32013234
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
@@ -3226,9 +3259,13 @@ SIMSIMD_PUBLIC void simsimd_fma_u8_sapphire(
32263259
// Add:
32273260
sum_f16_low_vec = _mm512_fmadd_ph(c_f16_low_vec, beta_vec, ab_scaled_f16_low_vec);
32283261
sum_f16_high_vec = _mm512_fmadd_ph(c_f16_high_vec, beta_vec, ab_scaled_f16_high_vec);
3262+
// Clip the 16-bit result to 8-bit:
3263+
sum_f16_low_vec = _mm512_max_ph(_mm512_min_ph(sum_f16_low_vec, max_f16_vec), min_f16_vec);
3264+
sum_f16_high_vec = _mm512_max_ph(_mm512_min_ph(sum_f16_high_vec, max_f16_vec), min_f16_vec);
32293265
// Downcast:
32303266
sum_i16_low_vec = _mm512_cvtph_epi16(sum_f16_low_vec);
32313267
sum_i16_high_vec = _mm512_cvtph_epi16(sum_f16_high_vec);
3268+
// Merge back:
32323269
sum_u8_vec = _mm512_packus_epi16(sum_i16_low_vec, sum_i16_high_vec);
32333270
_mm512_mask_storeu_epi8(result, mask, sum_u8_vec);
32343271
result += 64;

scripts/test.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,20 @@
6767
baseline_intersect = lambda x, y: len(np.intersect1d(x, y))
6868
baseline_bilinear = lambda x, y, z: x @ z @ y
6969

70-
def _normalize_element_wise(r, dtype):
70+
def _normalize_element_wise(r, dtype_new):
7171
"""Clips higher-resolution results to the smaller target dtype without overflow."""
72-
if np.issubdtype(dtype, np.integer):
72+
if np.issubdtype(dtype_new, np.integer):
7373
r = np.round(r)
7474
#! We need non-overflowing saturating addition for small integers, that NumPy lacks:
7575
#! https://stackoverflow.com/questions/29611185/avoid-overflow-when-adding-numpy-arrays
76-
if np.issubdtype(dtype, np.integer):
76+
if np.issubdtype(dtype_new, np.integer):
7777
dtype_old_info = np.iinfo(r.dtype) if np.issubdtype(r.dtype, np.integer) else np.finfo(r.dtype)
78-
dtype_new_info = np.iinfo(dtype)
78+
dtype_new_info = np.iinfo(dtype_new)
7979
new_min = dtype_new_info.min if dtype_new_info.min > dtype_old_info.min else None
8080
new_max = dtype_new_info.max if dtype_new_info.max < dtype_old_info.max else None
8181
if new_min is not None or new_max is not None:
8282
r = np.clip(r, new_min, new_max, out=r)
83-
return r.astype(dtype)
83+
return r.astype(dtype_new)
8484

8585
def _computation_dtype(x, y):
8686
x = np.asarray(x)
@@ -102,35 +102,38 @@ def _computation_dtype(x, y):
102102
return larger_dtype, larger_dtype
103103

104104
def baseline_scale(x, alpha, beta):
105-
return _normalize_element_wise(alpha * x + beta, x.dtype)
105+
compute_dtype, _ = _computation_dtype(x, alpha)
106+
result = alpha * x.astype(compute_dtype) + beta
107+
return _normalize_element_wise(result, x.dtype)
106108

107109
def baseline_sum(x, y):
108-
if x.dtype == np.uint8:
109-
return _normalize_element_wise(x.astype(np.uint16) + y, x.dtype)
110-
elif x.dtype == np.int8:
111-
return _normalize_element_wise(x.astype(np.int16) + y, x.dtype)
112-
else:
113-
return _normalize_element_wise(x + y, x.dtype)
110+
compute_dtype, _ = _computation_dtype(x, y)
111+
result = x.astype(compute_dtype) + y.astype(compute_dtype)
112+
return _normalize_element_wise(result, x.dtype)
114113

115114
def baseline_wsum(x, y, alpha, beta):
116-
return _normalize_element_wise(alpha * x + beta * y, x.dtype)
115+
compute_dtype, _ = _computation_dtype(x, y)
116+
result = x.astype(compute_dtype) * alpha + y.astype(compute_dtype) * beta
117+
return _normalize_element_wise(result, x.dtype)
117118

118119
def baseline_fma(x, y, z, alpha, beta):
119-
return _normalize_element_wise(np.multiply((alpha * x), y) + beta * z, x.dtype)
120+
compute_dtype, _ = _computation_dtype(x, y)
121+
result = x.astype(compute_dtype) * y.astype(compute_dtype) * alpha + z.astype(compute_dtype) * beta
122+
return _normalize_element_wise(result, x.dtype)
120123

121124
def baseline_add(x, y, out=None):
122-
comput_dtype, final_dtype = _computation_dtype(x, y)
123-
a = x.astype(comput_dtype) if isinstance(x, np.ndarray) else x
124-
b = y.astype(comput_dtype) if isinstance(y, np.ndarray) else y
125+
compute_dtype, final_dtype = _computation_dtype(x, y)
126+
a = x.astype(compute_dtype) if isinstance(x, np.ndarray) else x
127+
b = y.astype(compute_dtype) if isinstance(y, np.ndarray) else y
125128
# If the input types are identical, we want to perform addition with saturation
126129
result = np.add(a, b, out=out, casting="unsafe")
127130
result = _normalize_element_wise(result, final_dtype)
128131
return result
129132

130133
def baseline_multiply(x, y, out=None):
131-
comput_dtype, final_dtype = _computation_dtype(x, y)
132-
a = x.astype(comput_dtype) if isinstance(x, np.ndarray) else x
133-
b = y.astype(comput_dtype) if isinstance(y, np.ndarray) else y
134+
compute_dtype, final_dtype = _computation_dtype(x, y)
135+
a = x.astype(compute_dtype) if isinstance(x, np.ndarray) else x
136+
b = y.astype(compute_dtype) if isinstance(y, np.ndarray) else y
134137
# If the input types are identical, we want to perform addition with saturation
135138
result = np.multiply(a, b, out=out, casting="unsafe")
136139
result = _normalize_element_wise(result, final_dtype)
@@ -1103,7 +1106,7 @@ def test_dot_complex_explicit(ndim, capability):
11031106
@pytest.mark.parametrize("dtype", ["uint16", "uint32"])
11041107
@pytest.mark.parametrize("first_length_bound", [10, 100, 1000])
11051108
@pytest.mark.parametrize("second_length_bound", [10, 100, 1000])
1106-
@pytest.mark.parametrize("capability", possible_capabilities)
1109+
@pytest.mark.parametrize("capability", ["serial"] + possible_capabilities)
11071110
def test_intersect(dtype, first_length_bound, second_length_bound, capability):
11081111
"""Compares the simd.intersect() function with numpy.intersect1d."""
11091112

@@ -1133,7 +1136,7 @@ def test_intersect(dtype, first_length_bound, second_length_bound, capability):
11331136
@pytest.mark.parametrize("ndim", [11, 97, 1536])
11341137
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"])
11351138
@pytest.mark.parametrize("kernel", ["scale"])
1136-
@pytest.mark.parametrize("capability", possible_capabilities)
1139+
@pytest.mark.parametrize("capability", ["serial"] + possible_capabilities)
11371140
def test_scale(ndim, dtype, kernel, capability, stats_fixture):
11381141
""""""
11391142

@@ -1188,7 +1191,7 @@ def test_scale(ndim, dtype, kernel, capability, stats_fixture):
11881191
@pytest.mark.parametrize("ndim", [11, 97, 1536])
11891192
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"])
11901193
@pytest.mark.parametrize("kernel", ["sum"])
1191-
@pytest.mark.parametrize("capability", possible_capabilities)
1194+
@pytest.mark.parametrize("capability", ["serial"] + possible_capabilities)
11921195
def test_sum(ndim, dtype, kernel, capability, stats_fixture):
11931196
""""""
11941197

@@ -1240,7 +1243,7 @@ def test_sum(ndim, dtype, kernel, capability, stats_fixture):
12401243
@pytest.mark.parametrize("ndim", [11, 97, 1536])
12411244
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"])
12421245
@pytest.mark.parametrize("kernel", ["wsum"])
1243-
@pytest.mark.parametrize("capability", possible_capabilities)
1246+
@pytest.mark.parametrize("capability", ["serial"] + possible_capabilities)
12441247
def test_wsum(ndim, dtype, kernel, capability, stats_fixture):
12451248
""""""
12461249

@@ -1298,7 +1301,7 @@ def test_wsum(ndim, dtype, kernel, capability, stats_fixture):
12981301
@pytest.mark.parametrize("ndim", [11, 97, 1536])
12991302
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"])
13001303
@pytest.mark.parametrize("kernel", ["fma"])
1301-
@pytest.mark.parametrize("capability", possible_capabilities)
1304+
@pytest.mark.parametrize("capability", ["serial"] + possible_capabilities)
13021305
def test_fma(ndim, dtype, kernel, capability, stats_fixture):
13031306
""""""
13041307

0 commit comments

Comments
 (0)