Skip to content

Commit 0e7c656

Browse files
committed
Add: i16 element-wise kernels for Haswell
1 parent 38df49c commit 0e7c656

File tree

2 files changed

+227
-97
lines changed

2 files changed

+227
-97
lines changed

include/simsimd/elementwise.h

Lines changed: 93 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,11 +1242,11 @@ SIMSIMD_PUBLIC void simsimd_scale_i16_haswell(simsimd_i16_t const *a, simsimd_si
12421242
simsimd_size_t i = 0;
12431243
for (; i + 8 <= n; i += 8) {
12441244
__m256 a_vec = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_lddqu_si128((__m128i *)(a + i))));
1245-
__m256 b_vec = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_lddqu_si128((__m128i *)(a + i))));
12461245
__m256 sum_vec = _mm256_fmadd_ps(a_vec, alpha_vec, beta_vec);
12471246
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
12481247
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(-32768));
12491248
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(32767));
1249+
// Casting down to 16-bit integers is tricky!
12501250
__m128i sum_i16_vec =
12511251
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
12521252
_mm_storeu_si128((__m128i *)(result + i), sum_i16_vec);
@@ -1263,47 +1263,27 @@ SIMSIMD_PUBLIC void simsimd_scale_i16_haswell(simsimd_i16_t const *a, simsimd_si
12631263
SIMSIMD_PUBLIC void simsimd_fma_i16_haswell( //
12641264
simsimd_i16_t const *a, simsimd_i16_t const *b, simsimd_i16_t const *c, simsimd_size_t n, //
12651265
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i16_t *result) {
1266-
#if 0
12671266
simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
12681267
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
12691268
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
12701269
__m256 beta_vec = _mm256_set1_ps(beta_f32);
1271-
int sum_i32s[8], a_i32s[8], b_i32s[8], c_i32s[8];
12721270

12731271
// The main loop:
12741272
simsimd_size_t i = 0;
12751273
for (; i + 8 <= n; i += 8) {
1276-
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
1277-
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
1278-
a_i32s[0] = a[i + 0], a_i32s[1] = a[i + 1], a_i32s[2] = a[i + 2], a_i32s[3] = a[i + 3], //
1279-
a_i32s[4] = a[i + 4], a_i32s[5] = a[i + 5], a_i32s[6] = a[i + 6], a_i32s[7] = a[i + 7];
1280-
b_i32s[0] = b[i + 0], b_i32s[1] = b[i + 1], b_i32s[2] = b[i + 2], b_i32s[3] = b[i + 3], //
1281-
b_i32s[4] = b[i + 4], b_i32s[5] = b[i + 5], b_i32s[6] = b[i + 6], b_i32s[7] = b[i + 7];
1282-
c_i32s[0] = c[i + 0], c_i32s[1] = c[i + 1], c_i32s[2] = c[i + 2], c_i32s[3] = c[i + 3], //
1283-
c_i32s[4] = c[i + 4], c_i32s[5] = c[i + 5], c_i32s[6] = c[i + 6], c_i32s[7] = c[i + 7];
1284-
//! This can be done at least 50% faster if we convert 8-bit integers to floats instead
1285-
//! of relying on the slow `_mm256_cvtepi32_ps` instruction.
1286-
__m256 a_vec = _mm256_cvtepi32_ps(_mm256_lddqu_si256((__m256i *)a_i32s));
1287-
__m256 b_vec = _mm256_cvtepi32_ps(_mm256_lddqu_si256((__m256i *)b_i32s));
1288-
__m256 c_vec = _mm256_cvtepi32_ps(_mm256_lddqu_si256((__m256i *)c_i32s));
1289-
// The normal part.
1274+
__m256 a_vec = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_lddqu_si128((__m128i *)(a + i))));
1275+
__m256 b_vec = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_lddqu_si128((__m128i *)(b + i))));
1276+
__m256 c_vec = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_lddqu_si128((__m128i *)(c + i))));
12901277
__m256 ab_vec = _mm256_mul_ps(a_vec, b_vec);
12911278
__m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec);
12921279
__m256 sum_vec = _mm256_fmadd_ps(c_vec, beta_vec, ab_scaled_vec);
1293-
// Instead of serial calls to expensive `_simsimd_f32_to_u8`, convert and clip with SIMD.
12941280
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
1295-
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(-128));
1296-
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(127));
1297-
// Export into a serial buffer.
1298-
_mm256_storeu_si256((__m256i *)sum_i32s, sum_i32_vec);
1299-
result[i + 0] = (simsimd_i16_t)sum_i32s[0];
1300-
result[i + 1] = (simsimd_i16_t)sum_i32s[1];
1301-
result[i + 2] = (simsimd_i16_t)sum_i32s[2];
1302-
result[i + 3] = (simsimd_i16_t)sum_i32s[3];
1303-
result[i + 4] = (simsimd_i16_t)sum_i32s[4];
1304-
result[i + 5] = (simsimd_i16_t)sum_i32s[5];
1305-
result[i + 6] = (simsimd_i16_t)sum_i32s[6];
1306-
result[i + 7] = (simsimd_i16_t)sum_i32s[7];
1281+
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(-32768));
1282+
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(32767));
1283+
// Casting down to 16-bit integers is tricky!
1284+
__m128i sum_i16_vec =
1285+
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
1286+
_mm_storeu_si128((__m128i *)(result + i), sum_i16_vec);
13071287
}
13081288

