diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 9bc40b404293..ecb568878a7b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -184,6 +184,9 @@ API Changes * GITHUB#15636: Introduce TermsEnum.preferSeekExact(). Update and delete processing obey this setting, allowing bloom or other approximate membership filters to apply in these paths. (Trevor McCulloch) +* GITHUB#15614: Added optional query bit size hint plumbing to KnnSearchStrategy.Hnsw, + including passthrough in Patience/seeded strategy wrapping. (Arup Chauhan) + New Features --------------------- (No changes) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 9148316651fe..5766392bd4c2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -96,14 +96,16 @@ public void nextCandidate() { public KnnSearchStrategy getSearchStrategy() { KnnSearchStrategy delegateStrategy = delegate.getSearchStrategy(); if (delegateStrategy instanceof KnnSearchStrategy.Hnsw hnswStrategy) { - return new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold()); + return new KnnSearchStrategy.Patience( + this, hnswStrategy.filteredSearchThreshold(), hnswStrategy.queryBitSizeHint()); } else if (delegateStrategy instanceof KnnSearchStrategy.Seeded seededStrategy) { if (seededStrategy.originalStrategy() instanceof KnnSearchStrategy.Hnsw hnswStrategy) { // rewrap the underlying HNSW strategy with patience // this way we still use the seeded entry points, filter threshold, // and can utilize patience thresholds KnnSearchStrategy.Patience patienceStrategy = - new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold()); + new KnnSearchStrategy.Patience( + this, hnswStrategy.filteredSearchThreshold(), hnswStrategy.queryBitSizeHint()); return new KnnSearchStrategy.Seeded( seededStrategy.entryPoints(), seededStrategy.numberOfEntryPoints(), patienceStrategy); } diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/KnnSearchStrategy.java b/lucene/core/src/java/org/apache/lucene/search/knn/KnnSearchStrategy.java index b0923da43a4b..ef6c14ecd835 100644 --- a/lucene/core/src/java/org/apache/lucene/search/knn/KnnSearchStrategy.java +++ b/lucene/core/src/java/org/apache/lucene/search/knn/KnnSearchStrategy.java @@ -53,6 +53,7 @@ public static class Hnsw extends KnnSearchStrategy { public static final Hnsw DEFAULT = new Hnsw(DEFAULT_FILTERED_SEARCH_THRESHOLD); private final int filteredSearchThreshold; + private final Integer queryBitSizeHint; /** * Create a new Hnsw strategy @@ -61,16 +62,41 @@ public static class Hnsw extends KnnSearchStrategy { * 100 where 0 means never use filtered search and 100 means always use filtered search. */ public Hnsw(int filteredSearchThreshold) { + this(filteredSearchThreshold, null); + } + + /** + * Create a new Hnsw strategy + * + * @param filteredSearchThreshold threshold for filtered search, a percentage value from 0 to + * 100 where 0 means never use filtered search and 100 means always use filtered search. + * @param queryBitSizeHint optional hint for query bit size to be used by codecs/scorers that + * support this optimization. + */ + public Hnsw(int filteredSearchThreshold, Integer queryBitSizeHint) { if (filteredSearchThreshold < 0 || filteredSearchThreshold > 100) { throw new IllegalArgumentException("filteredSearchThreshold must be >= 0 and <= 100"); } + if (queryBitSizeHint != null && queryBitSizeHint <= 0) { + throw new IllegalArgumentException("queryBitSizeHint must be > 0"); + } this.filteredSearchThreshold = filteredSearchThreshold; + this.queryBitSizeHint = queryBitSizeHint; } public int filteredSearchThreshold() { return filteredSearchThreshold; } + /** + * Optional hint for query bit size used by supported codecs/scorers. + * + * @return query bit size hint, or null when unspecified + */ + public Integer queryBitSizeHint() { + return queryBitSizeHint; + } + /** * Whether to use filtered search based on the ratio of vectors that pass the filter * @@ -87,12 +113,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Hnsw hnsw = (Hnsw) o; - return filteredSearchThreshold == hnsw.filteredSearchThreshold; + return filteredSearchThreshold == hnsw.filteredSearchThreshold + && Objects.equals(queryBitSizeHint, hnsw.queryBitSizeHint); } @Override public int hashCode() { - return Objects.hashCode(filteredSearchThreshold); + return Objects.hash(filteredSearchThreshold, queryBitSizeHint); } @Override @@ -180,7 +207,14 @@ public static class Patience extends Hnsw { private final HnswQueueSaturationCollector collector; public Patience(HnswQueueSaturationCollector collector, int filteredSearchThreshold) { - super(filteredSearchThreshold); + this(collector, filteredSearchThreshold, null); + } + + public Patience( + HnswQueueSaturationCollector collector, + int filteredSearchThreshold, + Integer queryBitSizeHint) { + super(filteredSearchThreshold, queryBitSizeHint); this.collector = collector; } @@ -191,7 +225,7 @@ public boolean equals(Object obj) { @Override public int hashCode() { - return Objects.hash(super.filteredSearchThreshold, collector); + return Objects.hash(super.hashCode(), collector); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnSearchStrategy.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnSearchStrategy.java new file mode 100644 index 000000000000..90a1b8513717 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnSearchStrategy.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestKnnSearchStrategy extends LuceneTestCase { + + public void testHnswDefaultConstructorHasNoQueryBitHint() { + KnnSearchStrategy.Hnsw strategy = new KnnSearchStrategy.Hnsw(60); + assertNull(strategy.queryBitSizeHint()); + } + + public void testHnswConstructorWithQueryBitHint() { + KnnSearchStrategy.Hnsw strategy = new KnnSearchStrategy.Hnsw(60, 4); + assertEquals(Integer.valueOf(4), strategy.queryBitSizeHint()); + } + + public void testHnswQueryBitHintValidation() { + IllegalArgumentException exception = + expectThrows(IllegalArgumentException.class, () -> new KnnSearchStrategy.Hnsw(60, 0)); + assertEquals("queryBitSizeHint must be > 0", exception.getMessage()); + } + + public void testHnswEqualsAndHashCodeIncludeQueryBitHint() { + KnnSearchStrategy.Hnsw withHintA = new KnnSearchStrategy.Hnsw(60, 4); + KnnSearchStrategy.Hnsw withHintB = new KnnSearchStrategy.Hnsw(60, 4); + KnnSearchStrategy.Hnsw withDifferentHint = new KnnSearchStrategy.Hnsw(60, 2); + KnnSearchStrategy.Hnsw withoutHint = new KnnSearchStrategy.Hnsw(60); + + assertEquals(withHintA, withHintB); + assertEquals(withHintA.hashCode(), withHintB.hashCode()); + assertNotEquals(withHintA, withDifferentHint); + assertNotEquals(withHintA, withoutHint); + } + + public void testSeededPreservesOriginalHnswQueryBitHint() { + KnnSearchStrategy.Hnsw original = new KnnSearchStrategy.Hnsw(60, 2); + KnnSearchStrategy.Seeded seeded = + new KnnSearchStrategy.Seeded(DocIdSetIterator.empty(), 0, original); + + assertSame(original, seeded.originalStrategy()); + assertEquals( + Integer.valueOf(2), + ((KnnSearchStrategy.Hnsw) seeded.originalStrategy()).queryBitSizeHint()); + } + + public void testPatienceWrappingPreservesQueryBitHint() { + KnnSearchStrategy.Hnsw strategy = new KnnSearchStrategy.Hnsw(60, 1); + TopKnnCollector collector = new TopKnnCollector(2, 10, strategy); + HnswQueueSaturationCollector wrapped = new HnswQueueSaturationCollector(collector, 0.99d, 5); + + KnnSearchStrategy wrappedStrategy = wrapped.getSearchStrategy(); + assertTrue(wrappedStrategy instanceof KnnSearchStrategy.Patience); + assertEquals( + Integer.valueOf(1), ((KnnSearchStrategy.Patience) wrappedStrategy).queryBitSizeHint()); + } + + public void testPatienceEqualsAndHashCodeIncludeQueryBitHint() { + TopKnnCollector collector = new TopKnnCollector(2, 10, KnnSearchStrategy.Hnsw.DEFAULT); + HnswQueueSaturationCollector wrapped = new HnswQueueSaturationCollector(collector, 0.99d, 5); + KnnSearchStrategy.Patience withHintA = new KnnSearchStrategy.Patience(wrapped, 60, 4); + KnnSearchStrategy.Patience withHintB = new KnnSearchStrategy.Patience(wrapped, 60, 4); + KnnSearchStrategy.Patience withDifferentHint = new KnnSearchStrategy.Patience(wrapped, 60, 2); + + assertEquals(withHintA, withHintB); + assertEquals(withHintA.hashCode(), withHintB.hashCode()); + assertNotEquals(withHintA, withDifferentHint); + } +}