Skip to content

Commit 78de179

Browse files
foxstariusKennyLindahlAndreas Franzon
authored
feat: Add kNN Query (#198)
kNN search was added to Elasticsearch in v8.0 Co-authored-by: kennylindahl <[email protected]> Co-authored-by: Andreas Franzon <[email protected]>
1 parent 8f73d34 commit 78de179

File tree

7 files changed

+485
-11
lines changed

7 files changed

+485
-11
lines changed

src/core/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ exports.Aggregation = require('./aggregation');
88

99
exports.Query = require('./query');
1010

11+
exports.KNN = require('./knn');
12+
1113
exports.Suggester = require('./suggester');
1214

1315
exports.Script = require('./script');

src/core/knn.js

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
'use strict';
2+
3+
const { recursiveToJSON, checkType } = require('./util');
4+
const Query = require('./query');
5+
6+
/**
7+
* Class representing a k-Nearest Neighbors (k-NN) query.
8+
* This class extends the Query class to support the specifics of k-NN search, including setting up the field,
9+
* query vector, number of neighbors (k), and number of candidates.
10+
*
11+
* @example
12+
* const qry = esb.kNN('my_field', 100, 1000).vector([1,2,3]);
13+
* const qry = esb.kNN('my_field', 100, 1000).queryVectorBuilder('model_123', 'Sample model text');
14+
*
15+
* NOTE: kNN search was added to Elasticsearch in v8.0
16+
*
17+
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html)
18+
*/
19+
class KNN {
20+
// eslint-disable-next-line require-jsdoc
21+
constructor(field, k, numCandidates) {
22+
if (k > numCandidates)
23+
throw new Error('KNN numCandidates cannot be less than k');
24+
this._body = {};
25+
this._body.field = field;
26+
this._body.k = k;
27+
this._body.filter = [];
28+
this._body.num_candidates = numCandidates;
29+
}
30+
31+
/**
32+
* Sets the query vector for the k-NN search.
33+
* @param {Array<number>} vector - The query vector.
34+
* @returns {KNN} Returns the instance of KNN for method chaining.
35+
*/
36+
queryVector(vector) {
37+
if (this._body.query_vector_builder)
38+
throw new Error(
39+
'cannot provide both query_vector_builder and query_vector'
40+
);
41+
this._body.query_vector = vector;
42+
return this;
43+
}
44+
45+
/**
46+
* Sets the query vector builder for the k-NN search.
47+
* This method configures a query vector builder using a specified model ID and model text.
48+
* It's important to note that either a direct query vector or a query vector builder can be
49+
* provided, but not both.
50+
*
51+
* @param {string} modelId - The ID of the model to be used for generating the query vector.
52+
* @param {string} modelText - The text input based on which the query vector is generated.
53+
* @returns {KNN} Returns the instance of KNN for method chaining.
54+
* @throws {Error} Throws an error if both query_vector_builder and query_vector are provided.
55+
*
56+
* @example
57+
* let knn = new esb.KNN().queryVectorBuilder('model_123', 'Sample model text');
58+
*/
59+
queryVectorBuilder(modelId, modelText) {
60+
if (this._body.query_vector)
61+
throw new Error(
62+
'cannot provide both query_vector_builder and query_vector'
63+
);
64+
this._body.query_vector_builder = {
65+
text_embeddings: {
66+
model_id: modelId,
67+
model_text: modelText
68+
}
69+
};
70+
return this;
71+
}
72+
73+
/**
74+
* Adds one or more filter queries to the k-NN search.
75+
*
76+
* This method is designed to apply filters to the k-NN search. It accepts either a single
77+
* query or an array of queries. Each query acts as a filter, refining the search results
78+
* according to the specified conditions. These queries must be instances of the `Query` class.
79+
* If any provided query is not an instance of `Query`, a TypeError is thrown.
80+
*
81+
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
82+
* @returns {KNN} Returns `this` to allow method chaining.
83+
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
84+
*
85+
* @example
86+
* let knn = new esb.KNN().filter(new esb.TermQuery('field', 'value')); // Applying a single filter query
87+
*
88+
* @example
89+
* let knn = new esb.KNN().filter([
90+
* new esb.TermQuery('field1', 'value1'),
91+
* new esb.TermQuery('field2', 'value2')
92+
* ]); // Applying multiple filter queries
93+
*/
94+
filter(queries) {
95+
const queryArray = Array.isArray(queries) ? queries : [queries];
96+
queryArray.forEach(query => {
97+
checkType(query, Query);
98+
this._body.filter.push(query);
99+
});
100+
return this;
101+
}
102+
103+
/**
104+
* Sets the field to perform the k-NN search on.
105+
* @param {number} boost - The number of the boost
106+
* @returns {KNN} Returns the instance of KNN for method chaining.
107+
*/
108+
boost(boost) {
109+
this._body.boost = boost;
110+
return this;
111+
}
112+
113+
/**
114+
* Sets the field to perform the k-NN search on.
115+
* @param {number} similarity - The number of the similarity
116+
* @returns {KNN} Returns the instance of KNN for method chaining.
117+
*/
118+
similarity(similarity) {
119+
this._body.similarity = similarity;
120+
return this;
121+
}
122+
123+
/**
124+
* Override default `toJSON` to return DSL representation for the `query`
125+
*
126+
* @override
127+
* @returns {Object} returns an Object which maps to the elasticsearch query DSL
128+
*/
129+
toJSON() {
130+
if (!this._body.query_vector && !this._body.query_vector_builder)
131+
throw new Error(
132+
'either query_vector_builder or query_vector must be provided'
133+
);
134+
return recursiveToJSON(this._body);
135+
}
136+
}
137+
138+
module.exports = KNN;

src/core/request-body-search.js

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ const Query = require('./query'),
1010
Rescore = require('./rescore'),
1111
Sort = require('./sort'),
1212
Highlight = require('./highlight'),
13-
InnerHits = require('./inner-hits');
13+
InnerHits = require('./inner-hits'),
14+
KNN = require('./knn');
1415

1516
const { checkType, setDefault, recursiveToJSON } = require('./util');
1617
const RuntimeField = require('./runtime-field');
@@ -70,6 +71,7 @@ class RequestBodySearch {
7071
constructor() {
7172
// Maybe accept some optional parameter?
7273
this._body = {};
74+
this._knn = [];
7375
this._aggs = [];
7476
this._suggests = [];
7577
this._suggestText = null;
@@ -88,6 +90,21 @@ class RequestBodySearch {
8890
return this;
8991
}
9092

93+
/**
94+
* Sets knn on the search request body.
95+
*
96+
* @param {Knn|Knn[]} knn
97+
* @returns {RequestBodySearch} returns `this` so that calls can be chained.
98+
*/
99+
kNN(knn) {
100+
const knns = Array.isArray(knn) ? knn : [knn];
101+
knns.forEach(_knn => {
102+
checkType(_knn, KNN);
103+
this._knn.push(_knn);
104+
});
105+
return this;
106+
}
107+
91108
/**
92109
* Sets aggregation on the request body.
93110
* Alias for method `aggregation`
@@ -867,6 +884,12 @@ class RequestBodySearch {
867884
toJSON() {
868885
const dsl = recursiveToJSON(this._body);
869886

887+
if (!isEmpty(this._knn))
888+
dsl.knn =
889+
this._knn.length == 1
890+
? recMerge(this._knn)
891+
: this._knn.map(knn => recursiveToJSON(knn));
892+
870893
if (!isEmpty(this._aggs)) dsl.aggs = recMerge(this._aggs);
871894

872895
if (!isEmpty(this._suggests) || !isNil(this._suggestText)) {

src/index.d.ts

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ declare namespace esb {
1818
*/
1919
query(query: Query): this;
2020

21+
/**
22+
* Sets knn on the request body.
23+
*
24+
* @param {KNN|KNN[]} knn
25+
*/
26+
kNN(knn: KNN | KNN[]): this;
27+
2128
/**
2229
* Sets aggregation on the request body.
2330
* Alias for method `aggregation`
@@ -3141,7 +3148,7 @@ declare namespace esb {
31413148

31423149
/**
31433150
* Sets the script used to compute the score of documents returned by the query.
3144-
*
3151+
*
31453152
* @param {Script} script A valid `Script` object
31463153
*/
31473154
script(script: Script): this;
@@ -3761,6 +3768,84 @@ declare namespace esb {
37613768
spanQry?: SpanQueryBase
37623769
): SpanFieldMaskingQuery;
37633770

3771+
/**
3772+
* Knn performs k-nearest neighbor (KNN) searches.
3773+
* This class allows configuring the KNN search with various parameters such as field, query vector,
3774+
* number of nearest neighbors (k), number of candidates, boost factor, and similarity metric.
3775+
*
3776+
* NOTE: Only available in Elasticsearch v8.0+
3777+
*/
3778+
export class KNN {
3779+
/**
3780+
* Creates an instance of Knn, initializing the internal state for the k-NN search.
3781+
*
3782+
* @param {string} field - (Optional) The field against which to perform the k-NN search.
3783+
* @param {number} k - (Optional) The number of nearest neighbors to retrieve.
3784+
* @param {number} numCandidates - (Optional) The number of candidate neighbors to consider during the search.
3785+
* @throws {Error} If the number of candidates (numCandidates) is less than the number of neighbors (k).
3786+
*/
3787+
constructor(field: string, k: number, numCandidates: number);
3788+
3789+
/**
3790+
* Sets the query vector for the KNN search, an array of numbers representing the reference point.
3791+
*
3792+
* @param {number[]} vector
3793+
*/
3794+
queryVector(vector: number[]): this;
3795+
3796+
/**
3797+
* Sets the query vector builder for the k-NN search.
3798+
* This method configures a query vector builder using a specified model ID and model text.
3799+
* Note that either a direct query vector or a query vector builder can be provided, but not both.
3800+
*
3801+
* @param {string} modelId - The ID of the model used for generating the query vector.
3802+
* @param {string} modelText - The text input based on which the query vector is generated.
3803+
* @returns {KNN} Returns the instance of Knn for method chaining.
3804+
* @throws {Error} If both query_vector_builder and query_vector are provided.
3805+
*/
3806+
queryVectorBuilder(modelId: string, modelText: string): this;
3807+
3808+
/**
3809+
* Adds one or more filter queries to the k-NN search.
3810+
* This method is designed to apply filters to the k-NN search. It accepts either a single
3811+
* query or an array of queries. Each query acts as a filter, refining the search results
3812+
* according to the specified conditions. These queries must be instances of the `Query` class.
3813+
*
3814+
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
3815+
* @returns {KNN} Returns `this` to allow method chaining.
3816+
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
3817+
*/
3818+
filter(queries: Query | Query[]): this;
3819+
3820+
/**
3821+
* Applies a boost factor to the query to influence the relevance score of returned documents.
3822+
*
3823+
* @param {number} boost
3824+
*/
3825+
boost(boost: number): this;
3826+
3827+
/**
3828+
* Sets the similarity metric used in the KNN algorithm to calculate similarity.
3829+
*
3830+
* @param {number} similarity
3831+
*/
3832+
similarity(similarity: number): this;
3833+
3834+
/**
3835+
* Override default `toJSON` to return DSL representation for the `query`
3836+
*
3837+
* @override
3838+
*/
3839+
toJSON(): object;
3840+
}
3841+
3842+
/**
3843+
* Factory function to instantiate a new Knn object.
3844+
*
3845+
* @returns {KNN}
3846+
*/
3847+
export function kNN(field: string, k: number, numCandidates: number): KNN;
3848+
37643849
/**
37653850
* Base class implementation for all aggregation types.
37663851
*
@@ -3913,9 +3998,9 @@ declare namespace esb {
39133998
/**
39143999
* A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents.
39154000
* These values can be extracted either from specific numeric fields in the documents.
3916-
*
4001+
*
39174002
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html)
3918-
*
4003+
*
39194004
* Added in Elasticsearch v6.4.0
39204005
* [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html)
39214006
*
@@ -3929,7 +4014,7 @@ declare namespace esb {
39294014

39304015
/**
39314016
* Sets the value
3932-
*
4017+
*
39334018
* @param {string | Script} value Field name or script to be used as the value
39344019
* @param {number=} missing Sets the missing parameter which defines how documents
39354020
* that are missing a value should be treated.
@@ -3939,7 +4024,7 @@ declare namespace esb {
39394024

39404025
/**
39414026
* Sets the weight
3942-
*
4027+
*
39434028
* @param {string | Script} weight Field name or script to be used as the weight
39444029
* @param {number=} missing Sets the missing parameter which defines how documents
39454030
* that are missing a value should be treated.
@@ -3969,9 +4054,9 @@ declare namespace esb {
39694054
/**
39704055
* A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents.
39714056
* These values can be extracted either from specific numeric fields in the documents.
3972-
*
4057+
*
39734058
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html)
3974-
*
4059+
*
39754060
* Added in Elasticsearch v6.4.0
39764061
* [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html)
39774062
*
@@ -8922,15 +9007,15 @@ declare namespace esb {
89229007

89239008
/**
89249009
* Sets the type of the runtime field.
8925-
*
9010+
*
89269011
* @param {string} type One of `boolean`, `composite`, `date`, `double`, `geo_point`, `ip`, `keyword`, `long`, `lookup`.
89279012
* @returns {void}
89289013
*/
89299014
type(type: 'boolean' | 'composite' | 'date' | 'double' | 'geo_point' | 'ip' | 'keyword' | 'long' | 'lookup'): void;
89309015

89319016
/**
89329017
* Sets the source of the script.
8933-
*
9018+
*
89349019
* @param {string} script
89359020
* @returns {void}
89369021
*/

src/index.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const {
1515
RuntimeField,
1616
SearchTemplate,
1717
Query,
18+
KNN,
1819
util: { constructorWrapper }
1920
} = require('./core');
2021

@@ -343,6 +344,13 @@ exports.spanWithinQuery = constructorWrapper(SpanWithinQuery);
343344

344345
exports.SpanFieldMaskingQuery = SpanFieldMaskingQuery;
345346
exports.spanFieldMaskingQuery = constructorWrapper(SpanFieldMaskingQuery);
347+
348+
/* ============ ============ ============ */
349+
/* ======== KNN ======== */
350+
/* ============ ============ ============ */
351+
exports.KNN = KNN;
352+
exports.kNN = constructorWrapper(KNN);
353+
346354
/* ============ ============ ============ */
347355
/* ======== Metrics Aggregations ======== */
348356
/* ============ ============ ============ */

0 commit comments

Comments
 (0)