Skip to content
85 changes: 51 additions & 34 deletions cpp/src/TypedIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ class TypedIndex : public Index {
: TypedIndex(space, dimensions, /* M */ 12, /* efConstruction */ 200,
/* randomSeed */ 1, /* maxElements */ 1,
/* enableOrderPreservingTransform */ false) {
auto inputStream = std::make_shared<FileInputStream>(indexFilename);
algorithmImpl = std::make_unique<hnswlib::HierarchicalNSW<dist_t, data_t>>(
spaceImpl.get(), indexFilename, 0, searchOnly);
spaceImpl.get(), inputStream, 0, searchOnly);
Comment on lines +155 to +157
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor for HierarchicalNSW is

HierarchicalNSW(Space<dist_t, data_t> *s,
                  std::shared_ptr<InputStream> inputStream,
                  size_t max_elements = 0, bool search_only = false)

currentLabel = algorithmImpl->cur_element_count;
}

Expand Down Expand Up @@ -187,11 +188,11 @@ class TypedIndex : public Index {
currentLabel = algorithmImpl->cur_element_count;
}

int getNumDimensions() const { return dimensions; }
int getNumDimensions() const override { return dimensions; }

SpaceType getSpace() const { return space; }
SpaceType getSpace() const override { return space; }

std::string getSpaceName() const {
std::string getSpaceName() const override {
// TODO: Use magic_enum?
switch (space) {
case SpaceType::Euclidean:
Expand All @@ -205,35 +206,38 @@ class TypedIndex : public Index {
}
}

StorageDataType getStorageDataType() const {
StorageDataType getStorageDataType() const override {
return storageDataType<data_t>();
}

std::string getStorageDataTypeName() const {
std::string getStorageDataTypeName() const override {
return storageDataTypeName<data_t>();
}

void setEF(size_t ef) {
void setEF(size_t ef) override {
defaultEF = ef;
if (algorithmImpl)
algorithmImpl->ef_ = ef;
}

void setNumThreads(int numThreads) { numThreadsDefault = numThreads; }
void setNumThreads(int numThreads) override {
numThreadsDefault = numThreads;
}

void loadIndex(const std::string &pathToIndex, bool searchOnly = false) {
void loadIndex(const std::string &pathToIndex,
bool searchOnly = false) override {
throw std::runtime_error("Not implemented.");
}

void loadIndex(std::shared_ptr<InputStream> inputStream,
bool searchOnly = false) {
bool searchOnly = false) override {
throw std::runtime_error("Not implemented.");
}

/**
* Save this index to the provided file path on disk.
*/
void saveIndex(const std::string &pathToIndex) {
void saveIndex(const std::string &pathToIndex) override {
algorithmImpl->saveIndex(pathToIndex);
saveIndex(std::make_shared<FileOutputStream>(pathToIndex));
}
Expand All @@ -243,14 +247,14 @@ class TypedIndex : public Index {
* The bytes written to the given output stream can be passed to the
* TypedIndex constructor to reload this index.
*/
void saveIndex(std::shared_ptr<OutputStream> outputStream) {
void saveIndex(std::shared_ptr<OutputStream> outputStream) override {
metadata->setMaxNorm(max_norm);
metadata->setUseOrderPreservingTransform(useOrderPreservingTransform);
metadata->serializeToStream(outputStream);
algorithmImpl->saveIndex(outputStream);
}

float getDistance(std::vector<float> _a, std::vector<float> _b) {
float getDistance(std::vector<float> _a, std::vector<float> _b) override {
if ((int)_a.size() != dimensions || (int)_b.size() != dimensions) {
throw std::runtime_error("Index has " + std::to_string(dimensions) +
" dimensions, but received vectors of size: " +
Expand Down Expand Up @@ -285,7 +289,7 @@ class TypedIndex : public Index {
}

hnswlib::labeltype addItem(std::vector<float> vector,
std::optional<hnswlib::labeltype> id) {
std::optional<hnswlib::labeltype> id) override {
std::vector<size_t> ids;

if (id) {
Expand All @@ -297,13 +301,15 @@ class TypedIndex : public Index {

std::vector<hnswlib::labeltype>
addItems(const std::vector<std::vector<float>> vectors,
std::vector<hnswlib::labeltype> ids = {}, int numThreads = -1) {
std::vector<hnswlib::labeltype> ids = {},
int numThreads = -1) override {
return addItems(vectorsToNDArray(vectors), ids, numThreads);
}

std::vector<hnswlib::labeltype>
addItems(NDArray<float, 2> floatInput,
std::vector<hnswlib::labeltype> ids = {}, int numThreads = -1) {
std::vector<hnswlib::labeltype> ids = {},
int numThreads = -1) override {
if (numThreads <= 0)
numThreads = numThreadsDefault;

Expand Down Expand Up @@ -477,13 +483,13 @@ class TypedIndex : public Index {
return algorithmImpl->getDataByLabel(id);
}

std::vector<float> getVector(hnswlib::labeltype id) {
std::vector<float> getVector(hnswlib::labeltype id) override {
std::vector<data_t> rawData = getRawVector(id);
NDArray<data_t, 2> output(rawData.data(), {1, (int)dimensions});
return dataTypeToFloat<data_t, scalefactor>(output).data;
}

NDArray<float, 2> getVectors(std::vector<hnswlib::labeltype> ids) {
NDArray<float, 2> getVectors(std::vector<hnswlib::labeltype> ids) override {
NDArray<float, 2> output = NDArray<float, 2>({(int)ids.size(), dimensions});

for (unsigned long i = 0; i < ids.size(); i++) {
Expand All @@ -495,7 +501,7 @@ class TypedIndex : public Index {
return output;
}

std::vector<hnswlib::labeltype> getIDs() const {
std::vector<hnswlib::labeltype> getIDs() const override {
std::vector<hnswlib::labeltype> ids;
ids.reserve(algorithmImpl->label_lookup_.size());

Expand All @@ -506,22 +512,24 @@ class TypedIndex : public Index {
return ids;
}

long long getIDsCount() const { return algorithmImpl->label_lookup_.size(); }
long long getIDsCount() const override {
return algorithmImpl->label_lookup_.size();
}

const std::unordered_map<hnswlib::labeltype, hnswlib::tableint> &
getIDsMap() const {
getIDsMap() const override {
return algorithmImpl->label_lookup_;
}

std::tuple<NDArray<hnswlib::labeltype, 2>, NDArray<dist_t, 2>>
query(std::vector<std::vector<float>> floatQueryVectors, int k = 1,
int numThreads = -1, long queryEf = -1) {
int numThreads = -1, long queryEf = -1) override {
return query(vectorsToNDArray(floatQueryVectors), k, numThreads, queryEf);
}

std::tuple<NDArray<hnswlib::labeltype, 2>, NDArray<dist_t, 2>>
query(NDArray<float, 2> floatQueryVectors, int k = 1, int numThreads = -1,
long queryEf = -1) {
long queryEf = -1) override {
if (queryEf > 0 && queryEf < k) {
throw std::runtime_error("queryEf must be equal to or greater than the "
"requested number of neighbors");
Expand Down Expand Up @@ -635,8 +643,9 @@ class TypedIndex : public Index {
return {labels, distances};
}

std::tuple<std::vector<hnswlib::labeltype>, std::vector<float>>
query(std::vector<float> floatQueryVector, int k = 1, long queryEf = -1) {
std::tuple<std::vector<hnswlib::labeltype>, std::vector<dist_t>>
query(std::vector<float> floatQueryVector, int k = 1,
long queryEf = -1) override {
if (queryEf > 0 && queryEf < k) {
throw std::runtime_error("queryEf must be equal to or greater than the "
"requested number of neighbors");
Expand Down Expand Up @@ -710,32 +719,40 @@ class TypedIndex : public Index {
return {labels, distances};
}

void markDeleted(hnswlib::labeltype label) {
void markDeleted(hnswlib::labeltype label) override {
algorithmImpl->markDelete(label);
}

void unmarkDeleted(hnswlib::labeltype label) {
void unmarkDeleted(hnswlib::labeltype label) override {
algorithmImpl->unmarkDelete(label);
}

void resizeIndex(size_t new_size) { algorithmImpl->resizeIndex(new_size); }
void resizeIndex(size_t new_size) override {
algorithmImpl->resizeIndex(new_size);
}

size_t getMaxElements() const { return algorithmImpl->max_elements_; }
size_t getMaxElements() const override {
return algorithmImpl->max_elements_;
}

size_t getNumElements() const { return algorithmImpl->cur_element_count; }
size_t getNumElements() const override {
return algorithmImpl->cur_element_count;
}

int getEF() const {
int getEF() const override {
if (algorithmImpl)
return algorithmImpl->ef_;
else
return defaultEF;
}

int getNumThreads() { return numThreadsDefault; }
int getNumThreads() override { return numThreadsDefault; }

size_t getEfConstruction() const { return algorithmImpl->ef_construction_; }
size_t getEfConstruction() const override {
return algorithmImpl->ef_construction_;
}

size_t getM() const { return algorithmImpl->M_; }
size_t getM() const override { return algorithmImpl->M_; }
};

std::unique_ptr<Index>
Expand Down
Loading