Skip to content

Commit 4d2982f

Browse files
authored
Merge branch 'main' into tvclean0
2 parents 8924385 + fd9f0ff commit 4d2982f

File tree

8 files changed

+227
-32
lines changed

8 files changed

+227
-32
lines changed

docs/changelog/138685.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138685
2+
summary: Introduce an adaptive HNSW Patience collector
3+
area: Vector Search
4+
type: enhancement
5+
issues: []
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.search.HnswQueueSaturationCollector;
13+
import org.apache.lucene.search.KnnCollector;
14+
15+
/**
16+
* A {@link KnnCollector.Decorator} extending {@link HnswQueueSaturationCollector}
17+
* that adaptively early-exits HNSW search using an online-estimated discovery rate,
18+
* rolling mean/variance, and adaptive patience threshold.
19+
* It tracks smoothed discovery rate (how many new neighbors are collected per candidate),
20+
* maintains a rolling mean and variance of the rate (using Welford's algorithm).
21+
* Those are used to define an adaptive saturation threshold = mean + looseness * stddev
22+
* and adaptive patience = patience-scaling / (1 + stddev).
23+
* Adaptive patience scales inversely with volatility (stddev) and looseness.
24+
* Patience-scaling defines patience order of magnitude.
25+
* Saturation happens when the discovery rate is lower than the adaptive saturation threshold.
26+
* The collector early exits once saturation persists for longer than adaptive patience.
27+
*/
28+
public class AdaptiveHnswQueueSaturationCollector extends HnswQueueSaturationCollector {
29+
30+
private static final float DEFAULT_DISCOVERY_RATE_SMOOTHING = 0.9f;
31+
private static final float DEFAULT_THRESHOLD_LOOSENESS = 0.01f;
32+
private static final float DEFAULT_PATIENCE_SCALING = 10.0f;
33+
34+
private final float discoveryRateSmoothing;
35+
private final float thresholdLooseness;
36+
private final float patienceScaling;
37+
38+
private boolean patienceFinished = false;
39+
40+
private int previousQueueSize = 0;
41+
private int currentQueueSize = 0;
42+
43+
private float smoothedDiscoveryRate = 0.0f;
44+
private float mean = 0.0f;
45+
private float m2 = 0.0f;
46+
private int samples = 0;
47+
private int steps = 0;
48+
49+
private int saturatedCount = 0;
50+
51+
private AdaptiveHnswQueueSaturationCollector(
52+
KnnCollector delegate,
53+
float discoveryRateSmoothing,
54+
float thresholdLooseness,
55+
float patienceScaling
56+
) {
57+
super(delegate, 0, 0);
58+
this.discoveryRateSmoothing = discoveryRateSmoothing;
59+
this.thresholdLooseness = thresholdLooseness;
60+
this.patienceScaling = patienceScaling;
61+
}
62+
63+
public AdaptiveHnswQueueSaturationCollector(KnnCollector delegate) {
64+
this(delegate, DEFAULT_DISCOVERY_RATE_SMOOTHING, DEFAULT_THRESHOLD_LOOSENESS, DEFAULT_PATIENCE_SCALING);
65+
}
66+
67+
@Override
68+
public boolean earlyTerminated() {
69+
return patienceFinished || super.earlyTerminated();
70+
}
71+
72+
@Override
73+
public boolean collect(int docId, float similarity) {
74+
boolean collected = super.collect(docId, similarity);
75+
if (collected) {
76+
currentQueueSize++;
77+
}
78+
steps++;
79+
return collected;
80+
}
81+
82+
@Override
83+
public void nextCandidate() {
84+
// rate of newly discovered neighbors for the current candidate
85+
float discoveryRate = (float) ((currentQueueSize - previousQueueSize) / (1e-9 + steps * k()));
86+
float rate = Math.max(0, discoveryRate);
87+
88+
// exponentially smoothed discovery rate
89+
smoothedDiscoveryRate = discoveryRateSmoothing * rate + (1 - discoveryRateSmoothing) * smoothedDiscoveryRate;
90+
91+
// update rolling mean and variance using Welford's algorithm
92+
samples++;
93+
float deltaMean = smoothedDiscoveryRate - mean;
94+
mean += deltaMean / samples;
95+
m2 += deltaMean * (smoothedDiscoveryRate - mean);
96+
double variance = samples > 1 ? m2 / (samples - 1) : 0.0;
97+
double stddev = Math.sqrt(variance);
98+
99+
// update adaptive threshold and patience
100+
double adaptiveThreshold = mean + thresholdLooseness * stddev;
101+
double adaptivePatience = patienceScaling / (1.0 + stddev);
102+
103+
if (smoothedDiscoveryRate < adaptiveThreshold) {
104+
saturatedCount++;
105+
} else {
106+
saturatedCount = 0;
107+
}
108+
109+
if (saturatedCount > adaptivePatience) {
110+
patienceFinished = true;
111+
}
112+
113+
previousQueueSize = currentQueueSize;
114+
steps = 0;
115+
}
116+
117+
}

