diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml index aacbad2ca2..e4b5a43bd4 100644 --- a/java/cuvs-java/pom.xml +++ b/java/cuvs-java/pom.xml @@ -58,10 +58,18 @@ runtime + + junit + junit + 4.13.1 + test + + - org.junit.jupiter - junit-jupiter-api - 5.10.0 + org.apache.lucene + lucene-test-framework + 9.12.0 + test diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java index 8ad64d7d69..2c9b51825d 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java @@ -117,6 +117,9 @@ private void initializeMethodHandles() throws IOException { * index */ private IndexReference build() throws Throwable { + if (dataset == null || dataset.length == 0 || dataset[0].length == 0) { + throw new IllegalArgumentException("Dataset cannot be null or empty"); + } long rows = dataset.length; long cols = dataset[0].length; MemoryLayout layout = resources.linker.canonicalLayouts().get("int"); @@ -138,6 +141,9 @@ private IndexReference build() throws Throwable { * @return an instance of {@link CagraSearchResults} containing the results */ public CagraSearchResults search(CagraQuery query) throws Throwable { + if (query.getQueryVectors() == null) { + throw new IllegalArgumentException("Query vectors cannot be null"); + } long numQueries = query.getQueryVectors().length; long numBlocks = query.getTopK() * numQueries; int vectorDimension = numQueries > 0 ? query.getQueryVectors()[0].length : 0; @@ -168,6 +174,9 @@ public CagraSearchResults search(CagraQuery query) throws Throwable { * bytes into */ public void serialize(OutputStream outputStream) throws Throwable { + if (outputStream == null) { + throw new IllegalArgumentException("Output stream cannot be null"); + } serialize(outputStream, File.createTempFile(UUID.randomUUID().toString(), ".cag")); } diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java index c5788d3427..bfedd646c3 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java @@ -16,7 +16,8 @@ package com.nvidia.cuvs; -import static org.junit.jupiter.api.Assertions.assertEquals; + +import static org.junit.Assert.*; import java.io.File; import java.io.FileInputStream; @@ -28,7 +29,7 @@ import java.util.Map; import java.util.UUID; -import org.junit.jupiter.api.Test; +import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,14 +101,14 @@ public void testIndexingAndSearchingFlow() throws Throwable { // Check results log.info(results.getResults().toString()); - assertEquals(expectedResults, results.getResults(), "Results different than expected"); + assertEquals("Results different than expected", expectedResults, results.getResults()); // Search from deserialized index results = loadedIndex.search(cuvsQuery); // Check results log.info(results.getResults().toString()); - assertEquals(expectedResults, results.getResults(), "Results different than expected"); + assertEquals("Results different than expected", expectedResults, results.getResults()); // Cleanup if (indexFile.exists()) { diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java new file mode 100644 index 0000000000..bc0d2f5dbd --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java @@ -0,0 +1,321 @@ +package com.nvidia.cuvs; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.lang.invoke.MethodHandles; +import java.util.Map; +import java.util.Random; +import java.util.UUID; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.carrotsearch.randomizedtesting.RandomizedContext; + +public class CagraIndexTest extends LuceneTestCase { + Random random; + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Before + public void setup() { + this.random = random(); + log.info("Test seed: " +RandomizedContext.current().getRunnerSeedAsString()); + } + + @Ignore + @Test + public void testInvalidDataset() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + // Use consistent dataset parameters as the working test + float[][] invalidDataset = null; // Simulate an invalid dataset + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + new CagraIndex.Builder(resources) + .withDataset(invalidDataset) + .withIndexParams(indexParams) + .build(); + }); + + assertEquals("Dataset cannot be null or empty", exception.getMessage()); + } + @Ignore + @Test + public void testSerializationWithoutOutputStream() throws Throwable { + // Use the same dataset as the working test + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + index.serialize(null); // Pass null output stream + }); + + assertEquals("Output stream cannot be null", exception.getMessage()); + } + @Ignore + @Test + public void testSingleElementDataset() throws Throwable { + // Match dataset and parameters to the working test + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + float[][] query = { + {0.48216683f, 0.0428398f}, + {0.5084142f, 0.6545497f}, + {0.51260436f, 0.2643005f}, + {0.05198065f, 0.5789965f} + }; + + CagraQuery cuvsQuery = new CagraQuery.Builder() + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .withQueryVectors(query) + .build(); + + CagraSearchResults results = index.search(cuvsQuery); + + // Verify the results size matches the queries + assertEquals("Expected one result for each query", query.length, results.getResults().size()); + } + @Ignore + @Test + public void testSearchResultMapping() throws Throwable { + // Match dataset and parameters to the working test + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + // Use consistent query and mapping + Map mapping = Map.of(0, 100, 1, 200, 2, 300, 3, 400); + float[][] query = { + {0.48216683f, 0.0428398f}, + {0.5084142f, 0.6545497f}, + {0.51260436f, 0.2643005f}, + {0.05198065f, 0.5789965f} + }; + + CagraQuery cuvsQuery = new CagraQuery.Builder() + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .withQueryVectors(query) + .withMapping(mapping) + .build(); + + CagraSearchResults results = index.search(cuvsQuery); + + // Verify mapped results contain expected keys + results.getResults().forEach(result -> { + assertTrue(result.containsKey(100) || result.containsKey(200) || result.containsKey(300) || result.containsKey(400)); + }); + } + + @Test + public void testResultsTopKWithRandomValues() throws Throwable { + int numRows = random.nextInt(10) + 1; // 1 - 10 rows + int numCols = random.nextInt(5) + 1; // 1 - 5 columns + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + int numQueries = random.nextInt(5) + 1; // 1 - 5 queries + float[][] queries = new float[numQueries][numCols]; + for (int i = 0; i < numQueries; i++) { + for (int j = 0; j < numCols; j++) { + queries[i][j] = random.nextFloat() * 100; + } + } + + int topK = random.nextInt(numRows) + 1; + System.out.println("Dataset size: " + numRows + "x" + numCols); + System.out.println("Query size: " + numQueries + "x" + numCols); + System.out.println("TopK: " + topK); + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(topK) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + CagraSearchResults results = index.search(query); + + results.getResults().forEach(result -> { + System.out.println("Result size: " + result.size()); + assertEquals("TopK mismatch for query.", topK, result.size()); + }); + } + + @Ignore + @Test + public void testEmptyResults() throws Throwable { + float[][] dataset = { + {10.0f, 10.0f}, + {20.0f, 20.0f} + }; + + float[][] queries = { + {1000.0f, 1000.0f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(2) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + CagraSearchResults results = index.search(query); + System.out.println(results.getResults()); + + // Verify no neighbors were found + assertTrue(results.getResults().isEmpty()); + } + @Ignore + @Test + public void testSearchWithDeletedIndexFile() throws Throwable { + // Dataset and Query + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + float[][] queries = { + {0.48216683f, 0.0428398f}, + {0.5084142f, 0.6545497f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + // Create and serialize index + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + String indexFileName = UUID.randomUUID().toString() + ".cag"; + index.serialize(new FileOutputStream(indexFileName)); + + // Delete the serialized file + File indexFile = new File(indexFileName); + if (indexFile.exists()) { + indexFile.delete(); + } + + // Attempt to create an InputStream from the deleted file + Throwable exception = assertThrows(Exception.class, () -> { + try (InputStream inputStream = new FileInputStream(indexFile)) { + CagraIndex deletedIndex = new CagraIndex.Builder(resources) + .from(inputStream) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .withQueryVectors(queries) + .build(); + + deletedIndex.search(query); + } + }); + + // Assert the exception type + assertTrue("Expected FileNotFoundException", exception instanceof java.io.FileNotFoundException); + } + @Ignore + @Test + public void testNullQueryVectors() throws Throwable { + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery invalidQuery = new CagraQuery.Builder() + .withQueryVectors(null) + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + index.search(invalidQuery); + }); + + assertEquals("Query vectors cannot be null", exception.getMessage()); + } + + +} diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java new file mode 100644 index 0000000000..cf58a89065 --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java @@ -0,0 +1,150 @@ +package com.nvidia.cuvs; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.carrotsearch.randomizedtesting.RandomizedRunner; + +@RunWith(RandomizedRunner.class) +public class CagraRandomizedTest extends CuVSTestCase { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Before + public void setup() { + initializeRandom(); + log.info("Random context initialized for test."); + } + + @Test + public void testResultsTopKWithRandomValues() throws Throwable { + // Use old-style random generation logic + int datasetSize = random.nextInt(400) + 1; + int dimensions = random.nextInt(500) + 1; + int numQueries = random.nextInt(500) + 1; + int topK = random.nextInt(datasetSize) + 1; + + // Generate a random dataset + float[][] dataset = new float[datasetSize][dimensions]; + for (int i = 0; i < datasetSize; i++) { + for (int j = 0; j < dimensions; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + // Generate random query vectors + float[][] queries = new float[numQueries][dimensions]; + for (int i = 0; i < numQueries; i++) { + for (int j = 0; j < dimensions; j++) { + queries[i][j] = random.nextFloat() * 100; + } + } + + log.info("Dataset size: {}x{}", datasetSize, dimensions); + log.info("Query size: {}x{}", numQueries, dimensions); + log.info("TopK: {}", topK); + + // Debugging: Log dataset and queries + if (log.isDebugEnabled()) { + log.debug("Dataset:"); + for (float[] row : dataset) { + log.debug(java.util.Arrays.toString(row)); + } + + log.debug("Queries:"); + for (float[] query : queries) { + log.debug(java.util.Arrays.toString(query)); + } + } + + // Sanity checks + assert dataset.length > 0 : "Dataset is empty."; + assert queries.length > 0 : "Queries are empty."; + assert dimensions > 0 : "Invalid dimensions."; + assert topK > 0 && topK <= datasetSize : "Invalid topK value."; + + // Generate expected results using brute force + List> expected = generateExpectedResults(topK, dataset, queries); + + // Create CuVS index and query + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources).withDataset(dataset).withIndexParams(indexParams).build(); + + log.info("Index built successfully."); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(topK) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + log.info("Query built successfully. Executing search..."); + + // Execute search and retrieve results + CagraSearchResults results = index.search(query); + + // actual vs. expected results + for (int i = 0; i < results.getResults().size(); i++) { + Map result = results.getResults().get(i); + log.info("Actual result for query {}: {}", i, result.keySet()); + log.info("Expected result for query {}: {}", i, expected.get(i)); + + assertEquals("TopK mismatch for query.", Math.min(topK, datasetSize), result.size()); + + // Sort result by values (distances) and extract keys + List sortedResultKeys = result.entrySet().stream() + .sorted(Map.Entry.comparingByValue()) // Sort by value (distance) + .map(Map.Entry::getKey) // Extract sorted keys + .toList(); + + log.info("Sorted Actual result for query {}: {}", i, sortedResultKeys); + + // Compare using primitive int arrays + assertArrayEquals( + "Query " + i + " mismatched", + expected.get(i).stream().mapToInt(Integer::intValue).toArray(), + sortedResultKeys.stream().mapToInt(Integer::intValue).toArray() + ); + } + + } + + private List> generateExpectedResults(int topK, float[][] dataset, float[][] queries) { + List> neighborsResult = new ArrayList<>(); + int dimensions = dataset[0].length; + + for (float[] query : queries) { + Map distances = new TreeMap<>(); + for (int j = 0; j < dataset.length; j++) { + double distance = 0; + for (int k = 0; k < dimensions; k++) { + distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]); + } + distances.put(j, Math.sqrt(distance)); + } + + // Sort by distance and select the topK nearest neighbors + List neighbors = distances.entrySet().stream() + .sorted(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .toList(); + neighborsResult.add(neighbors.subList(0, Math.min(topK, dataset.length))); + } + + log.info("Expected results generated successfully."); + return neighborsResult; + } +} diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java new file mode 100644 index 0000000000..a07cb4ebfe --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java @@ -0,0 +1,19 @@ +package com.nvidia.cuvs; + +import java.lang.invoke.MethodHandles; +import java.util.Random; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.carrotsearch.randomizedtesting.RandomizedContext; + +public abstract class CuVSTestCase { + protected Random random; + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + protected void initializeRandom() { + random = RandomizedContext.current().getRandom(); + log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString()); + } +}