Skip to content

Commit bb298a2

Browse files
mdouzemeta-codesync[bot]
authored andcommitted
convert simdlib in distances_simd.cpp (#4884)
Summary: Pull Request resolved: #4884 Templatize fvec_sub, fvec_add, and compute_PQ_dis_tables_dsub2 on SIMDLevel. Move these 256-bit simdlib-based functions from distances_simd.cpp into a new header simd_impl/distances_simdlib256.h, templatized on THE_SIMDLEVEL. Add with_simd_level_256bit() to simd_dispatch.h that maps AVX512->AVX2 and ARM_SVE->ARM_NEON using simd256_level_selector. Dispatch wrappers use with_simd_level_256bit with lambdas. Reviewed By: algoriddle Differential Revision: D95570445 fbshipit-source-id: 557b6546e9526c40a44f7165584e96a9e39a3b89
1 parent 8d8268c commit bb298a2

9 files changed

Lines changed: 308 additions & 174 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ set(FAISS_HEADERS
322322
utils/hamming_distance/avx2-inl.h
323323
utils/hamming_distance/avx512-inl.h
324324
utils/simd_impl/distances_autovec-inl.h
325+
utils/simd_impl/distances_simdlib256.h
325326
utils/simd_impl/distances_sse-inl.h
326327
)
327328

faiss/impl/simd_dispatch.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,21 @@ inline auto with_simd_level(LambdaType&& action) {
136136
DISPATCH_SIMDLevel(action.template operator());
137137
}
138138

139+
/**
140+
* Like with_simd_level, but maps to the 256-bit SIMD equivalent:
141+
* AVX512, AVX512_SPR -> AVX2
142+
* ARM_SVE -> ARM_NEON
143+
* AVX2, ARM_NEON, NONE -> unchanged
144+
*
145+
* Use for functions implemented with simd8float32 (256-bit) operations
146+
* that don't have dedicated AVX512 or SVE implementations.
147+
*/
148+
template <typename LambdaType>
149+
inline auto with_simd_level_256bit(LambdaType&& action) {
150+
return with_simd_level([&]<SIMDLevel level>() {
151+
constexpr SIMDLevel level256 = simd256_level_selector<level>::value;
152+
return action.template operator()<level256>();
153+
});
154+
}
155+
139156
} // namespace faiss

faiss/utils/distances.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,30 @@ int fvec_madd_and_argmin(
172172
return fvec_madd_and_argmin_dispatch(n, a, bf, b, c);
173173
}
174174

175+
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
176+
fvec_sub_dispatch(d, a, b, c);
177+
}
178+
179+
void fvec_add(size_t d, const float* a, const float* b, float* c) {
180+
fvec_add_dispatch(d, a, b, c);
181+
}
182+
183+
void fvec_add(size_t d, const float* a, float b, float* c) {
184+
fvec_add_scalar_dispatch(d, a, b, c);
185+
}
186+
187+
void compute_PQ_dis_tables_dsub2(
188+
size_t d,
189+
size_t ksub,
190+
const float* all_centroids,
191+
size_t nx,
192+
const float* x,
193+
bool is_inner_product,
194+
float* dis_tables) {
195+
compute_PQ_dis_tables_dsub2_dispatch(
196+
d, ksub, all_centroids, nx, x, is_inner_product, dis_tables);
197+
}
198+
175199
/***************************************************************************
176200
* Matrix/vector ops
177201
***************************************************************************/

