Skip to content

Add support for half in CAGRA+HNSW #813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: branch-25.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ auto create_algo(const std::string& algo_name,
[[maybe_unused]] cuvs::bench::Metric metric = parse_metric(distance);
std::unique_ptr<cuvs::bench::algo<T>> a;

if constexpr (std::is_same_v<T, float> or std::is_same_v<T, std::uint8_t>) {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, uint8_t>) {
if (algo_name == "raft_cagra_hnswlib" || algo_name == "cuvs_cagra_hnswlib") {
typename cuvs::bench::cuvs_cagra_hnswlib<T, uint32_t>::build_param bparam;
::parse_build_param<T, uint32_t>(conf, bparam.cagra_build_param);
Expand Down Expand Up @@ -97,6 +98,7 @@ auto create_search_param(const std::string& algo_name, const nlohmann::json& con
} // namespace cuvs::bench

REGISTER_ALGO_INSTANCE(float);
REGISTER_ALGO_INSTANCE(half);
REGISTER_ALGO_INSTANCE(std::int8_t);
REGISTER_ALGO_INSTANCE(std::uint8_t);

Expand Down
6 changes: 1 addition & 5 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,7 @@ void cuvs_cagra<T, IdxT>::save(const std::string& file) const
template <typename T, typename IdxT>
void cuvs_cagra<T, IdxT>::save_to_hnswlib(const std::string& file) const
{
if constexpr (!std::is_same_v<T, half>) {
cuvs::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_);
} else {
RAFT_FAIL("Cannot save fp16 index to hnswlib format");
}
cuvs::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_);
}

template <typename T, typename IdxT>
Expand Down
8 changes: 3 additions & 5 deletions cpp/bench/ann/src/hnswlib/hnswlib_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ auto create_algo(const std::string& algo_name,
cuvs::bench::Metric metric = parse_metric(distance);
std::unique_ptr<cuvs::bench::algo<T>> a;

if constexpr (std::is_same_v<T, float>) {
if (algo_name == "hnswlib") { a = make_algo<T, cuvs::bench::hnsw_lib>(metric, dim, conf); }
}

if constexpr (std::is_same_v<T, uint8_t>) {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, uint8_t>) {
if (algo_name == "hnswlib") { a = make_algo<T, cuvs::bench::hnsw_lib>(metric, dim, conf); }
}

Expand All @@ -90,6 +87,7 @@ auto create_search_param(const std::string& algo_name, const nlohmann::json& con
}; // namespace cuvs::bench

REGISTER_ALGO_INSTANCE(float);
REGISTER_ALGO_INSTANCE(half);
REGISTER_ALGO_INSTANCE(std::int8_t);
REGISTER_ALGO_INSTANCE(std::uint8_t);

Expand Down
41 changes: 23 additions & 18 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ struct hnsw_dist_t<float> {
using type = float;
};

template <>
struct hnsw_dist_t<half> {
using type = float;
};

template <>
struct hnsw_dist_t<uint8_t> {
using type = int;
Expand Down Expand Up @@ -122,12 +127,9 @@ template <typename T>
hnsw_lib<T>::hnsw_lib(Metric metric, int dim, const build_param& param) : algo<T>(metric, dim)
{
assert(dim_ > 0);
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t>);
if constexpr (std::is_same_v<T, uint8_t>) {
if (metric_ != Metric::kEuclidean) {
throw std::runtime_error("hnswlib<uint8_t> only supports Euclidean distance");
}
}
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, uint8_t>,
"Only float, half, uint8, and int8 are supported");

ef_construction_ = param.ef_construction;
m_ = param.m;
Expand All @@ -137,14 +139,17 @@ hnsw_lib<T>::hnsw_lib(Metric metric, int dim, const build_param& param) : algo<T
template <typename T>
void hnsw_lib<T>::build(const T* dataset, size_t nrow)
{
if constexpr (std::is_same_v<T, float>) {
if (metric_ == Metric::kInnerProduct) {
space_ = std::make_shared<hnswlib::InnerProductSpace>(dim_);
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, uint8_t>,
"Only float, half, uint8, and int8 are supported");
if (metric_ == Metric::kInnerProduct) {
space_ = std::make_shared<hnswlib::InnerProductSpace<T, typename hnsw_dist_t<T>::type>>(dim_);
} else {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, half>) {
space_ = std::make_shared<hnswlib::L2Space<T, typename hnsw_dist_t<T>::type>>(dim_);
} else {
space_ = std::make_shared<hnswlib::L2Space>(dim_);
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}
} else if constexpr (std::is_same_v<T, uint8_t>) {
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}

appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
Expand Down Expand Up @@ -209,14 +214,14 @@ void hnsw_lib<T>::save(const std::string& path_to_index) const
template <typename T>
void hnsw_lib<T>::load(const std::string& path_to_index)
{
if constexpr (std::is_same_v<T, float>) {
if (metric_ == Metric::kInnerProduct) {
space_ = std::make_shared<hnswlib::InnerProductSpace>(dim_);
if (metric_ == Metric::kInnerProduct) {
space_ = std::make_shared<hnswlib::InnerProductSpace<T, typename hnsw_dist_t<T>::type>>(dim_);
} else {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, half>) {
space_ = std::make_shared<hnswlib::L2Space<T, typename hnsw_dist_t<T>::type>>(dim_);
} else {
space_ = std::make_shared<hnswlib::L2Space>(dim_);
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}
} else if constexpr (std::is_same_v<T, uint8_t>) {
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}

appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
Expand Down
208 changes: 203 additions & 5 deletions cpp/cmake/patches/hnswlib.diff
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,209 @@ index bef0017..0ee7931 100644
}
}
}
diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h
index 2b1c359..e311f9d 100644
--- a/hnswlib/space_ip.h
+++ b/hnswlib/space_ip.h
@@ -3,19 +3,22 @@

namespace hnswlib {

-static float
+template <typename DataType, typename DistanceType>
+static DistanceType
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
- float res = 0;
+ DistanceType res = 0;
for (unsigned i = 0; i < qty; i++) {
- res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
+ const DistanceType t = ((DataType *) pVect1)[i] * ((DataType *) pVect2)[i];
+ res += t;
}
return res;
}

-static float
+template <typename DataType, typename DistanceType>
+static DistanceType
InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
- return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
+ return DistanceType{1} - InnerProduct<DataType, DistanceType>(pVect1, pVect2, qty_ptr);
}

