23
23
import java .util .Objects ;
24
24
import org .apache .lucene .index .LeafReader ;
25
25
import org .apache .lucene .index .LeafReaderContext ;
26
+ import org .apache .lucene .index .QueryTimeout ;
27
+ import org .apache .lucene .search .knn .KnnCollectorManager ;
26
28
import org .apache .lucene .util .BitSet ;
27
29
import org .apache .lucene .util .BitSetIterator ;
28
30
import org .apache .lucene .util .Bits ;
@@ -58,10 +60,19 @@ abstract class AbstractVectorSimilarityQuery extends Query {
58
60
this .filter = filter ;
59
61
}
60
62
63
+ protected KnnCollectorManager getKnnCollectorManager () {
64
+ return (visitedLimit , context ) ->
65
+ new VectorSimilarityCollector (traversalSimilarity , resultSimilarity , visitedLimit );
66
+ }
67
+
61
68
abstract VectorScorer createVectorScorer (LeafReaderContext context ) throws IOException ;
62
69
63
70
protected abstract TopDocs approximateSearch (
64
- LeafReaderContext context , Bits acceptDocs , int visitLimit ) throws IOException ;
71
+ LeafReaderContext context ,
72
+ Bits acceptDocs ,
73
+ int visitLimit ,
74
+ KnnCollectorManager knnCollectorManager )
75
+ throws IOException ;
65
76
66
77
@ Override
67
78
public Weight createWeight (IndexSearcher searcher , ScoreMode scoreMode , float boost )
@@ -72,6 +83,10 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
72
83
? null
73
84
: searcher .createWeight (searcher .rewrite (filter ), ScoreMode .COMPLETE_NO_SCORES , 1 );
74
85
86
+ final QueryTimeout queryTimeout = searcher .getTimeout ();
87
+ final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager =
88
+ new TimeLimitingKnnCollectorManager (getKnnCollectorManager (), queryTimeout );
89
+
75
90
@ Override
76
91
public Explanation explain (LeafReaderContext context , int doc ) throws IOException {
77
92
if (filterWeight != null ) {
@@ -103,16 +118,14 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
103
118
public ScorerSupplier scorerSupplier (LeafReaderContext context ) throws IOException {
104
119
LeafReader leafReader = context .reader ();
105
120
Bits liveDocs = leafReader .getLiveDocs ();
106
- final Scorer vectorSimilarityScorer ;
121
+
107
122
// If there is no filter
108
123
if (filterWeight == null ) {
109
124
// Return exhaustive results
110
- TopDocs results = approximateSearch (context , liveDocs , Integer .MAX_VALUE );
111
- if (results .scoreDocs .length == 0 ) {
112
- return null ;
113
- }
114
- vectorSimilarityScorer =
115
- VectorSimilarityScorer .fromScoreDocs (this , boost , results .scoreDocs );
125
+ TopDocs results =
126
+ approximateSearch (
127
+ context , liveDocs , Integer .MAX_VALUE , timeLimitingKnnCollectorManager );
128
+ return VectorSimilarityScorerSupplier .fromScoreDocs (boost , results .scoreDocs );
116
129
} else {
117
130
Scorer scorer = filterWeight .scorer (context );
118
131
if (scorer == null ) {
@@ -143,27 +156,23 @@ protected boolean match(int doc) {
143
156
}
144
157
145
158
// Perform an approximate search
146
- TopDocs results = approximateSearch (context , acceptDocs , cardinality );
159
+ TopDocs results =
160
+ approximateSearch (context , acceptDocs , cardinality , timeLimitingKnnCollectorManager );
147
161
148
- // If the limit was exhausted
149
- if (results .totalHits .relation == TotalHits .Relation .GREATER_THAN_OR_EQUAL_TO ) {
150
- // Return a lazy-loading iterator
151
- vectorSimilarityScorer =
152
- VectorSimilarityScorer .fromAcceptDocs (
153
- this ,
154
- boost ,
155
- createVectorScorer (context ),
156
- new BitSetIterator (acceptDocs , cardinality ),
157
- resultSimilarity );
158
- } else if (results .scoreDocs .length == 0 ) {
159
- return null ;
160
- } else {
162
+ if (results .totalHits .relation == TotalHits .Relation .EQUAL_TO
163
+ // Return partial results only when timeout is met
164
+ || (queryTimeout != null && queryTimeout .shouldExit ())) {
161
165
// Return an iterator over the collected results
162
- vectorSimilarityScorer =
163
- VectorSimilarityScorer .fromScoreDocs (this , boost , results .scoreDocs );
166
+ return VectorSimilarityScorerSupplier .fromScoreDocs (boost , results .scoreDocs );
167
+ } else {
168
+ // Return a lazy-loading iterator
169
+ return VectorSimilarityScorerSupplier .fromAcceptDocs (
170
+ boost ,
171
+ createVectorScorer (context ),
172
+ new BitSetIterator (acceptDocs , cardinality ),
173
+ resultSimilarity );
164
174
}
165
175
}
166
- return new DefaultScorerSupplier (vectorSimilarityScorer );
167
176
}
168
177
169
178
@ Override
@@ -197,16 +206,20 @@ public int hashCode() {
197
206
return Objects .hash (field , traversalSimilarity , resultSimilarity , filter );
198
207
}
199
208
200
- private static class VectorSimilarityScorer extends Scorer {
209
+ private static class VectorSimilarityScorerSupplier extends ScorerSupplier {
201
210
final DocIdSetIterator iterator ;
202
211
final float [] cachedScore ;
203
212
204
- VectorSimilarityScorer (DocIdSetIterator iterator , float [] cachedScore ) {
213
+ VectorSimilarityScorerSupplier (DocIdSetIterator iterator , float [] cachedScore ) {
205
214
this .iterator = iterator ;
206
215
this .cachedScore = cachedScore ;
207
216
}
208
217
209
- static VectorSimilarityScorer fromScoreDocs (Weight weight , float boost , ScoreDoc [] scoreDocs ) {
218
+ static VectorSimilarityScorerSupplier fromScoreDocs (float boost , ScoreDoc [] scoreDocs ) {
219
+ if (scoreDocs .length == 0 ) {
220
+ return null ;
221
+ }
222
+
210
223
// Sort in ascending order of docid
211
224
Arrays .sort (scoreDocs , Comparator .comparingInt (scoreDoc -> scoreDoc .doc ));
212
225
@@ -252,18 +265,15 @@ public long cost() {
252
265
}
253
266
};
254
267
255
- return new VectorSimilarityScorer (iterator , cachedScore );
268
+ return new VectorSimilarityScorerSupplier (iterator , cachedScore );
256
269
}
257
270
258
- static VectorSimilarityScorer fromAcceptDocs (
259
- Weight weight ,
260
- float boost ,
261
- VectorScorer scorer ,
262
- DocIdSetIterator acceptDocs ,
263
- float threshold ) {
271
+ static VectorSimilarityScorerSupplier fromAcceptDocs (
272
+ float boost , VectorScorer scorer , DocIdSetIterator acceptDocs , float threshold ) {
264
273
if (scorer == null ) {
265
274
return null ;
266
275
}
276
+
267
277
float [] cachedScore = new float [1 ];
268
278
DocIdSetIterator vectorIterator = scorer .iterator ();
269
279
DocIdSetIterator conjunction =
@@ -281,27 +291,37 @@ protected boolean match(int doc) throws IOException {
281
291
}
282
292
};
283
293
284
- return new VectorSimilarityScorer (iterator , cachedScore );
294
+ return new VectorSimilarityScorerSupplier (iterator , cachedScore );
285
295
}
286
296
287
297
@ Override
288
- public int docID () {
289
- return iterator .docID ();
290
- }
298
+ public Scorer get (long leadCost ) {
299
+ return new Scorer () {
300
+ @ Override
301
+ public int docID () {
302
+ return iterator .docID ();
303
+ }
291
304
292
- @ Override
293
- public DocIdSetIterator iterator () {
294
- return iterator ;
295
- }
305
+ @ Override
306
+ public DocIdSetIterator iterator () {
307
+ return iterator ;
308
+ }
296
309
297
- @ Override
298
- public float getMaxScore (int upTo ) {
299
- return Float .POSITIVE_INFINITY ;
310
+ @ Override
311
+ public float getMaxScore (int upTo ) {
312
+ return Float .POSITIVE_INFINITY ;
313
+ }
314
+
315
+ @ Override
316
+ public float score () {
317
+ return cachedScore [0 ];
318
+ }
319
+ };
300
320
}
301
321
302
322
@ Override
303
- public float score () {
304
- return cachedScore [ 0 ] ;
323
+ public long cost () {
324
+ return iterator . cost () ;
305
325
}
306
326
}
307
327
}
0 commit comments