faiss/utils/distances.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ void inner_product_to_L2sqr(
261261
*/
262262
void fvec_add(size_t d, const float* a, const float* b, float* c);
263263

264+
template <SIMDLevel>
265+
void fvec_add(size_t d, const float* a, const float* b, float* c);
266+
264267
/** compute c := a + b for a, c vectors and b a scalar
265268
*
266269
* c and a can overlap
@@ -270,6 +273,9 @@ void fvec_add(size_t d, const float* a, const float* b, float* c);
270273
*/
271274
void fvec_add(size_t d, const float* a, float b, float* c);
272275

276+
template <SIMDLevel>
277+
void fvec_add(size_t d, const float* a, float b, float* c);
278+
273279
/** compute c := a - b for vectors
274280
*
275281
* c and a can overlap, c and b can overlap
@@ -280,6 +286,9 @@ void fvec_add(size_t d, const float* a, float b, float* c);
280286
*/
281287
void fvec_sub(size_t d, const float* a, const float* b, float* c);
282288

289+
template <SIMDLevel>
290+
void fvec_sub(size_t d, const float* a, const float* b, float* c);
291+
283292
/***************************************************************************
284293
* Compute a subset of distances
285294
***************************************************************************/
@@ -542,6 +551,16 @@ void compute_PQ_dis_tables_dsub2(
542551
bool is_inner_product,
543552
float* dis_tables);
544553

554+
template <SIMDLevel>
555+
void compute_PQ_dis_tables_dsub2(
556+
size_t d,
557+
size_t ksub,
558+
const float* centroids,
559+
size_t nx,
560+
const float* x,
561+
bool is_inner_product,
562+
float* dis_tables);
563+
545564
/***************************************************************************
546565
* Templatized versions of distance functions
547566
***************************************************************************/

faiss/utils/distances_dispatch.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,45 @@ inline int fvec_madd_and_argmin_dispatch(
167167
DISPATCH_SIMDLevel(fvec_madd_and_argmin, n, a, bf, b, c);
168168
}
169169

170+
inline void fvec_sub_dispatch(
171+
size_t d,
172+
const float* a,
173+
const float* b,
174+
float* c) {
175+
with_simd_level_256bit(
176+
[&]<SIMDLevel level>() { fvec_sub<level>(d, a, b, c); });
177+
}
178+
179+
inline void fvec_add_dispatch(
180+
size_t d,
181+
const float* a,
182+
const float* b,
183+
float* c) {
184+
with_simd_level_256bit(
185+
[&]<SIMDLevel level>() { fvec_add<level>(d, a, b, c); });
186+
}
187+
188+
inline void fvec_add_scalar_dispatch(
189+
size_t d,
190+
const float* a,
191+
float b,
192+
float* c) {
193+
with_simd_level_256bit(
194+
[&]<SIMDLevel level>() { fvec_add<level>(d, a, b, c); });
195+
}
196+
197+
inline void compute_PQ_dis_tables_dsub2_dispatch(
198+
size_t d,
199+
size_t ksub,
200+
const float* centroids,
201+
size_t nx,
202+
const float* x,
203+
bool is_inner_product,
204+
float* dis_tables) {
205+
with_simd_level_256bit([&]<SIMDLevel level>() {
206+
compute_PQ_dis_tables_dsub2<level>(
207+
d, ksub, centroids, nx, x, is_inner_product, dis_tables);
208+
});
209+
}
210+
170211
} // namespace faiss

faiss/utils/distances_simd.cpp

Lines changed: 4 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
2323
#include <faiss/utils/simd_impl/distances_autovec-inl.h>
2424