#if defined(USE_AVX)
@@ -294,7 +297,7 @@ InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v,
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
- float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
+ float res_tail = InnerProduct<float, float>(pVect1, pVect2, &qty_left);
return 1.0f - (res + res_tail);
}

@@ -308,20 +311,21 @@ InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v,

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
- float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
+ float res_tail = InnerProduct<float, float>(pVect1, pVect2, &qty_left);

return 1.0f - (res + res_tail);
}
#endif

-class InnerProductSpace : public SpaceInterface<float> {
- DISTFUNC<float> fstdistfunc_;
+template <typename DataType, typename DistanceType>
+class InnerProductSpace : public SpaceInterface<DistanceType> {
+ DISTFUNC<DistanceType> fstdistfunc_;
size_t data_size_;
size_t dim_;

public:
InnerProductSpace(size_t dim) {
- fstdistfunc_ = InnerProductDistance;
+ fstdistfunc_ = InnerProductDistance<DataType, DistanceType>;
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable()) {
@@ -344,24 +348,26 @@ class InnerProductSpace : public SpaceInterface<float> {
}
#endif

- if (dim % 16 == 0)
- fstdistfunc_ = InnerProductDistanceSIMD16Ext;
- else if (dim % 4 == 0)
- fstdistfunc_ = InnerProductDistanceSIMD4Ext;
- else if (dim > 16)
- fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
- else if (dim > 4)
- fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
+ if constexpr (std::is_same_v<DataType, float>) {
+ if (dim % 16 == 0)
+ fstdistfunc_ = InnerProductDistanceSIMD16Ext;
+ else if (dim % 4 == 0)
+ fstdistfunc_ = InnerProductDistanceSIMD4Ext;
+ else if (dim > 16)
+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
+ else if (dim > 4)
+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
+ }
#endif
dim_ = dim;
- data_size_ = dim * sizeof(float);
+ data_size_ = dim * sizeof(DataType);
}

size_t get_data_size() {
return data_size_;
}

- DISTFUNC<float> get_dist_func() {
+ DISTFUNC<DistanceType> get_dist_func() {
return fstdistfunc_;
}

diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h
index 834d19f..0c0af26 100644
index 834d19f..c57c87a 100644
--- a/hnswlib/space_l2.h
+++ b/hnswlib/space_l2.h
@@ -252,12 +252,13 @@ class L2Space : public SpaceInterface<float> {
@@ -3,15 +3,16 @@

namespace hnswlib {

-static float
+template <typename DataType, typename DistanceType>
+static DistanceType
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
- float *pVect1 = (float *) pVect1v;
- float *pVect2 = (float *) pVect2v;
+ DataType *pVect1 = (DataType *) pVect1v;
+ DataType *pVect2 = (DataType *) pVect2v;
size_t qty = *((size_t *) qty_ptr);

- float res = 0;
+ DistanceType res = 0;
for (size_t i = 0; i < qty; i++) {
- float t = *pVect1 - *pVect2;
+ DistanceType t = *pVect1 - *pVect2;
pVect1++;
pVect2++;
res += t * t;
@@ -155,7 +156,7 @@ L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qt
float *pVect2 = (float *) pVect2v + qty16;

size_t qty_left = qty - qty16;
- float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
+ float res_tail = L2Sqr<float, float>(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif
@@ -199,20 +200,21 @@ L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty

float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
- float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
+ float res_tail = L2Sqr<float, float>(pVect1, pVect2, &qty_left);

return (res + res_tail);
}
#endif

-class L2Space : public SpaceInterface<float> {
- DISTFUNC<float> fstdistfunc_;
+template <typename DataType, typename DistanceType=float>
+class L2Space : public SpaceInterface<DistanceType> {
+ DISTFUNC<DistanceType> fstdistfunc_;
size_t data_size_;
size_t dim_;

public:
L2Space(size_t dim) {
- fstdistfunc_ = L2Sqr;
+ fstdistfunc_ = L2Sqr<DataType, DistanceType>;
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable())
@@ -224,24 +226,26 @@ class L2Space : public SpaceInterface<float> {
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#endif

- if (dim % 16 == 0)
- fstdistfunc_ = L2SqrSIMD16Ext;
- else if (dim % 4 == 0)
- fstdistfunc_ = L2SqrSIMD4Ext;
- else if (dim > 16)
- fstdistfunc_ = L2SqrSIMD16ExtResiduals;
- else if (dim > 4)
- fstdistfunc_ = L2SqrSIMD4ExtResiduals;
+ if constexpr (std::is_same_v<DataType, float>) {
+ if (dim % 16 == 0)
+ fstdistfunc_ = L2SqrSIMD16Ext;
+ else if (dim % 4 == 0)
+ fstdistfunc_ = L2SqrSIMD4Ext;
+ else if (dim > 16)
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
+ else if (dim > 4)
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
+ }
#endif
dim_ = dim;
- data_size_ = dim * sizeof(float);
+ data_size_ = dim * sizeof(DataType);
}

size_t get_data_size() {
return data_size_;
}

- DISTFUNC<float> get_dist_func() {
+ DISTFUNC<DistanceType> get_dist_func() {
return fstdistfunc_;
}

@@ -252,12 +256,13 @@ class L2Space : public SpaceInterface<float> {
~L2Space() {}
};

Expand All @@ -122,7 +320,7 @@ index 834d19f..0c0af26 100644

qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
@@ -277,11 +278,12 @@ L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const voi
@@ -277,11 +282,12 @@ L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const voi
return (res);
}

Expand All @@ -137,15 +335,15 @@ index 834d19f..0c0af26 100644

for (size_t i = 0; i < qty; i++) {
res += ((*a) - (*b)) * ((*a) - (*b));
@@ -291,6 +293,7 @@ static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2,
@@ -291,6 +297,7 @@ static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2,
return (res);
}

+template <typename T>
class L2SpaceI : public SpaceInterface<int> {
DISTFUNC<int> fstdistfunc_;
size_t data_size_;
@@ -299,9 +302,9 @@ class L2SpaceI : public SpaceInterface<int> {
@@ -299,9 +306,9 @@ class L2SpaceI : public SpaceInterface<int> {
public:
L2SpaceI(size_t dim) {
if (dim % 4 == 0) {
Expand Down
Loading