server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,6 @@ public KnnSearchStrategy getStrategy() {
6969
@Override
7070
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
7171
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
72-
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
72+
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager) : knnCollectorManager;
7373
}
7474
}

server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,6 @@ public KnnSearchStrategy getStrategy() {
6464
@Override
6565
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
6666
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
67-
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
67+
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager) : knnCollectorManager;
6868
}
6969
}

server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,6 @@ public KnnSearchStrategy getStrategy() {
6464
@Override
6565
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
6666
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
67-
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
67+
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager) : knnCollectorManager;
6868
}
6969
}

server/src/main/java/org/elasticsearch/search/vectors/PatienceCollectorManager.java

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
package org.elasticsearch.search.vectors;
1111

1212
import org.apache.lucene.index.LeafReaderContext;
13-
import org.apache.lucene.search.HnswQueueSaturationCollector;
1413
import org.apache.lucene.search.KnnCollector;
1514
import org.apache.lucene.search.knn.KnnCollectorManager;
1615
import org.apache.lucene.search.knn.KnnSearchStrategy;
@@ -27,40 +26,27 @@
2726
* tested for termination.
2827
*/
2928
class PatienceCollectorManager implements KnnCollectorManager {
30-
private static final double DEFAULT_SATURATION_THRESHOLD = 0.995;
3129

3230
private final KnnCollectorManager knnCollectorManager;
33-
private final int patience;
34-
private final double saturationThreshold;
3531

36-
PatienceCollectorManager(KnnCollectorManager knnCollectorManager, int patience, double saturationThreshold) {
32+
PatienceCollectorManager(KnnCollectorManager knnCollectorManager) {
3733
this.knnCollectorManager = knnCollectorManager;
38-
this.patience = patience;
39-
this.saturationThreshold = saturationThreshold;
4034
}
4135

42-
static KnnCollectorManager wrap(KnnCollectorManager knnCollectorManager, int k) {
43-
return new PatienceCollectorManager(knnCollectorManager, Math.max(7, (int) (k * 0.3)), DEFAULT_SATURATION_THRESHOLD);
36+
static KnnCollectorManager wrap(KnnCollectorManager knnCollectorManager) {
37+
return new PatienceCollectorManager(knnCollectorManager);
4438
}
4539

4640
@Override
4741
public KnnCollector newCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) throws IOException {
48-
return new HnswQueueSaturationCollector(
49-
knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx),
50-
saturationThreshold,
51-
patience
52-
);
42+
return new AdaptiveHnswQueueSaturationCollector(knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx));
5343
}
5444

