From 11d876cc9dcba3a76cbdb2af2bddadbd076ec7e0 Mon Sep 17 00:00:00 2001 From: Ted Willke Date: Fri, 19 Jun 2026 05:12:34 +0000 Subject: [PATCH] Permanently disable pruning while retaining non-zero threshold search behavior. --- .../jbellis/jvector/graph/GraphSearcher.java | 24 +- .../jbellis/jvector/graph/ScoreTracker.java | 29 +-- .../graph/TestLowCardinalityFiltering.java | 4 +- .../graph/TestPruningCompatibility.java | 212 ++++++++++++++++++ 4 files changed, 240 insertions(+), 29 deletions(-) create mode 100644 jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestPruningCompatibility.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 73cc5fbd5..7ca4c4c06 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -86,7 +86,7 @@ protected GraphSearcher(ImmutableGraphIndex.View view) { this.rerankedResults = new NodeQueue(new BoundedLongHeap(100), NodeQueue.Order.MIN_HEAP); this.visited = new IntHashSet(); - this.pruneSearch = true; + this.pruneSearch = false; this.scoreTrackerFactory = new ScoreTracker.ScoreTrackerFactory(); } @@ -117,12 +117,16 @@ public ImmutableGraphIndex.View getView() { } /** - * When using pruning, we are using a heuristic to terminate the search earlier. - * In certain cases, it can lead to speedups. This is set to false by default. - * @param usage a boolean that determines whether we do early termination or not. + * @deprecated TopK and filtered graph-search pruning is disabled because the + * existing heuristic can reduce recall and has not shown reliable production + * value. This method is retained for API compatibility and has no effect. + * + * Threshold searches, where {@code threshold > 0}, continue to use their + * legacy threshold early-termination behavior. */ + @Deprecated public void usePruning(boolean usage) { - pruneSearch = usage; + pruneSearch = false; } /** @@ -339,14 +343,17 @@ void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bi private boolean stopSearch(NodeQueue localCandidates, ScoreTracker scoreTracker, int rerankK, float threshold) { float topCandidateScore = localCandidates.topScore(); + // we're done when we have K results and the best candidate is worse than the worst result so far if (approximateResults.size() >= rerankK && topCandidateScore < approximateResults.topScore()) { return true; } - // when querying by threshold, also stop when we are probabilistically unlikely to find more qualifying results + + // preserve legacy threshold early termination if (threshold > 0 && scoreTracker.shouldStop()) { return true; } + return false; } @@ -394,8 +401,9 @@ void searchOneLayer(SearchScoreProvider scoreProvider, assert approximateResults.size() == 0; // should be cleared by setEntryPointsFromPreviousLayer approximateResults.setMaxSize(rerankK); - // track scores to predict when we are done with threshold queries - var scoreTracker = scoreTrackerFactory.getScoreTracker(pruneSearch, rerankK, threshold); + // TopK and filtered pruning are disabled. Threshold searches retain their + // legacy threshold early-termination path inside ScoreTrackerFactory. + var scoreTracker = scoreTrackerFactory.getScoreTracker(false, rerankK, threshold); // the main search loop while (candidates.size() > 0) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java index 815476db4..c6a26843a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java @@ -36,32 +36,23 @@ class ScoreTrackerFactory { } public ScoreTracker getScoreTracker(boolean pruneSearch, int rerankK, float threshold) { - // track scores to predict when we are done with threshold queries - final ScoreTracker scoreTracker; - + // Preserve legacy threshold behavior. Threshold searches used TwoPhaseTracker + // independent of the pruning flag, and may still do so for compatibility. if (threshold > 0) { if (twoPhaseTracker == null) { twoPhaseTracker = new ScoreTracker.TwoPhaseTracker(threshold); } else { twoPhaseTracker.reset(threshold); } - scoreTracker = twoPhaseTracker; - } else { - if (pruneSearch) { - if (relaxedMonotonicityTracker == null) { - relaxedMonotonicityTracker = new ScoreTracker.RelaxedMonotonicityTracker(rerankK); - } else { - relaxedMonotonicityTracker.reset(rerankK); - } - scoreTracker = relaxedMonotonicityTracker; - } else { - if (noOpTracker == null) { - noOpTracker = new ScoreTracker.NoOpTracker(); - } - scoreTracker = noOpTracker; - } + return twoPhaseTracker; + } + + // TopK and filtered pruning are disabled. Do not return + // RelaxedMonotonicityTracker, regardless of caller preference. + if (noOpTracker == null) { + noOpTracker = new ScoreTracker.NoOpTracker(); } - return scoreTracker; + return noOpTracker; } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java index 3d19e972b..e860ebd6f 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestLowCardinalityFiltering.java @@ -51,8 +51,8 @@ It splits the vectors in two classes (with probability 0.5) and tests that the f */ @Test public void testLowCardinalityFiltering() throws IOException { - testLowCardinalityFiltering(32, 0.044f, 0.91f, false); - testLowCardinalityFiltering(32, 0.048f, 0.93f, true); + testLowCardinalityFiltering(32, 0.055f, 0.95f, false); + testLowCardinalityFiltering(32, 0.055f, 0.95f, true); } public void testLowCardinalityFiltering(int maxDegree, float visitedRatioThreshold, float recallThreshold, boolean addHierarchy) throws IOException { var R = getRandom(); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestPruningCompatibility.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestPruningCompatibility.java new file mode 100644 index 000000000..6cb244aee --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestPruningCompatibility.java @@ -0,0 +1,212 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.LuceneTestCase; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestPruningCompatibility extends LuceneTestCase { + private static final int N_VECTORS = 10_000; + private static final int N_QUERIES = 10; + private static final int DIMENSIONS = 16; + private static final int TOP_K = 10; + private static final int RERANK_K = 100; + private static final int THRESHOLD_TOP_K_CAP = 1_000; + private static final int MAX_DEGREE = 32; + private static final VectorSimilarityFunction SIMILARITY = VectorSimilarityFunction.COSINE; + + @Test + public void testScoreTrackerFactoryPolicy() { + var factory = new ScoreTracker.ScoreTrackerFactory(); + + assertTrue(factory.getScoreTracker(false, RERANK_K, 0.0f) instanceof ScoreTracker.NoOpTracker); + assertTrue(factory.getScoreTracker(true, RERANK_K, 0.0f) instanceof ScoreTracker.NoOpTracker); + + // Preserve legacy threshold behavior: threshold searches still use TwoPhaseTracker. + assertTrue(factory.getScoreTracker(false, RERANK_K, 0.5f) instanceof ScoreTracker.TwoPhaseTracker); + assertTrue(factory.getScoreTracker(true, RERANK_K, 0.5f) instanceof ScoreTracker.TwoPhaseTracker); + } + + @Test + @SuppressWarnings("deprecation") + public void testUsePruningIgnoredForTopKAndFilteredTopK() throws IOException { + for (boolean addHierarchy : List.of(false, true)) { + Fixture fixture = buildFixture(addHierarchy); + + for (VectorFloat query : fixture.queries) { + assertSameWithPruningOffAndOn(fixture, query, Bits.ALL, 0.0f, TOP_K, RERANK_K); + assertSameWithPruningOffAndOn(fixture, query, fixture.evenOrds, 0.0f, TOP_K, RERANK_K); + } + } + } + + @Test + @SuppressWarnings("deprecation") + public void testUsePruningIgnoredForThresholdSearch() throws IOException { + for (boolean addHierarchy : List.of(false, true)) { + Fixture fixture = buildFixture(addHierarchy); + + for (VectorFloat query : fixture.queries) { + float threshold = exactThreshold(fixture.ravv, query, 100); + + assertSameWithPruningOffAndOn( + fixture, + query, + Bits.ALL, + threshold, + THRESHOLD_TOP_K_CAP, + THRESHOLD_TOP_K_CAP); + } + } + } + + private void assertSameWithPruningOffAndOn(Fixture fixture, + VectorFloat query, + Bits acceptOrds, + float threshold, + int topK, + int rerankK) { + SearchResult pruningOff = search(fixture, query, acceptOrds, threshold, topK, rerankK, false); + SearchResult pruningOn = search(fixture, query, acceptOrds, threshold, topK, rerankK, true); + + assertEquals(pruningOff.getVisitedCount(), pruningOn.getVisitedCount()); + assertEquals(pruningOff.getNodes().length, pruningOn.getNodes().length); + assertArrayEquals(sortedNodes(pruningOff), sortedNodes(pruningOn)); + + if (threshold > 0.0f) { + assertAllAtOrAboveThreshold(fixture, query, threshold, pruningOff); + assertAllAtOrAboveThreshold(fixture, query, threshold, pruningOn); + } + } + + @SuppressWarnings("deprecation") + private SearchResult search(Fixture fixture, + VectorFloat query, + Bits acceptOrds, + float threshold, + int topK, + int rerankK, + boolean usePruning) { + var searcher = new GraphSearcher(fixture.graph); + searcher.usePruning(usePruning); + + var sf = fixture.ravv.rerankerFor(query, SIMILARITY); + return searcher.search( + new DefaultSearchScoreProvider(sf), + topK, + rerankK, + threshold, + 0.0f, + acceptOrds); + } + + private Fixture buildFixture(boolean addHierarchy) throws IOException { + var random = getRandom(); + + VectorFloat[] vectors = TestVectorGraph.createRandomFloatVectors(N_VECTORS, DIMENSIONS, random); + var ravv = new ListRandomAccessVectorValues(List.of(vectors), DIMENSIONS); + + var builder = new GraphIndexBuilder( + ravv, + SIMILARITY, + MAX_DEGREE, + 2 * MAX_DEGREE, + 1.2f, + 1.2f, + addHierarchy); + var graph = builder.build(ravv); + + FixedBitSet evenOrds = new FixedBitSet(N_VECTORS); + for (int i = 0; i < N_VECTORS; i += 2) { + evenOrds.set(i); + } + + VectorFloat[] queries = new VectorFloat[N_QUERIES]; + for (int i = 0; i < N_QUERIES; i++) { + queries[i] = TestUtil.randomVector(random, DIMENSIONS); + } + + return new Fixture(ravv, graph, evenOrds, queries); + } + + private float exactThreshold(RandomAccessVectorValues ravv, + VectorFloat query, + int targetMatches) { + float[] scores = new float[ravv.size()]; + for (int i = 0; i < ravv.size(); i++) { + scores[i] = SIMILARITY.compare(query, ravv.getVector(i)); + } + + Arrays.sort(scores); + return scores[scores.length - targetMatches]; + } + + private void assertAllAtOrAboveThreshold(Fixture fixture, + VectorFloat query, + float threshold, + SearchResult result) { + for (var nodeScore : result.getNodes()) { + float score = SIMILARITY.compare(query, fixture.ravv.getVector(nodeScore.node)); + assertTrue( + "returned node below threshold: node=" + nodeScore.node + + ", score=" + score + + ", threshold=" + threshold, + score + 1e-6f >= threshold); + } + } + + private static int[] sortedNodes(SearchResult result) { + int[] nodes = Arrays.stream(result.getNodes()) + .mapToInt(nodeScore -> nodeScore.node) + .toArray(); + Arrays.sort(nodes); + return nodes; + } + + private static class Fixture { + final RandomAccessVectorValues ravv; + final ImmutableGraphIndex graph; + final FixedBitSet evenOrds; + final VectorFloat[] queries; + + Fixture(RandomAccessVectorValues ravv, + ImmutableGraphIndex graph, + FixedBitSet evenOrds, + VectorFloat[] queries) { + this.ravv = ravv; + this.graph = graph; + this.evenOrds = evenOrds; + this.queries = queries; + } + } +}