Skip to content

Commit 8b35dfa

Browse files
committed
Partial Optimization for IVFPQ
1 parent d0434be commit 8b35dfa

24 files changed

Lines changed: 7330 additions & 0 deletions

faiss/CMakeLists.txt

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,45 @@ if(FAISS_ENABLE_SVS)
335335
)
336336
endif()
337337

338+
# -------------------- ARM aarch64 sra_krl sources --------------------
339+
# On aarch64, add sra_krl sources. Source code uses #ifdef __aarch64__.
340+
#
341+
# These .c files contain C++ features (template, constexpr, #include <cstring>)
342+
# so they MUST be compiled as C++. The public API headers (krl.h) already wrap
343+
# function declarations with extern "C" { }, ensuring C linkage for symbols.
344+
set(KRL_SRC)
345+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
346+
message(STATUS "ARM aarch64 detected: adding sra_krl sources")
347+
348+
set(KRL_SRC
349+
sra_krl/src/IPdistance_simd_f16f32.c
350+
sra_krl/src/IPdistance_simd.c
351+
sra_krl/src/IPdistance_simd_s8.c
352+
sra_krl/src/krl_handles.c
353+
sra_krl/src/L2distance_simd_f16f32.c
354+
sra_krl/src/L2distance_simd.c
355+
sra_krl/src/L2distance_simd_u8.c
356+
sra_krl/src/matrix_block_transpose.c
357+
sra_krl/src/MinMax_quant.c
358+
sra_krl/src/pq_search_with_table_8bit.c
359+
)
360+
361+
362+
# Force C++ compilation: these .c files use template, constexpr, <cstring>
363+
set_source_files_properties(${KRL_SRC} PROPERTIES LANGUAGE CXX)
364+
365+
set(KRL_COMPILE_FLAGS "-march=armv8.2-a+fp16+fp16fml+dotprod+sve+rcpc")
366+
367+
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang")
368+
set_source_files_properties(${KRL_SRC} PROPERTIES
369+
COMPILE_FLAGS "${KRL_COMPILE_FLAGS}"
370+
)
371+
endif()
372+
373+
# Add KRL sources to main FAISS_SRC
374+
list(APPEND FAISS_SRC ${KRL_SRC})
375+
endif()
376+
338377
if(NOT WIN32)
339378
list(APPEND FAISS_SRC invlists/OnDiskInvertedLists.cpp)
340379
list(APPEND FAISS_HEADERS invlists/OnDiskInvertedLists.h)
@@ -345,6 +384,14 @@ set(FAISS_HEADERS ${FAISS_HEADERS} PARENT_SCOPE)
345384

346385
add_library(faiss ${FAISS_SRC})
347386

387+
# -------------------- aarch64 sra_krl include dirs for faiss --------------------
388+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
389+
target_include_directories(faiss PUBLIC
390+
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/sra_krl/include>
391+
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/faiss/sra_krl>
392+
)
393+
endif()
394+
348395
add_library(faiss_avx2 ${FAISS_SRC})
349396
if(NOT FAISS_OPT_LEVEL STREQUAL "avx2" AND NOT FAISS_OPT_LEVEL STREQUAL "avx512" AND NOT FAISS_OPT_LEVEL STREQUAL "avx512_spr")
350397
set_target_properties(faiss_avx2 PROPERTIES EXCLUDE_FROM_ALL TRUE)
@@ -422,6 +469,14 @@ endif()
422469
target_sources(faiss_sve PRIVATE ${FAISS_SIMD_NEON_SRC} ${FAISS_SIMD_SVE_SRC})
423470
target_compile_definitions(faiss_sve PRIVATE COMPILE_SIMD_ARM_NEON COMPILE_SIMD_ARM_SVE)
424471