5545
@Override
5646
public KnnCollector newOptimisticCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx, int k)
5747
throws IOException {
5848
if (knnCollectorManager.isOptimistic()) {
59-
return new HnswQueueSaturationCollector(
60-
knnCollectorManager.newOptimisticCollector(visitLimit, searchStrategy, ctx, k),
61-
saturationThreshold,
62-
patience
63-
);
49+
return new AdaptiveHnswQueueSaturationCollector(knnCollectorManager.newOptimisticCollector(visitLimit, searchStrategy, ctx, k));
6450
} else {
6551
return null;
6652
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.search.KnnCollector;
13+
import org.apache.lucene.search.TopDocs;
14+
import org.apache.lucene.search.TopKnnCollector;
15+
import org.apache.lucene.search.TotalHits;
16+
import org.apache.lucene.tests.util.LuceneTestCase;
17+
18+
import java.util.Random;
19+
20+
public class AdaptiveHnswQueueSaturationCollectorTests extends LuceneTestCase {
21+
22+
public void testDelegate() {
23+
Random random = random();
24+
int numDocs = 100;
25+
int k = random.nextInt(1, 10);
26+
KnnCollector delegate = new TopKnnCollector(k, numDocs);
27+
AdaptiveHnswQueueSaturationCollector queueSaturationCollector = new AdaptiveHnswQueueSaturationCollector(delegate);
28+
for (int i = 0; i < random.nextInt(numDocs); i++) {
29+
queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f));
30+
}
31+
assertEquals(delegate.k(), queueSaturationCollector.k());
32+
assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount());
33+
assertEquals(delegate.visitLimit(), queueSaturationCollector.visitLimit());
34+
assertEquals(delegate.minCompetitiveSimilarity(), queueSaturationCollector.minCompetitiveSimilarity(), 1e-3);
35+
}
36+
37+
public void testEarlyExpectedExit() {
38+
int numDocs = 1000;
39+
int k = 10;
40+
KnnCollector delegate = new TopKnnCollector(k, numDocs);
41+
AdaptiveHnswQueueSaturationCollector queueSaturationCollector = new AdaptiveHnswQueueSaturationCollector(delegate);
42+
for (int i = 0; i < numDocs; i++) {
43+
queueSaturationCollector.collect(i, 1.0f - i * 1e-3f);
44+
if (i % 10 == 0) {
45+
queueSaturationCollector.nextCandidate();
46+
}
47+
if (queueSaturationCollector.earlyTerminated()) {
48+
assertEquals(110, i);
49+
break;
50+
}
51+
}
52+
}
53+
54+
public void testDelegateVsSaturateEarlyExit() {
55+
Random random = random();
56+
int numDocs = 10000;
57+
int k = random.nextInt(1, 100);
58+
KnnCollector delegate = new TopKnnCollector(k, numDocs);
59+
AdaptiveHnswQueueSaturationCollector queueSaturationCollector = new AdaptiveHnswQueueSaturationCollector(delegate);
60+
for (int i = 0; i < random.nextInt(numDocs); i++) {
61+
queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f));
62+
if (i % 10 == 0) {
63+
queueSaturationCollector.nextCandidate();
64+
}
65+
boolean earlyTerminatedSaturation = queueSaturationCollector.earlyTerminated();
66+
boolean earlyTerminatedDelegate = delegate.earlyTerminated();
67+
assertTrue(earlyTerminatedSaturation || earlyTerminatedDelegate == false);
68+
}
69+
}
70+
71+
public void testEarlyExitRelation() {
72+
Random random = random();
73+
int numDocs = 10000;
74+
int k = random.nextInt(1, 100);
75+
KnnCollector delegate = new TopKnnCollector(k, random.nextInt(numDocs));
76+
AdaptiveHnswQueueSaturationCollector queueSaturationCollector = new AdaptiveHnswQueueSaturationCollector(delegate);
77+
for (int i = 0; i < random.nextInt(numDocs); i++) {
78+
queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f));
79+
if (i % 10 == 0) {
80+
queueSaturationCollector.nextCandidate();
81+
}
82+
if (delegate.earlyTerminated()) {
83+
TopDocs topDocs = queueSaturationCollector.topDocs();
84+
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, topDocs.totalHits.relation());
85+
}
86+
if (queueSaturationCollector.earlyTerminated()) {
87+
TopDocs topDocs = queueSaturationCollector.topDocs();
88+
assertEquals(TotalHits.Relation.EQUAL_TO, topDocs.totalHits.relation());
89+
break;
90+
}
91+
}
92+
}
93+
94+
}

server/src/test/java/org/elasticsearch/search/vectors/PatienceCollectorManagerTests.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12-
import org.apache.lucene.search.HnswQueueSaturationCollector;
1312
import org.apache.lucene.search.knn.KnnSearchStrategy;
1413
import org.apache.lucene.search.knn.TopKnnCollectorManager;
1514
import org.elasticsearch.test.ESTestCase;
@@ -20,20 +19,14 @@ public class PatienceCollectorManagerTests extends ESTestCase {
2019

2120
public void testEarlyTermination() throws IOException {
2221
int k = randomIntBetween(1, 10);
23-
int patience = randomIntBetween(1, 2);
24-
double saturationThreshold = randomDoubleBetween(0.01, 0.02, true);
25-
PatienceCollectorManager patienceCollectorManager = new PatienceCollectorManager(
26-
new TopKnnCollectorManager(k, null),
27-
patience,
28-
saturationThreshold
29-
);
30-
HnswQueueSaturationCollector knnCollector = (HnswQueueSaturationCollector) patienceCollectorManager.newCollector(
22+
PatienceCollectorManager patienceCollectorManager = new PatienceCollectorManager(new TopKnnCollectorManager(k, null));
23+
AdaptiveHnswQueueSaturationCollector knnCollector = (AdaptiveHnswQueueSaturationCollector) patienceCollectorManager.newCollector(
3124
randomIntBetween(100, 1000),
3225
new KnnSearchStrategy.Hnsw(10),
3326
null
3427
);
3528

36-
for (int i = 0; i < 100; i++) {
29+
for (int i = 0; i < 1000; i++) {
3730
knnCollector.collect(i, 1 - i);
3831
if (i % 10 == 0) {
3932
knnCollector.nextCandidate();

0 commit comments

Comments
 (0)