diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml
index aacbad2ca2..e4b5a43bd4 100644
--- a/java/cuvs-java/pom.xml
+++ b/java/cuvs-java/pom.xml
@@ -58,10 +58,18 @@
runtime
+
+ junit
+ junit
+ 4.13.1
+ test
+
+
- org.junit.jupiter
- junit-jupiter-api
- 5.10.0
+ org.apache.lucene
+ lucene-test-framework
+ 9.12.0
+ test
diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java
index 8ad64d7d69..2c9b51825d 100644
--- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java
+++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java
@@ -117,6 +117,9 @@ private void initializeMethodHandles() throws IOException {
* index
*/
private IndexReference build() throws Throwable {
+ if (dataset == null || dataset.length == 0 || dataset[0].length == 0) {
+ throw new IllegalArgumentException("Dataset cannot be null or empty");
+ }
long rows = dataset.length;
long cols = dataset[0].length;
MemoryLayout layout = resources.linker.canonicalLayouts().get("int");
@@ -138,6 +141,9 @@ private IndexReference build() throws Throwable {
* @return an instance of {@link CagraSearchResults} containing the results
*/
public CagraSearchResults search(CagraQuery query) throws Throwable {
+ if (query.getQueryVectors() == null) {
+ throw new IllegalArgumentException("Query vectors cannot be null");
+ }
long numQueries = query.getQueryVectors().length;
long numBlocks = query.getTopK() * numQueries;
int vectorDimension = numQueries > 0 ? query.getQueryVectors()[0].length : 0;
@@ -168,6 +174,9 @@ public CagraSearchResults search(CagraQuery query) throws Throwable {
* bytes into
*/
public void serialize(OutputStream outputStream) throws Throwable {
+ if (outputStream == null) {
+ throw new IllegalArgumentException("Output stream cannot be null");
+ }
serialize(outputStream, File.createTempFile(UUID.randomUUID().toString(), ".cag"));
}
diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java
index c5788d3427..bfedd646c3 100644
--- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java
+++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java
@@ -16,7 +16,8 @@
package com.nvidia.cuvs;
-import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import static org.junit.Assert.*;
import java.io.File;
import java.io.FileInputStream;
@@ -28,7 +29,7 @@
import java.util.Map;
import java.util.UUID;
-import org.junit.jupiter.api.Test;
+import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -100,14 +101,14 @@ public void testIndexingAndSearchingFlow() throws Throwable {
// Check results
log.info(results.getResults().toString());
- assertEquals(expectedResults, results.getResults(), "Results different than expected");
+ assertEquals("Results different than expected", expectedResults, results.getResults());
// Search from deserialized index
results = loadedIndex.search(cuvsQuery);
// Check results
log.info(results.getResults().toString());
- assertEquals(expectedResults, results.getResults(), "Results different than expected");
+ assertEquals("Results different than expected", expectedResults, results.getResults());
// Cleanup
if (indexFile.exists()) {
diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java
new file mode 100644
index 0000000000..bc0d2f5dbd
--- /dev/null
+++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java
@@ -0,0 +1,321 @@
+package com.nvidia.cuvs;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.InputStream;
+import java.lang.invoke.MethodHandles;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.carrotsearch.randomizedtesting.RandomizedContext;
+
+public class CagraIndexTest extends LuceneTestCase {
+ Random random;
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ @Before
+ public void setup() {
+ this.random = random();
+ log.info("Test seed: " +RandomizedContext.current().getRunnerSeedAsString());
+ }
+
+ @Ignore
+ @Test
+ public void testInvalidDataset() {
+ Throwable exception = assertThrows(IllegalArgumentException.class, () -> {
+ // Use consistent dataset parameters as the working test
+ float[][] invalidDataset = null; // Simulate an invalid dataset
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources)
+ .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
+ .build();
+ new CagraIndex.Builder(resources)
+ .withDataset(invalidDataset)
+ .withIndexParams(indexParams)
+ .build();
+ });
+
+ assertEquals("Dataset cannot be null or empty", exception.getMessage());
+ }
+ @Ignore
+ @Test
+ public void testSerializationWithoutOutputStream() throws Throwable {
+ // Use the same dataset as the working test
+ float[][] dataset = {
+ {0.74021935f, 0.9209938f},
+ {0.03902049f, 0.9689629f},
+ {0.92514056f, 0.4463501f},
+ {0.6673192f, 0.10993068f}
+ };
+
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources)
+ .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
+ .build();
+
+ CagraIndex index = new CagraIndex.Builder(resources)
+ .withDataset(dataset)
+ .withIndexParams(indexParams)
+ .build();
+
+ Throwable exception = assertThrows(IllegalArgumentException.class, () -> {
+ index.serialize(null); // Pass null output stream
+ });
+
+ assertEquals("Output stream cannot be null", exception.getMessage());
+ }
+ @Ignore
+ @Test
+ public void testSingleElementDataset() throws Throwable {
+ // Match dataset and parameters to the working test
+ float[][] dataset = {
+ {0.74021935f, 0.9209938f},
+ {0.03902049f, 0.9689629f},
+ {0.92514056f, 0.4463501f},
+ {0.6673192f, 0.10993068f}
+ };
+
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources)
+ .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
+ .build();
+
+ CagraIndex index = new CagraIndex.Builder(resources)
+ .withDataset(dataset)
+ .withIndexParams(indexParams)
+ .build();
+
+ float[][] query = {
+ {0.48216683f, 0.0428398f},
+ {0.5084142f, 0.6545497f},
+ {0.51260436f, 0.2643005f},
+ {0.05198065f, 0.5789965f}
+ };
+
+ CagraQuery cuvsQuery = new CagraQuery.Builder()
+ .withTopK(3)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build())
+ .withQueryVectors(query)
+ .build();
+
+ CagraSearchResults results = index.search(cuvsQuery);
+
+ // Verify the results size matches the queries
+ assertEquals("Expected one result for each query", query.length, results.getResults().size());
+ }
+ @Ignore
+ @Test
+ public void testSearchResultMapping() throws Throwable {
+ // Match dataset and parameters to the working test
+ float[][] dataset = {
+ {0.74021935f, 0.9209938f},
+ {0.03902049f, 0.9689629f},
+ {0.92514056f, 0.4463501f},
+ {0.6673192f, 0.10993068f}
+ };
+
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources)
+ .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
+ .build();
+
+ CagraIndex index = new CagraIndex.Builder(resources)
+ .withDataset(dataset)
+ .withIndexParams(indexParams)
+ .build();
+
+ // Use consistent query and mapping
+ Map mapping = Map.of(0, 100, 1, 200, 2, 300, 3, 400);
+ float[][] query = {
+ {0.48216683f, 0.0428398f},
+ {0.5084142f, 0.6545497f},
+ {0.51260436f, 0.2643005f},
+ {0.05198065f, 0.5789965f}
+ };
+
+ CagraQuery cuvsQuery = new CagraQuery.Builder()
+ .withTopK(3)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build())
+ .withQueryVectors(query)
+ .withMapping(mapping)
+ .build();
+
+ CagraSearchResults results = index.search(cuvsQuery);
+
+ // Verify mapped results contain expected keys
+ results.getResults().forEach(result -> {
+ assertTrue(result.containsKey(100) || result.containsKey(200) || result.containsKey(300) || result.containsKey(400));
+ });
+ }
+
+ @Test
+ public void testResultsTopKWithRandomValues() throws Throwable {
+ int numRows = random.nextInt(10) + 1; // 1 - 10 rows
+ int numCols = random.nextInt(5) + 1; // 1 - 5 columns
+ float[][] dataset = new float[numRows][numCols];
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ dataset[i][j] = random.nextFloat() * 100;
+ }
+ }
+
+ int numQueries = random.nextInt(5) + 1; // 1 - 5 queries
+ float[][] queries = new float[numQueries][numCols];
+ for (int i = 0; i < numQueries; i++) {
+ for (int j = 0; j < numCols; j++) {
+ queries[i][j] = random.nextFloat() * 100;
+ }
+ }
+
+ int topK = random.nextInt(numRows) + 1;
+ System.out.println("Dataset size: " + numRows + "x" + numCols);
+ System.out.println("Query size: " + numQueries + "x" + numCols);
+ System.out.println("TopK: " + topK);
+
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build();
+ CagraIndex index = new CagraIndex.Builder(resources)
+ .withDataset(dataset)
+ .withIndexParams(indexParams)
+ .build();
+
+ CagraQuery query = new CagraQuery.Builder()
+ .withQueryVectors(queries)
+ .withTopK(topK)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build())
+ .build();
+
+ CagraSearchResults results = index.search(query);
+
+ results.getResults().forEach(result -> {
+ System.out.println("Result size: " + result.size());
+ assertEquals("TopK mismatch for query.", topK, result.size());
+ });
+ }
+
+ @Ignore
+ @Test
+ public void testEmptyResults() throws Throwable {
+ float[][] dataset = {
+ {10.0f, 10.0f},
+ {20.0f, 20.0f}
+ };
+
+ float[][] queries = {
+ {1000.0f, 1000.0f}
+ };
+
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build();
+ CagraIndex index = new CagraIndex.Builder(resources)
+ .withDataset(dataset)
+ .withIndexParams(indexParams)
+ .build();
+
+ CagraQuery query = new CagraQuery.Builder()
+ .withQueryVectors(queries)
+ .withTopK(2)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build())
+ .build();
+
+ CagraSearchResults results = index.search(query);
+ System.out.println(results.getResults());
+
+ // Verify no neighbors were found
+ assertTrue(results.getResults().isEmpty());
+ }
+ @Ignore
+ @Test
+ public void testSearchWithDeletedIndexFile() throws Throwable {
+ // Dataset and Query
+ float[][] dataset = {
+ {0.74021935f, 0.9209938f},
+ {0.03902049f, 0.9689629f},
+ {0.92514056f, 0.4463501f},
+ {0.6673192f, 0.10993068f}
+ };
+
+ float[][] queries = {
+ {0.48216683f, 0.0428398f},
+ {0.5084142f, 0.6545497f}
+ };
+
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources)
+ .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
+ .build();
+
+ // Create and serialize index
+ CagraIndex index = new CagraIndex.Builder(resources)
+ .withDataset(dataset)
+ .withIndexParams(indexParams)
+ .build();
+
+ String indexFileName = UUID.randomUUID().toString() + ".cag";
+ index.serialize(new FileOutputStream(indexFileName));
+
+ // Delete the serialized file
+ File indexFile = new File(indexFileName);
+ if (indexFile.exists()) {
+ indexFile.delete();
+ }
+
+ // Attempt to create an InputStream from the deleted file
+ Throwable exception = assertThrows(Exception.class, () -> {
+ try (InputStream inputStream = new FileInputStream(indexFile)) {
+ CagraIndex deletedIndex = new CagraIndex.Builder(resources)
+ .from(inputStream)
+ .build();
+
+ CagraQuery query = new CagraQuery.Builder()
+ .withTopK(3)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build())
+ .withQueryVectors(queries)
+ .build();
+
+ deletedIndex.search(query);
+ }
+ });
+
+ // Assert the exception type
+ assertTrue("Expected FileNotFoundException", exception instanceof java.io.FileNotFoundException);
+ }
+ @Ignore
+ @Test
+ public void testNullQueryVectors() throws Throwable {
+ float[][] dataset = {
+ {0.74021935f, 0.9209938f},
+ {0.03902049f, 0.9689629f}
+ };
+
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build();
+ CagraIndex index = new CagraIndex.Builder(resources)
+ .withDataset(dataset)
+ .withIndexParams(indexParams)
+ .build();
+
+ CagraQuery invalidQuery = new CagraQuery.Builder()
+ .withQueryVectors(null)
+ .withTopK(3)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build())
+ .build();
+
+ Throwable exception = assertThrows(IllegalArgumentException.class, () -> {
+ index.search(invalidQuery);
+ });
+
+ assertEquals("Query vectors cannot be null", exception.getMessage());
+ }
+
+
+}
diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java
new file mode 100644
index 0000000000..cf58a89065
--- /dev/null
+++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java
@@ -0,0 +1,150 @@
+package com.nvidia.cuvs;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import java.lang.invoke.MethodHandles;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.carrotsearch.randomizedtesting.RandomizedRunner;
+
+@RunWith(RandomizedRunner.class)
+public class CagraRandomizedTest extends CuVSTestCase {
+
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ @Before
+ public void setup() {
+ initializeRandom();
+ log.info("Random context initialized for test.");
+ }
+
+ @Test
+ public void testResultsTopKWithRandomValues() throws Throwable {
+ // Use old-style random generation logic
+ int datasetSize = random.nextInt(400) + 1;
+ int dimensions = random.nextInt(500) + 1;
+ int numQueries = random.nextInt(500) + 1;
+ int topK = random.nextInt(datasetSize) + 1;
+
+ // Generate a random dataset
+ float[][] dataset = new float[datasetSize][dimensions];
+ for (int i = 0; i < datasetSize; i++) {
+ for (int j = 0; j < dimensions; j++) {
+ dataset[i][j] = random.nextFloat() * 100;
+ }
+ }
+
+ // Generate random query vectors
+ float[][] queries = new float[numQueries][dimensions];
+ for (int i = 0; i < numQueries; i++) {
+ for (int j = 0; j < dimensions; j++) {
+ queries[i][j] = random.nextFloat() * 100;
+ }
+ }
+
+ log.info("Dataset size: {}x{}", datasetSize, dimensions);
+ log.info("Query size: {}x{}", numQueries, dimensions);
+ log.info("TopK: {}", topK);
+
+ // Debugging: Log dataset and queries
+ if (log.isDebugEnabled()) {
+ log.debug("Dataset:");
+ for (float[] row : dataset) {
+ log.debug(java.util.Arrays.toString(row));
+ }
+
+ log.debug("Queries:");
+ for (float[] query : queries) {
+ log.debug(java.util.Arrays.toString(query));
+ }
+ }
+
+ // Sanity checks
+ assert dataset.length > 0 : "Dataset is empty.";
+ assert queries.length > 0 : "Queries are empty.";
+ assert dimensions > 0 : "Invalid dimensions.";
+ assert topK > 0 && topK <= datasetSize : "Invalid topK value.";
+
+ // Generate expected results using brute force
+ List> expected = generateExpectedResults(topK, dataset, queries);
+
+ // Create CuVS index and query
+ CuVSResources resources = new CuVSResources();
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build();
+ CagraIndex index = new CagraIndex.Builder(resources).withDataset(dataset).withIndexParams(indexParams).build();
+
+ log.info("Index built successfully.");
+
+ CagraQuery query = new CagraQuery.Builder()
+ .withQueryVectors(queries)
+ .withTopK(topK)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build())
+ .build();
+
+ log.info("Query built successfully. Executing search...");
+
+ // Execute search and retrieve results
+ CagraSearchResults results = index.search(query);
+
+ // actual vs. expected results
+ for (int i = 0; i < results.getResults().size(); i++) {
+ Map result = results.getResults().get(i);
+ log.info("Actual result for query {}: {}", i, result.keySet());
+ log.info("Expected result for query {}: {}", i, expected.get(i));
+
+ assertEquals("TopK mismatch for query.", Math.min(topK, datasetSize), result.size());
+
+ // Sort result by values (distances) and extract keys
+ List sortedResultKeys = result.entrySet().stream()
+ .sorted(Map.Entry.comparingByValue()) // Sort by value (distance)
+ .map(Map.Entry::getKey) // Extract sorted keys
+ .toList();
+
+ log.info("Sorted Actual result for query {}: {}", i, sortedResultKeys);
+
+ // Compare using primitive int arrays
+ assertArrayEquals(
+ "Query " + i + " mismatched",
+ expected.get(i).stream().mapToInt(Integer::intValue).toArray(),
+ sortedResultKeys.stream().mapToInt(Integer::intValue).toArray()
+ );
+ }
+
+ }
+
+ private List> generateExpectedResults(int topK, float[][] dataset, float[][] queries) {
+ List> neighborsResult = new ArrayList<>();
+ int dimensions = dataset[0].length;
+
+ for (float[] query : queries) {
+ Map distances = new TreeMap<>();
+ for (int j = 0; j < dataset.length; j++) {
+ double distance = 0;
+ for (int k = 0; k < dimensions; k++) {
+ distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]);
+ }
+ distances.put(j, Math.sqrt(distance));
+ }
+
+ // Sort by distance and select the topK nearest neighbors
+ List neighbors = distances.entrySet().stream()
+ .sorted(Map.Entry.comparingByValue())
+ .map(Map.Entry::getKey)
+ .toList();
+ neighborsResult.add(neighbors.subList(0, Math.min(topK, dataset.length)));
+ }
+
+ log.info("Expected results generated successfully.");
+ return neighborsResult;
+ }
+}
diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java
new file mode 100644
index 0000000000..a07cb4ebfe
--- /dev/null
+++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java
@@ -0,0 +1,19 @@
+package com.nvidia.cuvs;
+
+import java.lang.invoke.MethodHandles;
+import java.util.Random;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.carrotsearch.randomizedtesting.RandomizedContext;
+
+public abstract class CuVSTestCase {
+ protected Random random;
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ protected void initializeRandom() {
+ random = RandomizedContext.current().getRandom();
+ log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString());
+ }
+}