472+
# -------------------- aarch64 sra_krl include dirs for faiss_sve --------------------
473+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
474+
target_include_directories(faiss_sve PUBLIC
475+
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/sra_krl/include>
476+
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/faiss/sra_krl>
477+
)
478+
endif()
479+
425480
# =============================================================================
426481
# Dynamic Dispatch Mode
427482
# When FAISS_OPT_LEVEL=dd, the main faiss library is built with runtime SIMD
@@ -616,6 +671,13 @@ endif()
616671
# Note: FAISS_OPT_LEVEL=dd builds DD support into main faiss target,
617672
# so no separate faiss_dd install is needed
618673

674+
# -------------------- Install KRL headers (aarch64 only) --------------------
675+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
676+
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/sra_krl/include/
677+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/faiss/sra_krl/include
678+
FILES_MATCHING PATTERN "*.h")
679+
endif()
680+
619681
foreach(header ${FAISS_HEADERS})
620682
get_filename_component(dir ${header} DIRECTORY )
621683
install(FILES ${header}

faiss/IndexIVF.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
#include <faiss/impl/ResultHandler.h>
2929
#include <faiss/impl/expanded_scanners.h>
3030

31+
#ifdef __aarch64__
32+
extern "C" {
33+
#include <faiss/sra_krl/include/krl.h>
34+
}
35+
#endif
36+
3137
namespace faiss {
3238

3339
using ScopedIds = InvertedLists::ScopedIds;
@@ -464,6 +470,9 @@ void IndexIVF::search_preassigned(
464470
std::unique_ptr<InvertedListScanner> scanner(
465471
get_InvertedListScanner(store_pairs, sel, params));
466472

473+
#ifdef __aarch64__
474+
krl_create_LUT8b_handle(&(scanner->klh),(int)(sel != nullptr), tmp_buffer_size);
475+
#endif
467476
/*****************************************************
468477
* Depending on parallel_mode, there are two possible ways
469478
* to organize the search. Here we define local functions

faiss/IndexIVF.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
#include <faiss/invlists/InvertedLists.h>
2121
#include <faiss/utils/Heap.h>
2222

23+
#ifdef __aarch64__
24+
extern "C" {
25+
typedef struct KRLLookupTable8bitHandle KRLLUT8bHandle;
26+
void krl_clean_LUT8b_handle(KRLLUT8bHandle** klh);
27+
}
28+
#endif
29+
30+
2331
namespace faiss {
2432

2533
/** Encapsulates a quantizer object for the IndexIVF
@@ -198,6 +206,10 @@ struct IndexIVF : Index, IndexIVFInterface {
198206
/// centroids?
199207
bool by_residual = true;
200208

209+
#ifdef __aarch64__
210+
size_t tmp_buffer_size = 0;
211+
#endif
212+
201213
/** The Inverted file takes a quantizer (an Index) on input,
202214
* which implements the function mapping a vector to a list
203215
* identifier.
@@ -553,7 +565,17 @@ struct InvertedListScanner {
553565
const idx_t* ids,
554566
ResultHandler& handler) const;
555567

568+
#ifdef __aarch64__
569+
KRLLUT8bHandle* klh = nullptr;
570+
virtual ~InvertedListScanner() {
571+
if(klh){
572+
krl_clean_LUT8b_handle(&klh);
573+
klh = nullptr;
574+
}
575+
}
576+
#else
556577
virtual ~InvertedListScanner() {}
578+
#endif
557579
};
558580

559581
// whether to check that coarse quantizers are the same

faiss/IndexIVFPQ.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
#include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
3434
#include <faiss/impl/simd_dispatch.h>
3535

36+
#ifdef __aarch64__
37+
extern "C" {
38+
#include <faiss/sra_krl/include/krl.h>
39+
}
40+
#endif
41+
3642
namespace faiss {
3743

3844
/*****************************************
@@ -1229,7 +1235,46 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDist>,
12291235
assert(precompute_mode == 2);
12301236
this->scan_list_polysemous(ncode, codes, res);
12311237
} else if (precompute_mode == 2) {
1238+
#ifdef __aarch64__
1239+
if (this->pq.nbits == 8 && this->klh) {
1240+
if constexpr (use_sel) {
1241+
size_t j = 0;
1242+
size_t* idx_tmp_buffer = krl_get_idx_pointer(klh);
1243+
float* distance_tmp_buffer = krl_get_dist_pointer(klh);
1244+
for(size_t i = 0; i < ncode; ++i) {
1245+
if(this->sel->is_member(ids[i])){
1246+
idx_tmp_buffer[j++] = i;
1247+
}
1248+
}
1249+
ncode = j;
1250+
krl_table_lookup_8b_f32_by_idx(
1251+
this->pq.M, ncode, codes, this->sim_table, distance_tmp_buffer, this->dis0, idx_tmp_buffer,
1252+
this->pq.M * ncode, this->pq.M * 256, ncode);
1253+
for(size_t i = 0; i < ncode; ++i) {
1254+
res.add(idx_tmp_buffer[i],distance_tmp_buffer[i]);
1255+
}
1256+
} else {
1257+
float* distance_tmp_buffer = krl_get_dist_pointer(klh);
1258+
krl_table_lookup_8b_f32(
1259+
this->pq.M, ncode, codes, this->sim_table, distance_tmp_buffer, this->dis0, this->pq.M * ncode,
1260+
this->pq.M * 256, ncode);
1261+
size_t i = 0;
1262+
for(; i + 4 <= ncode; i+=4) {
1263+
res.add(i, distance_tmp_buffer[i]);
1264+
res.add(i + 1, distance_tmp_buffer[i + 1]);
1265+
res.add(i + 2, distance_tmp_buffer[i + 2]);
1266+
res.add(i + 3, distance_tmp_buffer[i + 3]);
1267+
}
1268+
for(; i < ncode; ++i) {
1269+
res.add(i, distance_tmp_buffer[i]);
1270+
}
1271+
}
1272+
} else {
1273+
this->scan_list_with_table(ncode, codes, res);
1274+
}
1275+
#else
12321276
this->scan_list_with_table(ncode, codes, res);
1277+
#endif
12331278
} else if (precompute_mode == 1) {
12341279
this->scan_list_with_pointer(ncode, codes, res);
12351280
} else if (precompute_mode == 0) {

faiss/impl/ProductQuantizer.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424

2525
extern "C" {
2626

27+
#ifdef __aarch64__
28+
#include <faiss/sra_krl/include/krl.h>
29+
#endif
30+
2731
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
2832

2933
int sgemm_(
@@ -436,6 +440,12 @@ void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
436440

437441
void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
438442
const {
443+
#ifdef __aarch64__
444+
if(use_transpose) {
445+
krl_L2sqr_ny_with_handle(kdh, dis_table, x, M * ksub, dsub * M);
446+
return;
447+
}
448+
#endif
439449
with_simd_level([&]<SIMDLevel SL>() {
440450
if (transposed_centroids.empty()) {
441451
// use regular version
@@ -466,6 +476,12 @@ void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
466476
void ProductQuantizer::compute_inner_prod_table(
467477
const float* x,
468478
float* dis_table) const {
479+
#ifdef __aarch64__
480+
if(use_transpose) {
481+
krl_inner_product_ny_with_handle(kdh, dis_table, x, M * ksub, dsub * M);
482+
return;
483+
}
484+
#endif
469485
with_simd_level([&]<SIMDLevel SL>() {
470486
for (size_t m = 0; m < M; m++) {
471487
fvec_inner_products_ny<SL>(
@@ -482,6 +498,15 @@ void ProductQuantizer::compute_distance_tables(
482498
size_t nx,
483499
const float* x,
484500
float* dis_tables) const {
501+
#ifdef __aarch64__
502+
if(use_transpose) {
503+
#pragma omp parallel for if (nx > 1)
504+
for (int64_t i = 0; i < nx; i++) {
505+
krl_L2sqr_ny_with_handle(kdh, dis_tables + i * ksub * M ,x + i * d, M * ksub, dsub * M);
506+
}
507+
return;
508+
}
509+
#endif
485510
#if defined(__AVX2__) || defined(__aarch64__)
486511
if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
487512
compute_PQ_dis_tables_dsub2(
@@ -516,6 +541,15 @@ void ProductQuantizer::compute_inner_prod_tables(
516541
size_t nx,
517542
const float* x,
518543
float* dis_tables) const {
544+
#ifdef __aarch64__
545+
if(use_transpose) {
546+
#pragma omp parallel for if (nx > 1)
547+
for (int64_t i = 0; i < nx; i++) {
548+
krl_inner_product_ny_with_handle(kdh, dis_tables + i * ksub * M ,x + i * d, M * ksub, dsub * M);
549+
}
550+
return;
551+
}
552+
#endif
519553
#if defined(__AVX2__) || defined(__aarch64__)
520554
if (dsub == 2 && nbits < 8) {
521555
compute_PQ_dis_tables_dsub2(

faiss/impl/ProductQuantizer.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
#include <faiss/impl/Quantizer.h>
1919
#include <faiss/impl/platform_macros.h>
2020
#include <faiss/utils/Heap.h>
21+
#include <iostream>
22+
23+
#ifdef __aarch64__
24+
extern "C" {
25+
#include <faiss/sra_krl/include/krl.h>
26+
}
27+
#endif
2128

2229
namespace faiss {
2330

@@ -51,6 +58,67 @@ struct ProductQuantizer : Quantizer {
5158
/// d / M)
5259
Index* assign_index;
5360

61+
#ifdef __aarch64__
62+
bool use_transpose = false;
63+
KRLDistanceHandle* kdh = nullptr;
64+
65+
ProductQuantizer(const ProductQuantizer& other)
66+
: Quantizer(other),
67+
M(other.M),
68+
nbits(other.nbits),
69+
dsub(other.dsub),
70+
ksub(other.ksub),
71+
verbose(other.verbose),
72+
train_type(other.train_type),
73+
cp(other.cp),
74+
assign_index(other.assign_index),
75+
use_transpose(false),
76+
kdh(nullptr),
77+
centroids(other.centroids),
78+
transposed_centroids(other.transposed_centroids),
79+
centroids_sq_lengths(other.centroids_sq_lengths),
80+
sdc_table(other.sdc_table) {}
81+
82+
ProductQuantizer& operator=(const ProductQuantizer& other) {
83+
if (this != &other) {
84+
Quantizer::operator=(other);
85+
M = other.M;
86+
nbits = other.nbits;
87+
dsub = other.dsub;
88+
ksub = other.ksub;
89+
verbose = other.verbose;
90+
train_type = other.train_type;
91+
cp = other.cp;
92+
assign_index = other.assign_index;
93+
centroids = other.centroids;
94+
transposed_centroids = other.transposed_centroids;
95+
centroids_sq_lengths = other.centroids_sq_lengths;
96+
sdc_table = other.sdc_table;
97+
if (kdh) {
98+
krl_clean_distance_handle(&kdh);
99+
kdh = nullptr;
100+
}
101+
use_transpose = false;
102+
}
103+
return *this;
104+
}
105+
106+
void initialize_krl_transpose_centroids(size_t batchsize, int metric_type) {
107+
if(!kdh && !centroids.empty()) {
108+
krl_create_distance_handle(
109+
&kdh, 3, batchsize, ksub, dsub, M, metric_type, (const uint8_t *)centroids.data(), M * ksub * dsub * sizeof(float));
110+
use_transpose = (kdh != nullptr);
111+
}
112+
}
113+
~ProductQuantizer() {
114+
if(kdh) {
115+
krl_clean_distance_handle(&kdh);
116+
kdh = nullptr;
117+
use_transpose = false;
118+
}
119+
}
120+
#endif
121+
54122
/// Centroid table, size M * ksub * dsub.
55123
/// Layout: (M, ksub, dsub)
56124
std::vector<float> centroids;

0 commit comments

Comments
 (0)