|
22 | 22 | // NOLINTNEXTLINE(facebook-hte-InlineHeader) |
23 | 23 | #include <faiss/utils/simd_impl/distances_autovec-inl.h> |
24 | 24 |
|
| 25 | +#define THE_SIMDLEVEL SIMDLevel::NONE |
| 26 | +// NOLINTNEXTLINE(facebook-hte-InlineHeader) |
| 27 | +#include <faiss/utils/simd_impl/distances_simdlib256.h> |
| 28 | + |
25 | 29 | namespace faiss { |
26 | 30 |
|
27 | 31 | /******* |
@@ -168,177 +172,3 @@ int fvec_madd_and_argmin<SIMDLevel::NONE>( |
168 | 172 | } |
169 | 173 |
|
170 | 174 | } // namespace faiss |
171 | | - |
172 | | -namespace faiss { |
173 | | - |
174 | | -/*************************************************************************** |
175 | | - * PQ tables computations |
176 | | - ***************************************************************************/ |
177 | | - |
178 | | -namespace { |
179 | | - |
180 | | -/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time |
181 | | -template <bool is_inner_product> |
182 | | -void pq2_8cents_table( |
183 | | - const simd8float32 centroids[8], |
184 | | - const simd8float32 x, |
185 | | - float* out, |
186 | | - size_t ldo, |
187 | | - size_t nout = 4) { |
188 | | - simd8float32 ips[4]; |
189 | | - |
190 | | - for (int i = 0; i < 4; i++) { |
191 | | - simd8float32 p1, p2; |
192 | | - if (is_inner_product) { |
193 | | - p1 = x * centroids[2 * i]; |
194 | | - p2 = x * centroids[2 * i + 1]; |
195 | | - } else { |
196 | | - p1 = (x - centroids[2 * i]); |
197 | | - p1 = p1 * p1; |
198 | | - p2 = (x - centroids[2 * i + 1]); |
199 | | - p2 = p2 * p2; |
200 | | - } |
201 | | - ips[i] = hadd(p1, p2); |
202 | | - } |
203 | | - |
204 | | - simd8float32 ip02a = geteven(ips[0], ips[1]); |
205 | | - simd8float32 ip02b = geteven(ips[2], ips[3]); |
206 | | - simd8float32 ip0 = getlow128(ip02a, ip02b); |
207 | | - simd8float32 ip2 = gethigh128(ip02a, ip02b); |
208 | | - |
209 | | - simd8float32 ip13a = getodd(ips[0], ips[1]); |
210 | | - simd8float32 ip13b = getodd(ips[2], ips[3]); |
211 | | - simd8float32 ip1 = getlow128(ip13a, ip13b); |
212 | | - simd8float32 ip3 = gethigh128(ip13a, ip13b); |
213 | | - |
214 | | - switch (nout) { |
215 | | - case 4: |
216 | | - ip3.storeu(out + 3 * ldo); |
217 | | - [[fallthrough]]; |
218 | | - case 3: |
219 | | - ip2.storeu(out + 2 * ldo); |
220 | | - [[fallthrough]]; |
221 | | - case 2: |
222 | | - ip1.storeu(out + 1 * ldo); |
223 | | - [[fallthrough]]; |
224 | | - case 1: |
225 | | - ip0.storeu(out); |
226 | | - } |
227 | | -} |
228 | | - |
229 | | -simd8float32 load_simd8float32_partial(const float* x, int n) { |
230 | | - ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0}; |
231 | | - float* wp = tmp; |
232 | | - for (int i = 0; i < n; i++) { |
233 | | - *wp++ = *x++; |
234 | | - } |
235 | | - return simd8float32(tmp); |
236 | | -} |
237 | | - |
238 | | -} // anonymous namespace |
239 | | - |
240 | | -void compute_PQ_dis_tables_dsub2( |
241 | | - size_t d, |
242 | | - size_t ksub, |
243 | | - const float* all_centroids, |
244 | | - size_t nx, |
245 | | - const float* x, |
246 | | - bool is_inner_product, |
247 | | - float* dis_tables) { |
248 | | - size_t M = d / 2; |
249 | | - FAISS_THROW_IF_NOT(ksub % 8 == 0); |
250 | | - |
251 | | - for (size_t m0 = 0; m0 < M; m0 += 4) { |
252 | | - int m1 = std::min(M, m0 + 4); |
253 | | - for (int k0 = 0; k0 < ksub; k0 += 8) { |
254 | | - simd8float32 centroids[8]; |
255 | | - for (int k = 0; k < 8; k++) { |
256 | | - ALIGNED(32) float centroid[8]; |
257 | | - size_t wp = 0; |
258 | | - size_t rp = (m0 * ksub + k + k0) * 2; |
259 | | - for (int m = m0; m < m1; m++) { |
260 | | - centroid[wp++] = all_centroids[rp]; |
261 | | - centroid[wp++] = all_centroids[rp + 1]; |
262 | | - rp += 2 * ksub; |
263 | | - } |
264 | | - centroids[k] = simd8float32(centroid); |
265 | | - } |
266 | | - for (size_t i = 0; i < nx; i++) { |
267 | | - simd8float32 xi; |
268 | | - if (m1 == m0 + 4) { |
269 | | - xi.loadu(x + i * d + m0 * 2); |
270 | | - } else { |
271 | | - xi = load_simd8float32_partial( |
272 | | - x + i * d + m0 * 2, 2 * (m1 - m0)); |
273 | | - } |
274 | | - |
275 | | - if (is_inner_product) { |
276 | | - pq2_8cents_table<true>( |
277 | | - centroids, |
278 | | - xi, |
279 | | - dis_tables + (i * M + m0) * ksub + k0, |
280 | | - ksub, |
281 | | - m1 - m0); |
282 | | - } else { |
283 | | - pq2_8cents_table<false>( |
284 | | - centroids, |
285 | | - xi, |
286 | | - dis_tables + (i * M + m0) * ksub + k0, |
287 | | - ksub, |
288 | | - m1 - m0); |
289 | | - } |
290 | | - } |
291 | | - } |
292 | | - } |
293 | | -} |
294 | | - |
295 | | -/********************************************************* |
296 | | - * Vector to vector functions |
297 | | - *********************************************************/ |
298 | | - |
299 | | -void fvec_sub(size_t d, const float* a, const float* b, float* c) { |
300 | | - size_t i; |
301 | | - for (i = 0; i + 7 < d; i += 8) { |
302 | | - simd8float32 ci, ai, bi; |
303 | | - ai.loadu(a + i); |
304 | | - bi.loadu(b + i); |
305 | | - ci = ai - bi; |
306 | | - ci.storeu(c + i); |
307 | | - } |
308 | | - // finish non-multiple of 8 remainder |
309 | | - for (; i < d; i++) { |
310 | | - c[i] = a[i] - b[i]; |
311 | | - } |
312 | | -} |
313 | | - |
314 | | -void fvec_add(size_t d, const float* a, const float* b, float* c) { |
315 | | - size_t i; |
316 | | - for (i = 0; i + 7 < d; i += 8) { |
317 | | - simd8float32 ci, ai, bi; |
318 | | - ai.loadu(a + i); |
319 | | - bi.loadu(b + i); |
320 | | - ci = ai + bi; |
321 | | - ci.storeu(c + i); |
322 | | - } |
323 | | - // finish non-multiple of 8 remainder |
324 | | - for (; i < d; i++) { |
325 | | - c[i] = a[i] + b[i]; |
326 | | - } |
327 | | -} |
328 | | - |
329 | | -void fvec_add(size_t d, const float* a, float b, float* c) { |
330 | | - size_t i; |
331 | | - simd8float32 bv(b); |
332 | | - for (i = 0; i + 7 < d; i += 8) { |
333 | | - simd8float32 ci, ai; |
334 | | - ai.loadu(a + i); |
335 | | - ci = ai + bv; |
336 | | - ci.storeu(c + i); |
337 | | - } |
338 | | - // finish non-multiple of 8 remainder |
339 | | - for (; i < d; i++) { |
340 | | - c[i] = a[i] + b; |
341 | | - } |
342 | | -} |
343 | | - |
344 | | -} // namespace faiss |
0 commit comments