Skip to content

Commit 1ff228e

Browse files
committed
Add TurboQuant scalar quantizer SIMD backends and training coverage
2 parents b068fd9 + 85074f0 commit 1ff228e

18 files changed

+1256
-18
lines changed

faiss/factory_tools.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ const std::map<faiss::ScalarQuantizer::QuantizerType, std::string> sq_types = {
3838
{faiss::ScalarQuantizer::QT_bf16, "SQbf16"},
3939
{faiss::ScalarQuantizer::QT_8bit_direct_signed, "SQ8_direct_signed"},
4040
{faiss::ScalarQuantizer::QT_8bit_direct, "SQ8_direct"},
41+
{faiss::ScalarQuantizer::QT_1bit_tqmse, "SQtqmse1"},
42+
{faiss::ScalarQuantizer::QT_2bit_tqmse, "SQtqmse2"},
43+
{faiss::ScalarQuantizer::QT_3bit_tqmse, "SQtqmse3"},
44+
{faiss::ScalarQuantizer::QT_4bit_tqmse, "SQtqmse4"},
45+
{faiss::ScalarQuantizer::QT_8bit_tqmse, "SQtqmse8"},
4146
};
4247

4348
int get_hnsw_M(const faiss::IndexHNSW* index) {

faiss/impl/ScalarQuantizer.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,29 @@ ScalarQuantizer::ScalarQuantizer() {}
3838

3939
void ScalarQuantizer::set_derived_sizes() {
4040
switch (qtype) {
41+
case QT_1bit_tqmse:
42+
code_size = (d + 7) / 8;
43+
bits = 1;
44+
break;
45+
case QT_2bit_tqmse:
46+
code_size = (d * 2 + 7) / 8;
47+
bits = 2;
48+
break;
49+
case QT_3bit_tqmse:
50+
code_size = (d * 3 + 7) / 8;
51+
bits = 3;
52+
break;
4153
case QT_8bit:
4254
case QT_8bit_uniform:
4355
case QT_8bit_direct:
4456
case QT_8bit_direct_signed:
57+
case QT_8bit_tqmse:
4558
code_size = d;
4659
bits = 8;
4760
break;
4861
case QT_4bit:
4962
case QT_4bit_uniform:
63+
case QT_4bit_tqmse:
5064
code_size = (d + 1) / 2;
5165
bits = 4;
5266
break;
@@ -107,6 +121,21 @@ void ScalarQuantizer::train(size_t n, const float* x) {
107121
case QT_8bit_direct_signed:
108122
// no training necessary
109123
break;
124+
case QT_1bit_tqmse:
125+
scalar_quantizer::train_TurboQuantMSE(d, 1, trained);
126+
break;
127+
case QT_2bit_tqmse:
128+
scalar_quantizer::train_TurboQuantMSE(d, 2, trained);
129+
break;
130+
case QT_3bit_tqmse:
131+
scalar_quantizer::train_TurboQuantMSE(d, 3, trained);
132+
break;
133+
case QT_4bit_tqmse:
134+
scalar_quantizer::train_TurboQuantMSE(d, 4, trained);
135+
break;
136+
case QT_8bit_tqmse:
137+
scalar_quantizer::train_TurboQuantMSE(d, 8, trained);
138+
break;
110139
default:
111140
break;
112141
}

faiss/impl/ScalarQuantizer.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ struct ScalarQuantizer : Quantizer {
3333
QT_bf16,
3434
QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
3535
///< [-128 to 127]
36+
QT_1bit_tqmse, ///< TurboQuant MSE-optimized, x bits per component
37+
QT_2bit_tqmse,
38+
QT_3bit_tqmse,
39+
QT_4bit_tqmse,
40+
QT_8bit_tqmse,
3641
QT_count
3742
};
3843

faiss/impl/index_read.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ void read_ScalarQuantizer(
868868
READ1(qtype_int);
869869
FAISS_THROW_IF_NOT_FMT(
870870
qtype_int >= ScalarQuantizer::QT_8bit &&
871-
qtype_int <= ScalarQuantizer::QT_8bit_direct_signed,
871+
qtype_int < ScalarQuantizer::QT_count,
872872
"invalid ScalarQuantizer qtype %d",
873873
qtype_int);
874874
ivsc->qtype = static_cast<ScalarQuantizer::QuantizerType>(qtype_int);
@@ -906,6 +906,21 @@ void read_ScalarQuantizer(
906906
case ScalarQuantizer::QT_count:
907907
expected = 0;
908908
break;
909+
case ScalarQuantizer::QT_1bit_tqmse:
910+
expected = 2 + 1; // 2^bits centroids + (2^bits - 1) boundaries
911+
break;
912+
case ScalarQuantizer::QT_2bit_tqmse:
913+
expected = 4 + 3;
914+
break;
915+
case ScalarQuantizer::QT_3bit_tqmse:
916+
expected = 8 + 7;
917+
break;
918+
case ScalarQuantizer::QT_4bit_tqmse:
919+
expected = 16 + 15;
920+
break;
921+
case ScalarQuantizer::QT_8bit_tqmse:
922+
expected = 256 + 255;
923+
break;
909924
}
910925
if (ivsc->trained.empty() && expected > 0) {
911926
// Empty trained is only valid for untrained indices.

faiss/impl/scalar_quantizer/quantizers.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
#pragma once
99

10+
#include <algorithm>
11+
12+
#include <faiss/impl/FaissAssert.h>
1013
#include <faiss/impl/ScalarQuantizer.h>
1114
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1215
#include <faiss/utils/bf16.h>
@@ -113,6 +116,90 @@ struct QuantizerTemplate<
113116
}
114117
};
115118

119+
/*******************************************************************
120+
* TurboQuant MSE quantizer
121+
*******************************************************************/
122+
template <int NBits, SIMDLevel SL>
123+
struct QuantizerTurboQuantMSE;
124+
125+
template <int NBits>
126+
struct QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>
127+
: ScalarQuantizer::SQuantizer {
128+
static_assert(NBits >= 1 && NBits <= 8);
129+
130+
static constexpr size_t kCentroidsCount = size_t(1) << NBits;
131+
static constexpr uint16_t kIndexMask =
132+
static_cast<uint16_t>((1u << NBits) - 1);
133+
134+
const size_t d;
135+
const float* centroids;
136+
const float* boundaries;
137+
138+
QuantizerTurboQuantMSE(size_t d_in, const std::vector<float>& trained)
139+
: d(d_in), centroids(nullptr), boundaries(nullptr) {
140+
FAISS_THROW_IF_NOT(trained.size() == 2 * kCentroidsCount - 1);
141+
centroids = trained.data();
142+
boundaries = trained.data() + kCentroidsCount;
143+
}
144+
145+
FAISS_ALWAYS_INLINE uint8_t select_index(float x) const {
146+
return static_cast<uint8_t>(
147+
std::upper_bound(
148+
boundaries, boundaries + (kCentroidsCount - 1), x) -
149+
boundaries);
150+
}
151+
152+
FAISS_ALWAYS_INLINE void encode_index(uint8_t idx, uint8_t* code, size_t i)
153+
const {
154+
const size_t bit_offset = i * NBits;
155+
const size_t byte_offset = bit_offset >> 3;
156+
const size_t bit_shift = bit_offset & 7;
157+
const uint16_t packed = static_cast<uint16_t>(idx & kIndexMask)
158+
<< bit_shift;
159+
code[byte_offset] |= packed & 0xff;
160+
if (bit_shift + NBits > 8) {
161+
code[byte_offset + 1] |= packed >> 8;
162+
}
163+
}
164+
165+
FAISS_ALWAYS_INLINE uint8_t
166+
decode_index(const uint8_t* code, size_t i) const {
167+
const size_t bit_offset = i * NBits;
168+
const size_t byte_offset = bit_offset >> 3;
169+
const size_t bit_shift = bit_offset & 7;
170+
171+
uint16_t packed = code[byte_offset];
172+
if (bit_shift + NBits > 8) {
173+
packed |= static_cast<uint16_t>(code[byte_offset + 1]) << 8;
174+
}
175+
return static_cast<uint8_t>((packed >> bit_shift) & kIndexMask);
176+
}
177+
178+
void encode_vector(const float* x, uint8_t* code) const final {
179+
for (size_t i = 0; i < d; i++) {
180+
encode_index(select_index(x[i]), code, i);
181+
}
182+
}
183+
184+
void decode_vector(const uint8_t* code, float* x) const final {
185+
for (size_t i = 0; i < d; i++) {
186+
x[i] = centroids[decode_index(code, i)];
187+
}
188+
}
189+
190+
FAISS_ALWAYS_INLINE float reconstruct_component(
191+
const uint8_t* code,
192+
size_t i) const {
193+
return centroids[decode_index(code, i)];
194+
}
195+
};
196+
197+
template <int NBits, SIMDLevel SL>
198+
struct QuantizerTurboQuantMSE : QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE> {
199+
using QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>::
200+
QuantizerTurboQuantMSE;
201+
};
202+
116203
/*******************************************************************
117204
* FP16 quantizer
118205
*******************************************************************/

faiss/impl/scalar_quantizer/sq-avx2.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
#include <faiss/impl/simdlib/simdlib_avx2.h>
1111

12+
#include <cstring>
13+
1214
#include <faiss/impl/scalar_quantizer/codecs.h>
1315
#include <faiss/impl/scalar_quantizer/distance_computers.h>
1416
#include <faiss/impl/scalar_quantizer/quantizers.h>
@@ -21,6 +23,61 @@ namespace scalar_quantizer {
2123

2224
using simd8float32 = faiss::simd8float32_tpl<SIMDLevel::AVX2>;
2325

26+
namespace {
27+
28+
FAISS_ALWAYS_INLINE uint16_t load_u16(const uint8_t* ptr) {
29+
uint16_t value;
30+
std::memcpy(&value, ptr, sizeof(value));
31+
return value;
32+
}
33+
34+
FAISS_ALWAYS_INLINE uint32_t load_u32(const uint8_t* ptr) {
35+
uint32_t value;
36+
std::memcpy(&value, ptr, sizeof(value));
37+
return value;
38+
}
39+
40+
FAISS_ALWAYS_INLINE uint32_t load_u24(const uint8_t* ptr) {
41+
return static_cast<uint32_t>(ptr[0]) |
42+
(static_cast<uint32_t>(ptr[1]) << 8) |
43+
(static_cast<uint32_t>(ptr[2]) << 16);
44+
}
45+
46+
FAISS_ALWAYS_INLINE __m256i unpack_8x1bit_to_u32(const uint8_t* code, int i) {
47+
const uint32_t packed = code[static_cast<size_t>(i) >> 3];
48+
const __m256i shifts = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
49+
const __m256i indices =
50+
_mm256_srlv_epi32(_mm256_set1_epi32(packed), shifts);
51+
return _mm256_and_si256(indices, _mm256_set1_epi32(0x1));
52+
}
53+
54+
FAISS_ALWAYS_INLINE __m256i unpack_8x2bit_to_u32(const uint8_t* code, int i) {
55+
const uint32_t packed = load_u16(code + (static_cast<size_t>(i) >> 2));
56+
const __m256i shifts = _mm256_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14);
57+
const __m256i indices =
58+
_mm256_srlv_epi32(_mm256_set1_epi32(packed), shifts);
59+
return _mm256_and_si256(indices, _mm256_set1_epi32(0x3));
60+
}
61+
62+
FAISS_ALWAYS_INLINE __m256i unpack_8x3bit_to_u32(const uint8_t* code, int i) {
63+
const uint32_t packed =
64+
load_u24(code + ((static_cast<size_t>(i) >> 3) * 3));
65+
const __m256i shifts = _mm256_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21);
66+
const __m256i indices =
67+
_mm256_srlv_epi32(_mm256_set1_epi32(packed), shifts);
68+
return _mm256_and_si256(indices, _mm256_set1_epi32(0x7));
69+
}
70+
71+
FAISS_ALWAYS_INLINE __m256i unpack_8x4bit_to_u32(const uint8_t* code, int i) {
72+
const uint32_t packed = load_u32(code + (static_cast<size_t>(i) >> 1));
73+
const __m256i shifts = _mm256_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28);
74+
const __m256i indices =
75+
_mm256_srlv_epi32(_mm256_set1_epi32(packed), shifts);
76+
return _mm256_and_si256(indices, _mm256_set1_epi32(0xf));
77+
}
78+
79+
} // namespace
80+
2481
/**********************************************************
2582
* Codecs
2683
**********************************************************/
@@ -168,6 +225,56 @@ struct QuantizerTemplate<
168225
}
169226
};
170227

