Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,9 @@ API Changes
instance instead of a Bits instance to identify document IDs to filter.
(Shubham Chaudhary, Adrien Grand)

* GITHUB#15614: Added optional query bit size hint plumbing to KnnSearchStrategy.Hnsw,
including passthrough in Patience/seeded strategy wrapping. (Arup Chauhan)

New Features
---------------------
* GITHUB#15015: MultiIndexMergeScheduler: a production multi-tenant merge scheduler (Shawn Yarbrough)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ public void nextCandidate() {
public KnnSearchStrategy getSearchStrategy() {
KnnSearchStrategy delegateStrategy = delegate.getSearchStrategy();
if (delegateStrategy instanceof KnnSearchStrategy.Hnsw hnswStrategy) {
return new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold());
return new KnnSearchStrategy.Patience(
this, hnswStrategy.filteredSearchThreshold(), hnswStrategy.queryBitSizeHint());
} else if (delegateStrategy instanceof KnnSearchStrategy.Seeded seededStrategy) {
if (seededStrategy.originalStrategy() instanceof KnnSearchStrategy.Hnsw hnswStrategy) {
// rewrap the underlying HNSW strategy with patience
// this way we still use the seeded entry points, filter threshold,
// and can utilize patience thresholds
KnnSearchStrategy.Patience patienceStrategy =
new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold());
new KnnSearchStrategy.Patience(
this, hnswStrategy.filteredSearchThreshold(), hnswStrategy.queryBitSizeHint());
return new KnnSearchStrategy.Seeded(
seededStrategy.entryPoints(), seededStrategy.numberOfEntryPoints(), patienceStrategy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public static class Hnsw extends KnnSearchStrategy {
public static final Hnsw DEFAULT = new Hnsw(DEFAULT_FILTERED_SEARCH_THRESHOLD);

private final int filteredSearchThreshold;
private final Integer queryBitSizeHint;

/**
* Create a new Hnsw strategy
Expand All @@ -61,16 +62,41 @@ public static class Hnsw extends KnnSearchStrategy {
* 100 where 0 means never use filtered search and 100 means always use filtered search.
*/
public Hnsw(int filteredSearchThreshold) {
this(filteredSearchThreshold, null);
}

/**
* Create a new Hnsw strategy
*
* @param filteredSearchThreshold threshold for filtered search, a percentage value from 0 to
* 100 where 0 means never use filtered search and 100 means always use filtered search.
* @param queryBitSizeHint optional hint for query bit size to be used by codecs/scorers that
* support this optimization.
*/
public Hnsw(int filteredSearchThreshold, Integer queryBitSizeHint) {
if (filteredSearchThreshold < 0 || filteredSearchThreshold > 100) {
throw new IllegalArgumentException("filteredSearchThreshold must be >= 0 and <= 100");
}
if (queryBitSizeHint != null && queryBitSizeHint <= 0) {
throw new IllegalArgumentException("queryBitSizeHint must be > 0");
}
this.filteredSearchThreshold = filteredSearchThreshold;
this.queryBitSizeHint = queryBitSizeHint;
}

public int filteredSearchThreshold() {
return filteredSearchThreshold;
}

/**
* Optional hint for query bit size used by supported codecs/scorers.
*
* @return query bit size hint, or null when unspecified
*/
public Integer queryBitSizeHint() {
return queryBitSizeHint;
}

/**
* Whether to use filtered search based on the ratio of vectors that pass the filter
*
Expand All @@ -87,12 +113,13 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Hnsw hnsw = (Hnsw) o;
return filteredSearchThreshold == hnsw.filteredSearchThreshold;
return filteredSearchThreshold == hnsw.filteredSearchThreshold
&& Objects.equals(queryBitSizeHint, hnsw.queryBitSizeHint);
}

@Override
public int hashCode() {
return Objects.hashCode(filteredSearchThreshold);
return Objects.hash(filteredSearchThreshold, queryBitSizeHint);
}

@Override
Expand Down Expand Up @@ -180,7 +207,14 @@ public static class Patience extends Hnsw {
private final HnswQueueSaturationCollector collector;

public Patience(HnswQueueSaturationCollector collector, int filteredSearchThreshold) {
super(filteredSearchThreshold);
this(collector, filteredSearchThreshold, null);
}

public Patience(
HnswQueueSaturationCollector collector,
int filteredSearchThreshold,
Integer queryBitSizeHint) {
super(filteredSearchThreshold, queryBitSizeHint);
this.collector = collector;
}

Expand All @@ -191,7 +225,7 @@ public boolean equals(Object obj) {

@Override
public int hashCode() {
return Objects.hash(super.filteredSearchThreshold, collector);
return Objects.hash(super.hashCode(), collector);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.tests.util.LuceneTestCase;

public class TestKnnSearchStrategy extends LuceneTestCase {

public void testHnswDefaultConstructorHasNoQueryBitHint() {
KnnSearchStrategy.Hnsw strategy = new KnnSearchStrategy.Hnsw(60);
assertNull(strategy.queryBitSizeHint());
}

public void testHnswConstructorWithQueryBitHint() {
KnnSearchStrategy.Hnsw strategy = new KnnSearchStrategy.Hnsw(60, 4);
assertEquals(Integer.valueOf(4), strategy.queryBitSizeHint());
}

public void testHnswQueryBitHintValidation() {
IllegalArgumentException exception =
expectThrows(IllegalArgumentException.class, () -> new KnnSearchStrategy.Hnsw(60, 0));
assertEquals("queryBitSizeHint must be > 0", exception.getMessage());
}

public void testHnswEqualsAndHashCodeIncludeQueryBitHint() {
KnnSearchStrategy.Hnsw withHintA = new KnnSearchStrategy.Hnsw(60, 4);
KnnSearchStrategy.Hnsw withHintB = new KnnSearchStrategy.Hnsw(60, 4);
KnnSearchStrategy.Hnsw withDifferentHint = new KnnSearchStrategy.Hnsw(60, 2);
KnnSearchStrategy.Hnsw withoutHint = new KnnSearchStrategy.Hnsw(60);

assertEquals(withHintA, withHintB);
assertEquals(withHintA.hashCode(), withHintB.hashCode());
assertNotEquals(withHintA, withDifferentHint);
assertNotEquals(withHintA, withoutHint);
}

public void testSeededPreservesOriginalHnswQueryBitHint() {
KnnSearchStrategy.Hnsw original = new KnnSearchStrategy.Hnsw(60, 2);
KnnSearchStrategy.Seeded seeded =
new KnnSearchStrategy.Seeded(DocIdSetIterator.empty(), 0, original);

assertSame(original, seeded.originalStrategy());
assertEquals(
Integer.valueOf(2),
((KnnSearchStrategy.Hnsw) seeded.originalStrategy()).queryBitSizeHint());
}

public void testPatienceWrappingPreservesQueryBitHint() {
KnnSearchStrategy.Hnsw strategy = new KnnSearchStrategy.Hnsw(60, 1);
TopKnnCollector collector = new TopKnnCollector(2, 10, strategy);
HnswQueueSaturationCollector wrapped = new HnswQueueSaturationCollector(collector, 0.99d, 5);

KnnSearchStrategy wrappedStrategy = wrapped.getSearchStrategy();
assertTrue(wrappedStrategy instanceof KnnSearchStrategy.Patience);
assertEquals(
Integer.valueOf(1), ((KnnSearchStrategy.Patience) wrappedStrategy).queryBitSizeHint());
}

public void testPatienceEqualsAndHashCodeIncludeQueryBitHint() {
TopKnnCollector collector = new TopKnnCollector(2, 10, KnnSearchStrategy.Hnsw.DEFAULT);
HnswQueueSaturationCollector wrapped = new HnswQueueSaturationCollector(collector, 0.99d, 5);
KnnSearchStrategy.Patience withHintA = new KnnSearchStrategy.Patience(wrapped, 60, 4);
KnnSearchStrategy.Patience withHintB = new KnnSearchStrategy.Patience(wrapped, 60, 4);
KnnSearchStrategy.Patience withDifferentHint = new KnnSearchStrategy.Patience(wrapped, 60, 2);

assertEquals(withHintA, withHintB);
assertEquals(withHintA.hashCode(), withHintB.hashCode());
assertNotEquals(withHintA, withDifferentHint);
}
}
Loading