Skip to content

Commit e0e5d81

Browse files
kaivalnpKaival Parikh
and
Kaival Parikh
authored
Add timeout support to AbstractVectorSimilarityQuery (apache#13285)
Co-authored-by: Kaival Parikh <[email protected]>
1 parent 43c8011 commit e0e5d81

File tree

5 files changed

+235
-54
lines changed

5 files changed

+235
-54
lines changed

lucene/CHANGES.txt

+3
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ Improvements
288288

289289
* GITHUB#13625: Remove BitSet#nextSetBit code duplication. (Greg Miller)
290290

291+
* GITHUB#13285: Early terminate graph searches of AbstractVectorSimilarityQuery to follow timeout set from
292+
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)
293+
291294
Optimizations
292295
---------------------
293296

lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java

+68-48
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.Objects;
2424
import org.apache.lucene.index.LeafReader;
2525
import org.apache.lucene.index.LeafReaderContext;
26+
import org.apache.lucene.index.QueryTimeout;
27+
import org.apache.lucene.search.knn.KnnCollectorManager;
2628
import org.apache.lucene.util.BitSet;
2729
import org.apache.lucene.util.BitSetIterator;
2830
import org.apache.lucene.util.Bits;
@@ -58,10 +60,19 @@ abstract class AbstractVectorSimilarityQuery extends Query {
5860
this.filter = filter;
5961
}
6062

63+
protected KnnCollectorManager getKnnCollectorManager() {
64+
return (visitedLimit, context) ->
65+
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitedLimit);
66+
}
67+
6168
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;
6269

6370
protected abstract TopDocs approximateSearch(
64-
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;
71+
LeafReaderContext context,
72+
Bits acceptDocs,
73+
int visitLimit,
74+
KnnCollectorManager knnCollectorManager)
75+
throws IOException;
6576

6677
@Override
6778
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
@@ -72,6 +83,10 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
7283
? null
7384
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);
7485

