Skip to content

Add support for half in CAGRA+HNSW #813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: branch-25.06
Choose a base branch
from
4 changes: 2 additions & 2 deletions conda/environments/go_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ dependencies:
- libcusolver=11.4.1.48
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- libcuvs==25.4.*,>=0.0.0a0
- libraft==25.4.*,>=0.0.0a0
- libcuvs==25.6.*,>=0.0.0a0
- libraft==25.6.*,>=0.0.0a0
- nccl>=2.19
- ninja
- nvcc_linux-aarch64=11.8
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/go_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ dependencies:
- libcusolver=11.4.1.48
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- libcuvs==25.4.*,>=0.0.0a0
- libraft==25.4.*,>=0.0.0a0
- libcuvs==25.6.*,>=0.0.0a0
- libraft==25.6.*,>=0.0.0a0
- nccl>=2.19
- ninja
- nvcc_linux-64=11.8
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/go_cuda-128_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- libcuvs==25.4.*,>=0.0.0a0
- libraft==25.4.*,>=0.0.0a0
- libcuvs==25.6.*,>=0.0.0a0
- libraft==25.6.*,>=0.0.0a0
- nccl>=2.19
- ninja
- sysroot_linux-aarch64==2.28
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/go_cuda-128_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- libcuvs==25.4.*,>=0.0.0a0
- libraft==25.4.*,>=0.0.0a0
- libcuvs==25.6.*,>=0.0.0a0
- libraft==25.6.*,>=0.0.0a0
- nccl>=2.19
- ninja
- sysroot_linux-64==2.28
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/rust_cuda-128_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- libcuvs==25.4.*,>=0.0.0a0
- libraft==25.4.*,>=0.0.0a0
- libcuvs==25.6.*,>=0.0.0a0
- libraft==25.6.*,>=0.0.0a0
- make
- nccl>=2.19
- ninja
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/rust_cuda-128_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ dependencies:
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- libcuvs==25.4.*,>=0.0.0a0
- libraft==25.4.*,>=0.0.0a0
- libcuvs==25.6.*,>=0.0.0a0
- libraft==25.6.*,>=0.0.0a0
- make
- nccl>=2.19
- ninja
Expand Down
208 changes: 203 additions & 5 deletions cpp/cmake/patches/hnswlib.diff
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,209 @@ index bef0017..0ee7931 100644
}
}
}
diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h
index 2b1c359..e311f9d 100644
--- a/hnswlib/space_ip.h
+++ b/hnswlib/space_ip.h
@@ -3,19 +3,22 @@

namespace hnswlib {

-static float
+template <typename DataType, typename DistanceType>
+static DistanceType
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
- float res = 0;
+ DistanceType res = 0;
for (unsigned i = 0; i < qty; i++) {
- res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
+ const DistanceType t = ((DataType *) pVect1)[i] * ((DataType *) pVect2)[i];
+ res += t;
}
return res;
}

-static float
+template <typename DataType, typename DistanceType>
+static DistanceType
InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
- return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
+ return DistanceType{1.0f} - InnerProduct<DataType, DistanceType>(pVect1, pVect2, qty_ptr);
}

#if defined(USE_AVX)
@@ -294,7 +297,7 @@ InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v,
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
- float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
+ float res_tail = InnerProduct<float, float>(pVect1, pVect2, &qty_left);
return 1.0f - (res + res_tail);
}

@@ -308,20 +311,21 @@ InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v,

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
- float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
+ float res_tail = InnerProduct<float, float>(pVect1, pVect2, &qty_left);

return 1.0f - (res + res_tail);
}
#endif