228+
/**********************************************************
229+
* TurboQuant MSE quantizer
230+
**********************************************************/
231+
232+
#define DEFINE_TQMSE_AVX2_SPECIALIZATION(NBITS, INDEX_EXPR) \
233+
template <> \
234+
struct QuantizerTurboQuantMSE<NBITS, SIMDLevel::AVX2> \
235+
: QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE> { \
236+
using Base = QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE>; \
237+
\
238+
QuantizerTurboQuantMSE(size_t d, const std::vector<float>& trained) \
239+
: Base(d, trained) { \
240+
assert(d % 8 == 0); \
241+
} \
242+
\
243+
FAISS_ALWAYS_INLINE simd8float32 \
244+
reconstruct_8_components(const uint8_t* code, int i) const { \
245+
const __m256i indices = (INDEX_EXPR); \
246+
return simd8float32(_mm256_i32gather_ps( \
247+
this->centroids, indices, sizeof(float))); \
248+
} \
249+
}
250+
251+
DEFINE_TQMSE_AVX2_SPECIALIZATION(1, unpack_8x1bit_to_u32(code, i));
252+
DEFINE_TQMSE_AVX2_SPECIALIZATION(2, unpack_8x2bit_to_u32(code, i));
253+
DEFINE_TQMSE_AVX2_SPECIALIZATION(3, unpack_8x3bit_to_u32(code, i));
254+
DEFINE_TQMSE_AVX2_SPECIALIZATION(4, unpack_8x4bit_to_u32(code, i));
255+
256+
#undef DEFINE_TQMSE_AVX2_SPECIALIZATION
257+
258+
template <>
259+
struct QuantizerTurboQuantMSE<8, SIMDLevel::AVX2>
260+
: QuantizerTurboQuantMSE<8, SIMDLevel::NONE> {
261+
using Base = QuantizerTurboQuantMSE<8, SIMDLevel::NONE>;
262+
263+
QuantizerTurboQuantMSE(size_t d, const std::vector<float>& trained)
264+
: Base(d, trained) {
265+
assert(d % 8 == 0);
266+
}
267+
268+
FAISS_ALWAYS_INLINE simd8float32
269+
reconstruct_8_components(const uint8_t* code, int i) const {
270+
const __m128i packed = _mm_loadl_epi64(
271+
(const __m128i*)(code + static_cast<size_t>(i)));
272+
const __m256i indices = _mm256_cvtepu8_epi32(packed);
273+
return simd8float32(
274+
_mm256_i32gather_ps(this->centroids, indices, sizeof(float)));
275+
}
276+
};
277+
171278
/**********************************************************
172279
* FP16 Quantizer
173280
**********************************************************/

0 commit comments

Comments
 (0)