86+
final QueryTimeout queryTimeout = searcher.getTimeout();
87+
final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager =
88+
new TimeLimitingKnnCollectorManager(getKnnCollectorManager(), queryTimeout);
89+
7590
@Override
7691
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
7792
if (filterWeight != null) {
@@ -103,16 +118,14 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
103118
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
104119
LeafReader leafReader = context.reader();
105120
Bits liveDocs = leafReader.getLiveDocs();
106-
final Scorer vectorSimilarityScorer;
121+
107122
// If there is no filter
108123
if (filterWeight == null) {
109124
// Return exhaustive results
110-
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
111-
if (results.scoreDocs.length == 0) {
112-
return null;
113-
}
114-
vectorSimilarityScorer =
115-
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
125+
TopDocs results =
126+
approximateSearch(
127+
context, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
128+
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
116129
} else {
117130
Scorer scorer = filterWeight.scorer(context);
118131
if (scorer == null) {
@@ -143,27 +156,23 @@ protected boolean match(int doc) {
143156
}
144157

145158
// Perform an approximate search
146-
TopDocs results = approximateSearch(context, acceptDocs, cardinality);
159+
TopDocs results =
160+
approximateSearch(context, acceptDocs, cardinality, timeLimitingKnnCollectorManager);
147161

148-
// If the limit was exhausted
149-
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
150-
// Return a lazy-loading iterator
151-
vectorSimilarityScorer =
152-
VectorSimilarityScorer.fromAcceptDocs(
153-
this,
154-
boost,
155-
createVectorScorer(context),
156-
new BitSetIterator(acceptDocs, cardinality),
157-
resultSimilarity);
158-
} else if (results.scoreDocs.length == 0) {
159-
return null;
160-
} else {
162+
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
163+
// Return partial results only when timeout is met
164+
|| (queryTimeout != null && queryTimeout.shouldExit())) {
161165
// Return an iterator over the collected results
162-
vectorSimilarityScorer =
163-
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
166+
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
167+
} else {
168+
// Return a lazy-loading iterator
169+
return VectorSimilarityScorerSupplier.fromAcceptDocs(
170+
boost,
171+
createVectorScorer(context),
172+
new BitSetIterator(acceptDocs, cardinality),
173+
resultSimilarity);
164174
}
165175
}
166-
return new DefaultScorerSupplier(vectorSimilarityScorer);
167176
}
168177

169178
@Override
@@ -197,16 +206,20 @@ public int hashCode() {
197206
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
198207
}
199208

200-
private static class VectorSimilarityScorer extends Scorer {
209+
private static class VectorSimilarityScorerSupplier extends ScorerSupplier {
201210
final DocIdSetIterator iterator;
202211
final float[] cachedScore;
203212

204-
VectorSimilarityScorer(DocIdSetIterator iterator, float[] cachedScore) {
213+
VectorSimilarityScorerSupplier(DocIdSetIterator iterator, float[] cachedScore) {
205214
this.iterator = iterator;
206215
this.cachedScore = cachedScore;
207216
}
208217

209-
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
218+
static VectorSimilarityScorerSupplier fromScoreDocs(float boost, ScoreDoc[] scoreDocs) {
219+
if (scoreDocs.length == 0) {
220+
return null;
221+
}
222+
210223
// Sort in ascending order of docid
211224
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
212225

@@ -252,18 +265,15 @@ public long cost() {
252265
}
253266
};
254267

255-
return new VectorSimilarityScorer(iterator, cachedScore);
268+
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
256269
}
257270

258-
static VectorSimilarityScorer fromAcceptDocs(
259-
Weight weight,
260-
float boost,
261-
VectorScorer scorer,
262-
DocIdSetIterator acceptDocs,
263-
float threshold) {
271+
static VectorSimilarityScorerSupplier fromAcceptDocs(
272+
float boost, VectorScorer scorer, DocIdSetIterator acceptDocs, float threshold) {
264273
if (scorer == null) {
265274
return null;
266275
}
276+
267277
float[] cachedScore = new float[1];
268278
DocIdSetIterator vectorIterator = scorer.iterator();
269279
DocIdSetIterator conjunction =
@@ -281,27 +291,37 @@ protected boolean match(int doc) throws IOException {
281291
}
282292
};
283293

284-
return new VectorSimilarityScorer(iterator, cachedScore);
294+
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
285295
}
286296

287297
@Override
288-
public int docID() {
289-
return iterator.docID();
290-
}
298+
public Scorer get(long leadCost) {
299+
return new Scorer() {
300+
@Override
301+
public int docID() {
302+
return iterator.docID();
303+
}
291304

292-
@Override
293-
public DocIdSetIterator iterator() {
294-
return iterator;
295-
}
305+
@Override
306+
public DocIdSetIterator iterator() {
307+
return iterator;
308+
}
296309

297-
@Override
298-
public float getMaxScore(int upTo) {
299-
return Float.POSITIVE_INFINITY;
310+
@Override
311+
public float getMaxScore(int upTo) {
312+
return Float.POSITIVE_INFINITY;
313+
}
314+
315+
@Override
316+
public float score() {
317+
return cachedScore[0];
318+
}
319+
};
300320
}
301321

302322
@Override
303-
public float score() {
304-
return cachedScore[0];
323+
public long cost() {
324+
return iterator.cost();
305325
}
306326
}
307327
}

lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.document.KnnByteVectorField;
2424
import org.apache.lucene.index.ByteVectorValues;
2525
import org.apache.lucene.index.LeafReaderContext;
26+
import org.apache.lucene.search.knn.KnnCollectorManager;
2627
import org.apache.lucene.util.Bits;
2728

2829
/**
@@ -106,10 +107,13 @@ VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
106107

107108
@Override
108109
@SuppressWarnings("resource")
109-
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
110+
protected TopDocs approximateSearch(
111+
LeafReaderContext context,
112+
Bits acceptDocs,
113+
int visitLimit,
114+
KnnCollectorManager knnCollectorManager)
110115
throws IOException {
111-
KnnCollector collector =
112-
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
116+
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
113117
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
114118
return collector.topDocs();
115119
}

lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.document.KnnFloatVectorField;
2424
import org.apache.lucene.index.FloatVectorValues;
2525
import org.apache.lucene.index.LeafReaderContext;
26+
import org.apache.lucene.search.knn.KnnCollectorManager;
2627
import org.apache.lucene.util.Bits;
2728
import org.apache.lucene.util.VectorUtil;
2829

@@ -108,10 +109,13 @@ VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
108109

109110
@Override
110111
@SuppressWarnings("resource")
111-
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
112+
protected TopDocs approximateSearch(
113+
LeafReaderContext context,
114+
Bits acceptDocs,
115+
int visitLimit,
116+
KnnCollectorManager knnCollectorManager)
112117
throws IOException {
113-
KnnCollector collector =
114-
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
118+
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
115119
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
116120
return collector.topDocs();
117121
}

0 commit comments

Comments
 (0)