Skip to content

[Fix] Various fixes for 25.02.01 point release #695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ struct index : cuvs::neighbors::index {
using search_params_type = cagra::search_params;
using index_type = IdxT;
using value_type = T;
using dataset_index_type = int64_t;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is where @achirkin is suggesting not to hard code the type. Ideally we should use a template for this so that it can be propagated outside of this class (and not hardcoded within it).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this case it's fine; this says "the dataset member of the cagra index uses int64_t as the indexing type", so one can argue it belongs to the index, and it's also an implementation detail of the cagra index. Before this, my problem was that it was hardcoded in two different places with no compile-time relation between those (inside the cagra index and in the merge function).


static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");

Expand Down Expand Up @@ -510,14 +512,14 @@ struct index : cuvs::neighbors::index {
*/
template <typename DatasetT>
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<int64_t>, DatasetT>>
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<dataset_index_type>, DatasetT>>
{
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<dataset_index_type>, DatasetT>>
{
dataset_ = std::move(dataset);
}
Expand Down Expand Up @@ -561,7 +563,7 @@ struct index : cuvs::neighbors::index {
cuvs::distance::DistanceType metric_;
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
std::unique_ptr<neighbors::dataset<dataset_index_type>> dataset_;
};
/**
* @}
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/cagra_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ index<T, IdxT> merge(raft::resources const& handle,
const cagra::merge_params& params,
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*>& indices)
{
using cagra_index_t = cuvs::neighbors::cagra::index<T, IdxT>;
using ds_idx_type = typename cagra_index_t::dataset_index_type;

std::size_t dim = 0;
std::size_t new_dataset_size = 0;
int64_t stride = -1;

for (auto index : indices) {
for (cagra_index_t* index : indices) {
RAFT_EXPECTS(index != nullptr,
"Null pointer detected in 'indices'. Ensure all elements are valid before usage.");
using ds_idx_type = decltype(index->data().n_rows());
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());
strided_dset != nullptr) {
if (dim == 0) {
Expand All @@ -74,8 +76,7 @@ index<T, IdxT> merge(raft::resources const& handle,
IdxT offset = 0;

auto merge_dataset = [&](T* dst) {
for (auto index : indices) {
using ds_idx_type = decltype(index->data().n_rows());
for (cagra_index_t* index : indices) {
auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());

RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst + offset * dim,
Expand Down
24 changes: 16 additions & 8 deletions cpp/src/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1047,24 +1047,32 @@ void GnndGraph<Index_t>::init_random_graph()
for (size_t seg_idx = 0; seg_idx < static_cast<size_t>(num_segments); seg_idx++) {
// random sequence (range: 0~nrow)
// segment_x stores neighbors which id % num_segments == x
std::vector<Index_t> rand_seq(nrow / num_segments);
std::vector<Index_t> rand_seq((nrow + num_segments - 1) / num_segments);
std::iota(rand_seq.begin(), rand_seq.end(), 0);
auto gen = std::default_random_engine{seg_idx};
std::shuffle(rand_seq.begin(), rand_seq.end(), gen);

#pragma omp parallel for
for (size_t i = 0; i < nrow; i++) {
size_t base_idx = i * node_degree + seg_idx * segment_size;
auto h_neighbor_list = h_graph + base_idx;
auto h_dist_list = h_dists.data_handle() + base_idx;
size_t base_idx = i * node_degree + seg_idx * segment_size;
auto h_neighbor_list = h_graph + base_idx;
auto h_dist_list = h_dists.data_handle() + base_idx;
size_t idx = base_idx;
size_t self_in_this_seg = 0;
for (size_t j = 0; j < static_cast<size_t>(segment_size); j++) {
size_t idx = base_idx + j;
Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx;
if ((size_t)id == i) {
id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx;
idx++;
id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx;
self_in_this_seg = 1;
}
h_neighbor_list[j].id_with_flag() = id;
h_dist_list[j] = std::numeric_limits<DistData_t>::max();

h_neighbor_list[j].id_with_flag() =
j < (rand_seq.size() - self_in_this_seg) && size_t(id) < nrow
? id
: std::numeric_limits<Index_t>::max();
h_dist_list[j] = std::numeric_limits<DistData_t>::max();
idx++;
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions cpp/tests/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,12 @@ class AnnCagraIndexMergeTest : public ::testing::TestWithParam<AnnCagraInputs> {
(ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows))
GTEST_SKIP();

// Avoid splitting datasets with a size of 0
if (ps.n_rows <= 3) GTEST_SKIP();

// IVF_PQ requires the `n_rows >= n_lists`.
if (ps.n_rows < 8 && ps.build_algo == graph_build_algo::IVF_PQ) GTEST_SKIP();

size_t queries_size = ps.n_queries * ps.k;
std::vector<IdxT> indices_Cagra(queries_size);
std::vector<IdxT> indices_naive(queries_size);
Expand Down Expand Up @@ -1161,6 +1167,24 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// Corner cases for small datasets
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{2},
{3, 5, 31, 32, 64, 101},
{1, 10},
{2}, // k
{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{0}, // query size
{0},
{256},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// Varying dim and build algo.
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
Expand Down
43 changes: 30 additions & 13 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.cuvs;

import java.util.Arrays;
import java.util.BitSet;
import java.util.List;

/**
Expand All @@ -28,7 +29,8 @@ public class BruteForceQuery {

private List<Integer> mapping;
private float[][] queryVectors;
private long[] prefilter;
private BitSet[] prefilters;
private int numDocs = -1;
private int topK;

/**
Expand All @@ -40,12 +42,15 @@ public class BruteForceQuery {
* @param topK the top k results to return
* @param prefilter the prefilter data to use while searching the BRUTEFORCE
* index
* @param numDocs Maximum of bits in each prefilter, representing number of documents in this index.
* Used only when prefilter(s) is/are passed.
*/
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, long[] prefilter) {
public BruteForceQuery(float[][] queryVectors, List<Integer> mapping, int topK, BitSet[] prefilters, int numDocs) {
this.queryVectors = queryVectors;
this.mapping = mapping;
this.topK = topK;
this.prefilter = prefilter;
this.prefilters = prefilters;
this.numDocs = numDocs;
}

/**
Expand Down Expand Up @@ -78,16 +83,25 @@ public int getTopK() {
/**
* Gets the prefilter long array
*
* @return a long array
* @return an array of bitsets
*/
public long[] getPrefilter() {
return prefilter;
public BitSet[] getPrefilters() {
return prefilters;
}

/**
* Gets the number of documents supposed to be in this index, as used for prefilters
*
* @return number of documents as an integer
*/
public int getNumDocs() {
return numDocs;
}

@Override
public String toString() {
return "BruteForceQuery [mapping=" + mapping + ", queryVectors=" + Arrays.toString(queryVectors) + ", prefilter="
+ Arrays.toString(prefilter) + ", topK=" + topK + "]";
+ Arrays.toString(prefilters) + ", topK=" + topK + "]";
}

/**
Expand All @@ -96,7 +110,8 @@ public String toString() {
public static class Builder {

private float[][] queryVectors;
private long[] prefilter;
private BitSet[] prefilters;
private int numDocs;
private List<Integer> mapping;
private int topK = 2;

Expand Down Expand Up @@ -134,13 +149,15 @@ public Builder withTopK(int topK) {
}

/**
* Sets the prefilter data for building the {@link BruteForceQuery}.
* Sets the prefilters data for building the {@link BruteForceQuery}.
*
* @param prefilter a one-dimensional long array
* @param prefilters array of bitsets, as many as queries, each containing as
* many bits as there are vectors in the index
* @return an instance of this Builder
*/
public Builder withPrefilter(long[] prefilter) {
this.prefilter = prefilter;
public Builder withPrefilter(BitSet[] prefilters, int numDocs) {
this.prefilters = prefilters;
this.numDocs = numDocs;
return this;
}

Expand All @@ -150,7 +167,7 @@ public Builder withPrefilter(long[] prefilter) {
* @return an instance of {@link BruteForceQuery}
*/
public BruteForceQuery build() {
return new BruteForceQuery(queryVectors, mapping, topK, prefilter);
return new BruteForceQuery(queryVectors, mapping, topK, prefilters, numDocs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Objects;
import java.util.UUID;

Expand Down Expand Up @@ -59,7 +63,7 @@ public class BruteForceIndexImpl implements BruteForceIndex{
FunctionDescriptor.of(ADDRESS, ADDRESS, C_LONG, C_LONG, ADDRESS, ADDRESS, C_INT));

private static final MethodHandle searchMethodHandle = downcallHandle("search_brute_force_index",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG, C_LONG));
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG));

private static final MethodHandle destroyIndexMethodHandle = downcallHandle("destroy_brute_force_index",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS));
Expand Down Expand Up @@ -169,16 +173,24 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
long numQueries = cuvsQuery.getQueryVectors().length;
long numBlocks = cuvsQuery.getTopK() * numQueries;
int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0;
long prefilterDataLength = cuvsQuery.getPrefilter() != null ? cuvsQuery.getPrefilter().length : 0;
long numRows = dataset != null ? dataset.length : 0;

SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_LONG);
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_FLOAT);
MemorySegment neighborsMemorySegment = resources.getArena().allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = resources.getArena().allocate(distancesSequenceLayout);
MemorySegment prefilterDataMemorySegment = cuvsQuery.getPrefilter() != null
? Util.buildMemorySegment(resources.getArena(), cuvsQuery.getPrefilter())
: MemorySegment.NULL;

// prepare the prefiltering data
long prefilterDataLength = 0;
MemorySegment prefilterDataMemorySegment = MemorySegment.NULL;
BitSet[] prefilters = cuvsQuery.getPrefilters();
if (prefilters != null && prefilters.length > 0) {
BitSet concatenatedFilters = Util.concatenate(prefilters, cuvsQuery.getNumDocs());
long filters[] = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = Util.buildMemorySegment(resources.getArena(), filters);
prefilterDataLength = cuvsQuery.getNumDocs() * prefilters.length;
}

MemorySegment querySeg = Util.buildMemorySegment(resources.getArena(), cuvsQuery.getQueryVectors());
try (var localArena = Arena.ofConfined()) {
MemorySegment returnValue = localArena.allocate(C_INT);
Expand All @@ -193,7 +205,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
distancesMemorySegment,
returnValue,
prefilterDataMemorySegment,
prefilterDataLength, numRows
prefilterDataLength
);
checkError(returnValue.get(C_INT, 0L), "searchMethodHandle");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;

import com.nvidia.cuvs.GPUInfo;
Expand Down Expand Up @@ -184,6 +186,14 @@ public static MemorySegment buildMemorySegment(Arena arena, long[] data) {
return dataMemorySegment;
}

public static MemorySegment buildMemorySegment(Arena arena, byte[] data) {
int cells = data.length;
MemoryLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, C_CHAR);
MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
MemorySegment.copy(data, 0, dataMemorySegment, C_CHAR, 0, cells);
return dataMemorySegment;
}

/**
* A utility method for building a {@link MemorySegment} for a 2D float array.
*
Expand All @@ -201,4 +211,20 @@ public static MemorySegment buildMemorySegment(Arena arena, float[][] data) {
}
return dataMemorySegment;
}

public static BitSet concatenate(BitSet[] arr, int maxSizeOfEachBitSet) {
BitSet ret = new BitSet(maxSizeOfEachBitSet * arr.length);
for (int i = 0; i < arr.length; i++) {
BitSet b = arr[i];
if (b == null || b.length() == 0) {
ret.set(i * maxSizeOfEachBitSet, (i + 1) * maxSizeOfEachBitSet);
} else {
for (int j = 0; j < maxSizeOfEachBitSet; j++) {
ret.set(i * maxSizeOfEachBitSet + j, b.get(j));
}
}
}
return ret;
}

}
Loading
Loading