13091289
// The tail:
@@ -1312,7 +1292,89 @@ SIMSIMD_PUBLIC void simsimd_fma_i16_haswell(
13121292
simsimd_f32_t sum = alpha_f32 * ai * bi + beta_f32 * ci;
13131293
_simsimd_f32_to_i16(&sum, result + i);
13141294
}
1315-
#endif
1295+
}
1296+
1297+
SIMSIMD_PUBLIC void simsimd_sum_u16_haswell(simsimd_u16_t const *a, simsimd_u16_t const *b, simsimd_size_t n,
1298+
simsimd_u16_t *result) {
1299+
// The main loop:
1300+
simsimd_size_t i = 0;
1301+
for (; i + 16 <= n; i += 16) {
1302+
__m256i a_vec = _mm256_lddqu_si256((__m256i *)(a + i));
1303+
__m256i b_vec = _mm256_lddqu_si256((__m256i *)(b + i));
1304+
__m256i sum_vec = _mm256_adds_epu16(a_vec, b_vec);
1305+
_mm256_storeu_si256((__m256i *)(result + i), sum_vec);
1306+
}
1307+
1308+
// The tail:
1309+
for (; i < n; ++i) {
1310+
simsimd_u64_t ai = a[i], bi = b[i];
1311+
simsimd_u64_t sum = ai + bi;
1312+
_simsimd_u64_to_u16(&sum, result + i);
1313+
}
1314+
}
1315+
1316+
SIMSIMD_PUBLIC void simsimd_scale_u16_haswell(simsimd_u16_t const *a, simsimd_size_t n, simsimd_distance_t alpha,
1317+
simsimd_distance_t beta, simsimd_u16_t *result) {
1318+
1319+
simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
1320+
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
1321+
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
1322+
__m256 beta_vec = _mm256_set1_ps(beta_f32);
1323+
1324+
// The main loop:
1325+
simsimd_size_t i = 0;
1326+
for (; i + 8 <= n; i += 8) {
1327+
__m256 a_vec = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_lddqu_si128((__m128i *)(a + i))));
1328+
__m256 sum_vec = _mm256_fmadd_ps(a_vec, alpha_vec, beta_vec);
1329+
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
1330+
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_setzero_si256());
1331+
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(65535));
1332+
// Casting down to 16-bit integers is tricky!
1333+
__m128i sum_u16_vec =
1334+
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
1335+
_mm_storeu_si128((__m128i *)(result + i), sum_u16_vec);
1336+
}
1337+
1338+
// The tail:
1339+
for (; i < n; ++i) {
1340+
simsimd_f32_t ai = a[i];
1341+
simsimd_f32_t sum = alpha_f32 * ai + beta_f32;
1342+
_simsimd_f32_to_u16(&sum, result + i);
1343+
}
1344+
}
1345+
1346+
SIMSIMD_PUBLIC void simsimd_fma_u16_haswell( //
1347+
simsimd_u16_t const *a, simsimd_u16_t const *b, simsimd_u16_t const *c, simsimd_size_t n, //
1348+
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u16_t *result) {
1349+
simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
1350+
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
1351+
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
1352+
__m256 beta_vec = _mm256_set1_ps(beta_f32);
1353+
1354+
// The main loop:
1355+
simsimd_size_t i = 0;
1356+
for (; i + 8 <= n; i += 8) {
1357+
__m256 a_vec = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_lddqu_si128((__m128i *)(a + i))));
1358+
__m256 b_vec = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_lddqu_si128((__m128i *)(b + i))));
1359+
__m256 c_vec = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_lddqu_si128((__m128i *)(c + i))));
1360+
__m256 ab_vec = _mm256_mul_ps(a_vec, b_vec);
1361+
__m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec);
1362+
__m256 sum_vec = _mm256_fmadd_ps(c_vec, beta_vec, ab_scaled_vec);
1363+
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
1364+
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_setzero_si256());
1365+
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(65535));
1366+
// Casting down to 16-bit integers is tricky!
1367+
__m128i sum_u16_vec =
1368+
_mm_packs_epi32(_mm256_castsi256_si128(sum_i32_vec), _mm256_extracti128_si256(sum_i32_vec, 1));
1369+
_mm_storeu_si128((__m128i *)(result + i), sum_u16_vec);
1370+
}
1371+
1372+
// The tail:
1373+
for (; i < n; ++i) {
1374+
simsimd_f32_t ai = a[i], bi = b[i], ci = c[i];
1375+
simsimd_f32_t sum = alpha_f32 * ai * bi + beta_f32 * ci;
1376+
_simsimd_f32_to_u16(&sum, result + i);
1377+
}
13161378
}
13171379

13181380
#pragma clang attribute pop

0 commit comments

Comments
 (0)