25+
#define THE_SIMDLEVEL SIMDLevel::NONE
26+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
27+
#include <faiss/utils/simd_impl/distances_simdlib256.h>
28+
2529
namespace faiss {
2630

2731
/*******
@@ -168,177 +172,3 @@ int fvec_madd_and_argmin<SIMDLevel::NONE>(
168172
}
169173

170174
} // namespace faiss
171-
172-
namespace faiss {
173-
174-
/***************************************************************************
175-
* PQ tables computations
176-
***************************************************************************/
177-
178-
namespace {
179-
180-
/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
181-
template <bool is_inner_product>
182-
void pq2_8cents_table(
183-
const simd8float32 centroids[8],
184-
const simd8float32 x,
185-
float* out,
186-
size_t ldo,
187-
size_t nout = 4) {
188-
simd8float32 ips[4];
189-
190-
for (int i = 0; i < 4; i++) {
191-
simd8float32 p1, p2;
192-
if (is_inner_product) {
193-
p1 = x * centroids[2 * i];
194-
p2 = x * centroids[2 * i + 1];
195-
} else {
196-
p1 = (x - centroids[2 * i]);
197-
p1 = p1 * p1;
198-
p2 = (x - centroids[2 * i + 1]);
199-
p2 = p2 * p2;
200-
}
201-
ips[i] = hadd(p1, p2);
202-
}
203-
204-
simd8float32 ip02a = geteven(ips[0], ips[1]);
205-
simd8float32 ip02b = geteven(ips[2], ips[3]);
206-
simd8float32 ip0 = getlow128(ip02a, ip02b);
207-
simd8float32 ip2 = gethigh128(ip02a, ip02b);
208-
209-
simd8float32 ip13a = getodd(ips[0], ips[1]);
210-
simd8float32 ip13b = getodd(ips[2], ips[3]);
211-
simd8float32 ip1 = getlow128(ip13a, ip13b);
212-
simd8float32 ip3 = gethigh128(ip13a, ip13b);
213-
214-
switch (nout) {
215-
case 4:
216-
ip3.storeu(out + 3 * ldo);
217-
[[fallthrough]];
218-
case 3:
219-
ip2.storeu(out + 2 * ldo);
220-
[[fallthrough]];
221-
case 2:
222-
ip1.storeu(out + 1 * ldo);
223-
[[fallthrough]];
224-
case 1:
225-
ip0.storeu(out);
226-
}
227-
}
228-
229-
simd8float32 load_simd8float32_partial(const float* x, int n) {
230-
ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
231-
float* wp = tmp;
232-
for (int i = 0; i < n; i++) {
233-
*wp++ = *x++;
234-
}
235-
return simd8float32(tmp);
236-
}
237-
238-
} // anonymous namespace
239-
240-
void compute_PQ_dis_tables_dsub2(
241-
size_t d,
242-
size_t ksub,
243-
const float* all_centroids,
244-
size_t nx,
245-
const float* x,
246-
bool is_inner_product,
247-
float* dis_tables) {
248-
size_t M = d / 2;
249-
FAISS_THROW_IF_NOT(ksub % 8 == 0);
250-
251-
for (size_t m0 = 0; m0 < M; m0 += 4) {
252-
int m1 = std::min(M, m0 + 4);
253-
for (int k0 = 0; k0 < ksub; k0 += 8) {
254-
simd8float32 centroids[8];
255-
for (int k = 0; k < 8; k++) {
256-
ALIGNED(32) float centroid[8];
257-
size_t wp = 0;
258-
size_t rp = (m0 * ksub + k + k0) * 2;
259-
for (int m = m0; m < m1; m++) {
260-
centroid[wp++] = all_centroids[rp];
261-
centroid[wp++] = all_centroids[rp + 1];
262-
rp += 2 * ksub;
263-
}
264-
centroids[k] = simd8float32(centroid);
265-
}
266-
for (size_t i = 0; i < nx; i++) {
267-
simd8float32 xi;
268-
if (m1 == m0 + 4) {
269-
xi.loadu(x + i * d + m0 * 2);
270-
} else {
271-
xi = load_simd8float32_partial(
272-
x + i * d + m0 * 2, 2 * (m1 - m0));
273-
}
274-
275-
if (is_inner_product) {
276-
pq2_8cents_table<true>(
277-
centroids,
278-
xi,
279-
dis_tables + (i * M + m0) * ksub + k0,
280-
ksub,
281-
m1 - m0);
282-
} else {
283-
pq2_8cents_table<false>(
284-
centroids,
285-
xi,
286-
dis_tables + (i * M + m0) * ksub + k0,
287-
ksub,
288-
m1 - m0);
289-
}
290-
}
291-
}
292-
}
293-
}
294-
295-
/*********************************************************
296-
* Vector to vector functions
297-
*********************************************************/
298-
299-
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
300-
size_t i;
301-
for (i = 0; i + 7 < d; i += 8) {
302-
simd8float32 ci, ai, bi;
303-
ai.loadu(a + i);
304-
bi.loadu(b + i);
305-
ci = ai - bi;
306-
ci.storeu(c + i);
307-
}
308-
// finish non-multiple of 8 remainder
309-
for (; i < d; i++) {
310-
c[i] = a[i] - b[i];
311-
}
312-
}
313-
314-
void fvec_add(size_t d, const float* a, const float* b, float* c) {
315-
size_t i;
316-
for (i = 0; i + 7 < d; i += 8) {
317-
simd8float32 ci, ai, bi;
318-
ai.loadu(a + i);
319-
bi.loadu(b + i);
320-
ci = ai + bi;
321-
ci.storeu(c + i);
322-
}
323-
// finish non-multiple of 8 remainder
324-
for (; i < d; i++) {
325-
c[i] = a[i] + b[i];
326-
}
327-
}
328-
329-
void fvec_add(size_t d, const float* a, float b, float* c) {
330-
size_t i;
331-
simd8float32 bv(b);
332-
for (i = 0; i + 7 < d; i += 8) {
333-
simd8float32 ci, ai;
334-
ai.loadu(a + i);
335-
ci = ai + bv;
336-
ci.storeu(c + i);
337-
}
338-
// finish non-multiple of 8 remainder
339-
for (; i < d; i++) {
340-
c[i] = a[i] + b;
341-
}
342-
}
343-
344-
} // namespace faiss

faiss/utils/simd_impl/distances_aarch64.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#define AUTOVEC_LEVEL SIMDLevel::ARM_NEON
1616
#include <faiss/utils/simd_impl/distances_autovec-inl.h>
1717

18+
#define THE_SIMDLEVEL SIMDLevel::ARM_NEON
19+
#include <faiss/utils/simd_impl/distances_simdlib256.h>
20+
1821
namespace faiss {
1922

2023
template <>

faiss/utils/simd_impl/distances_avx2.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
1414
#include <faiss/utils/simd_impl/distances_autovec-inl.h>
1515

16+
#define THE_SIMDLEVEL SIMDLevel::AVX2
17+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
18+
#include <faiss/utils/simd_impl/distances_simdlib256.h>
19+
1620
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
1721
#include <faiss/utils/simd_impl/distances_sse-inl.h>
1822
// NOLINTNEXTLINE(facebook-hte-InlineHeader)

0 commit comments

Comments
 (0)