From 9338a5311b14c66a66e870c005cab1834216150c Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Fri, 5 Dec 2025 15:15:06 -0600 Subject: [PATCH] Add diversityScoreFunctionFor to avoid creation of wrapper object --- .../bench/PQDistanceCalculationBenchmark.java | 15 +++++++ .../diversity/VamanaDiversityProvider.java | 2 +- .../graph/similarity/BuildScoreProvider.java | 44 +++++++++++++++---- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java index 59342e41a..fe8f4857b 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java @@ -130,6 +130,21 @@ public void diversityCalculation(Blackhole blackhole) { blackhole.consume(totalSimilarity); } + @Benchmark + public void diversityCalculationScoreProvider(Blackhole blackhole) { + float totalSimilarity = 0; + + for (int q = 0; q < queryCount; q++) { + for (int i = 0; i < vectorCount; i++) { + final ScoreFunction sf = buildScoreProvider.diversityScoreFunctionFor(i); + float similarity = sf.similarityTo(q); + totalSimilarity += similarity; + } + } + + blackhole.consume(totalSimilarity); + } + private VectorFloat createRandomVector(int dimension) { VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); for (int i = 0; i < dimension; i++) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java index 0bdc6415f..4721ab2ca 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java @@ -60,7 +60,7 @@ public double retainDiverse(NodeArray neighbors, int maxDegree, int diverseBefor int cNode = neighbors.getNode(i); float cScore = neighbors.getScore(i); - var sf = scoreProvider.diversityProviderFor(cNode).scoreFunction(); + var sf = scoreProvider.diversityScoreFunctionFor(cNode); if (isDiverse(cNode, cScore, neighbors, sf, selected, currentAlpha)) { selected.set(i); nSelected++; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index f0b184e67..1049069de 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -82,6 +82,14 @@ public interface BuildScoreProvider { */ SearchScoreProvider diversityProviderFor(int node1); + /** + * Create the diversity provider's score function. See {@link #diversityProviderFor(int)} for documentation + * on the use of the ScoreFunction. + */ + default ScoreFunction diversityScoreFunctionFor(int node1) { + return diversityProviderFor(node1).scoreFunction(); + } + /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. * @@ -141,6 +149,14 @@ public SearchScoreProvider diversityProviderFor(int node1) { var vc = vectorsCopy.get(); return DefaultSearchScoreProvider.exact(v, similarityFunction, vc); } + + @Override + public ScoreFunction diversityScoreFunctionFor(int node1) { + var v = vectors.get().getVector(node1); + var vc = vectorsCopy.get(); + // don't use ESF.reranker, we need thread safety here + return (ScoreFunction.ExactScoreFunction) node2 -> similarityFunction.compare(v, vc.getVector(node2)); + } }; } @@ -162,11 +178,16 @@ public boolean isExact() { } @Override - public SearchScoreProvider diversityProviderFor(int node1) { + public ScoreFunction diversityScoreFunctionFor(int node1) { // like searchProviderFor, this skips reranking; unlike sPF, it uses pqv.scoreFunctionFor // instead of precomputedScoreFunctionFor; since we only perform a few dozen comparisons // during diversity computation, this is cheaper than precomputing a lookup table - var asf = pqv.diversityFunctionFor(node1, vsf); // not precomputed! + return pqv.diversityFunctionFor(node1, vsf); // not precomputed! + } + + @Override + public SearchScoreProvider diversityProviderFor(int node1) { + var asf = diversityScoreFunctionFor(node1); return new DefaultSearchScoreProvider(asf); } @@ -210,8 +231,18 @@ public SearchScoreProvider searchProviderFor(VectorFloat vector) { @Override public SearchScoreProvider searchProviderFor(int node1) { + return new DefaultSearchScoreProvider(diversityScoreFunctionFor(node1)); + } + + @Override + public SearchScoreProvider diversityProviderFor(int node1) { + return searchProviderFor(node1); + } + + @Override + public ScoreFunction diversityScoreFunctionFor(int node1) { var encoded1 = bqv.get(node1); - return new DefaultSearchScoreProvider(new ScoreFunction() { + return new ScoreFunction() { @Override public boolean isExact() { return false; @@ -221,12 +252,7 @@ public boolean isExact() { public float similarityTo(int node2) { return bqv.similarityBetween(encoded1, bqv.get(node2)); } - }); - } - - @Override - public SearchScoreProvider diversityProviderFor(int node1) { - return searchProviderFor(node1); + }; } }; }