-class InnerProductSpace : public SpaceInterface<float> {
- DISTFUNC<float> fstdistfunc_;
+template <typename DataType, typename DistanceType>
+class InnerProductSpace : public SpaceInterface<DistanceType> {
+ DISTFUNC<DistanceType> fstdistfunc_;
size_t data_size_;
size_t dim_;

public:
InnerProductSpace(size_t dim) {
- fstdistfunc_ = InnerProductDistance;
+ fstdistfunc_ = InnerProductDistance<DataType, DistanceType>;
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable()) {
@@ -344,24 +348,26 @@ class InnerProductSpace : public SpaceInterface<float> {
}
#endif

- if (dim % 16 == 0)
- fstdistfunc_ = InnerProductDistanceSIMD16Ext;
- else if (dim % 4 == 0)
- fstdistfunc_ = InnerProductDistanceSIMD4Ext;
- else if (dim > 16)
- fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
- else if (dim > 4)
- fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
+ if constexpr (std::is_same_v<DataType, float>) {
+ if (dim % 16 == 0)
+ fstdistfunc_ = InnerProductDistanceSIMD16Ext;
+ else if (dim % 4 == 0)
+ fstdistfunc_ = InnerProductDistanceSIMD4Ext;
+ else if (dim > 16)
+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
+ else if (dim > 4)
+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
+ }
#endif
dim_ = dim;
- data_size_ = dim * sizeof(float);
+ data_size_ = dim * sizeof(DataType);
}

size_t get_data_size() {
return data_size_;
}

- DISTFUNC<float> get_dist_func() {
+ DISTFUNC<DistanceType> get_dist_func() {
return fstdistfunc_;
}

diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h
index 834d19f..0c0af26 100644
index 834d19f..c57c87a 100644
--- a/hnswlib/space_l2.h
+++ b/hnswlib/space_l2.h
@@ -252,12 +252,13 @@ class L2Space : public SpaceInterface<float> {
@@ -3,15 +3,16 @@

namespace hnswlib {

-static float
+template <typename DataType, typename DistanceType>
+static DistanceType
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- float *pVect1 = (float *) pVect1v;
- float *pVect2 = (float *) pVect2v;
+ DataType *pVect1 = (DataType *) pVect1v;
+ DataType *pVect2 = (DataType *) pVect2v;
size_t qty = *((size_t *) qty_ptr);

- float res = 0;
+ DistanceType res = 0;
for (size_t i = 0; i < qty; i++) {
- float t = *pVect1 - *pVect2;
+ DistanceType t = *pVect1 - *pVect2;
pVect1++;
pVect2++;
res += t * t;
@@ -155,7 +156,7 @@ L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qt
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
- float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
+ float res_tail = L2Sqr<float, float>(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif
@@ -199,20 +200,21 @@ L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
- float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
+ float res_tail = L2Sqr<float, float>(pVect1, pVect2, &qty_left);

return (res + res_tail);
}
#endif

-class L2Space : public SpaceInterface<float> {
- DISTFUNC<float> fstdistfunc_;
+template <typename DataType, typename DistanceType=float>
+class L2Space : public SpaceInterface<DistanceType> {
+ DISTFUNC<DistanceType> fstdistfunc_;
size_t data_size_;
size_t dim_;

public:
L2Space(size_t dim) {
- fstdistfunc_ = L2Sqr;
+ fstdistfunc_ = L2Sqr<DataType, DistanceType>;
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable())
@@ -224,24 +226,26 @@ class L2Space : public SpaceInterface<float> {
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#endif

- if (dim % 16 == 0)
- fstdistfunc_ = L2SqrSIMD16Ext;
- else if (dim % 4 == 0)
- fstdistfunc_ = L2SqrSIMD4Ext;
- else if (dim > 16)
- fstdistfunc_ = L2SqrSIMD16ExtResiduals;
- else if (dim > 4)
- fstdistfunc_ = L2SqrSIMD4ExtResiduals;
+ if constexpr (std::is_same_v<DataType, float>) {
+ if (dim % 16 == 0)
+ fstdistfunc_ = L2SqrSIMD16Ext;
+ else if (dim % 4 == 0)
+ fstdistfunc_ = L2SqrSIMD4Ext;
+ else if (dim > 16)
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
+ else if (dim > 4)
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
+ }
#endif
dim_ = dim;
- data_size_ = dim * sizeof(float);
+ data_size_ = dim * sizeof(DataType);
}

size_t get_data_size() {
return data_size_;
}

- DISTFUNC<float> get_dist_func() {
+ DISTFUNC<DistanceType> get_dist_func() {
return fstdistfunc_;
}

@@ -252,12 +256,13 @@ class L2Space : public SpaceInterface<float> {
~L2Space() {}
};

Expand All @@ -122,7 +320,7 @@ index 834d19f..0c0af26 100644

qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
@@ -277,11 +278,12 @@ L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const voi
@@ -277,11 +282,12 @@ L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const voi
return (res);
}

Expand All @@ -137,15 +335,15 @@ index 834d19f..0c0af26 100644

for (size_t i = 0; i < qty; i++) {
res += ((*a) - (*b)) * ((*a) - (*b));
@@ -291,6 +293,7 @@ static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2,
@@ -291,6 +297,7 @@ static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2,
return (res);
}

+template <typename T>
class L2SpaceI : public SpaceInterface<int> {
DISTFUNC<int> fstdistfunc_;
size_t data_size_;
@@ -299,9 +302,9 @@ class L2SpaceI : public SpaceInterface<int> {
@@ -299,9 +306,9 @@ class L2SpaceI : public SpaceInterface<int> {
public:
L2SpaceI(size_t dim) {
if (dim % 4 == 0) {
Expand Down
Loading
Loading