Skip to content

Java: Fix for a crash situation when using different threads (#680) #705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 117 additions & 62 deletions java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,38 @@

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.lang.invoke.MethodHandles;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.carrotsearch.randomizedtesting.RandomizedRunner;
import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo;
import com.nvidia.cuvs.CagraIndexParams.CuvsDistanceType;

@RunWith(RandomizedRunner.class)
public class CagraBuildAndSearchIT extends CuVSTestCase {

private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

@Before
public void setup() {
assumeTrue("not supported on " + System.getProperty("os.name"), isLinuxAmd64());
initializeRandom();
log.info("Random context initialized for test.");
}

/**
Expand All @@ -60,14 +69,14 @@ public void testIndexingAndSearchingFlow() throws Throwable {
{ 0.03902049f, 0.9689629f },
{ 0.92514056f, 0.4463501f },
{ 0.6673192f, 0.10993068f }
};
};
List<Integer> map = List.of(0, 1, 2, 3);
float[][] queries = {
{ 0.48216683f, 0.0428398f },
{ 0.5084142f, 0.6545497f },
{ 0.51260436f, 0.2643005f },
{ 0.05198065f, 0.5789965f }
};
};

// Expected search results
List<Map<Integer, Float>> expectedResults = Arrays.asList(
Expand All @@ -76,68 +85,114 @@ public void testIndexingAndSearchingFlow() throws Throwable {
Map.of(3, 0.047766715f, 2, 0.20332818f, 0, 0.48305473f),
Map.of(1, 0.15224178f, 0, 0.59063464f, 3, 0.5986642f));

for (int j = 0; j < 10; j++) {

try (CuVSResources resources = CuVSResources.create()) {

// Configure index parameters
CagraIndexParams indexParams = new CagraIndexParams.Builder()
.withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT)
.withGraphDegree(1)
.withIntermediateGraphDegree(2)
.withNumWriterThreads(32)
.withMetric(CuvsDistanceType.L2Expanded)
.build();

// Create the index with the dataset
CagraIndex index = CagraIndex.newBuilder(resources)
.withDataset(dataset)
.withIndexParams(indexParams)
.build();

// Saving the index on to the disk.
String indexFileName = UUID.randomUUID().toString() + ".cag";
index.serialize(new FileOutputStream(indexFileName));

// Loading a CAGRA index from disk.
File indexFile = new File(indexFileName);
InputStream inputStream = new FileInputStream(indexFile);
CagraIndex loadedIndex = CagraIndex.newBuilder(resources)
.from(inputStream)
.build();

// Configure search parameters
CagraSearchParams searchParams = new CagraSearchParams.Builder(resources)
.build();

// Create a query object with the query vectors
CagraQuery cuvsQuery = new CagraQuery.Builder()
.withTopK(3)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(map)
.build();

// Perform the search
SearchResults results = index.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Search from deserialized index
results = loadedIndex.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Cleanup
if (indexFile.exists()) {
indexFile.delete();
int numTestsRuns = 10;

try (CuVSResources resources = CuVSResources.create()) {
// sometimes run this test using different threads?
boolean runTestInDifferentThreads = random.nextBoolean();
// if running in different threads, run concurrently or one after the other?
boolean runConcurrently = runTestInDifferentThreads ? random.nextBoolean(): false;

log.info("Running in different threads? " + runTestInDifferentThreads);
log.info("Running concurrently? " + runConcurrently);

ExecutorService parallelExecutor = runConcurrently ? Executors.newFixedThreadPool(numTestsRuns): null;

for (int j = 0; j < numTestsRuns; j++) {
Runnable testLogic = indexAndQueryOnce(dataset, map, queries, expectedResults, resources);
if (runTestInDifferentThreads) {
if (runConcurrently) {
parallelExecutor.submit(testLogic);
} else {
ExecutorService singleExecutor = Executors.newSingleThreadExecutor();
singleExecutor.submit(testLogic);
singleExecutor.shutdown();
singleExecutor.awaitTermination(2000, TimeUnit.SECONDS);
}
} else {
// run the test logic in the main thread
testLogic.run();
}
index.destroyIndex();
}
if (parallelExecutor != null) {
parallelExecutor.shutdown();
parallelExecutor.awaitTermination(2000, TimeUnit.SECONDS);
}

}
}

private Runnable indexAndQueryOnce(float[][] dataset, List<Integer> map, float[][] queries,
List<Map<Integer, Float>> expectedResults, CuVSResources resources) throws Throwable, FileNotFoundException {

Runnable thread = new Runnable() {

@Override
public void run() {
try {

// Configure index parameters
CagraIndexParams indexParams = new CagraIndexParams.Builder()
.withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT)
.withGraphDegree(1)
.withIntermediateGraphDegree(2)
.withNumWriterThreads(32)
.withMetric(CuvsDistanceType.L2Expanded)
.build();

// Create the index with the dataset
CagraIndex index = CagraIndex.newBuilder(resources)
.withDataset(dataset)
.withIndexParams(indexParams)
.build();

// Saving the index on to the disk.
String indexFileName = UUID.randomUUID().toString() + ".cag";
index.serialize(new FileOutputStream(indexFileName));

// Loading a CAGRA index from disk.
File indexFile = new File(indexFileName);
InputStream inputStream = new FileInputStream(indexFile);
CagraIndex loadedIndex = CagraIndex.newBuilder(resources)
.from(inputStream)
.build();

// Configure search parameters
CagraSearchParams searchParams = new CagraSearchParams.Builder(resources)
.build();

// Create a query object with the query vectors
CagraQuery cuvsQuery = new CagraQuery.Builder()
.withTopK(3)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(map)
.build();

// Perform the search
SearchResults results = index.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Search from deserialized index
results = loadedIndex.search(cuvsQuery);

// Check results
log.info(results.getResults().toString());
assertEquals(expectedResults, results.getResults());

// Cleanup
if (indexFile.exists()) {
indexFile.delete();
}
index.destroyIndex();
} catch (Throwable ex) {
throw new RuntimeException("Exception during indexing/querying", ex);
}
}
};
return thread;
}
}
2 changes: 0 additions & 2 deletions java/internal/src/cuvs_java.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ cuvsCagraIndex_t build_cagra_index(float *dataset, long rows, long dimensions, c
cuvsStreamGet(cuvs_resources, &stream);

omp_set_num_threads(n_writer_threads);
cuvsRMMPoolMemoryResourceEnable(95, 95, false);

int64_t dataset_shape[2] = {rows, dimensions};
DLManagedTensor dataset_tensor = prepare_tensor(dataset, dataset_shape, kDLFloat, 32, 2, kDLCUDA);
Expand Down Expand Up @@ -226,7 +225,6 @@ cuvsBruteForceIndex_t build_brute_force_index(float *dataset, long rows, long di
int *return_value, int n_writer_threads) {

omp_set_num_threads(n_writer_threads);
cuvsRMMPoolMemoryResourceEnable(95, 95, false);

cudaStream_t stream;
cuvsStreamGet(cuvs_resources, &stream);
Expand Down
Loading