Skip to content

Commit 3c4c401

Browse files
committed
Knn function minimal support
1 parent 6f4d011 commit 3c4c401

File tree

6 files changed

+307
-2
lines changed

6 files changed

+307
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plugin;
9+
10+
import org.elasticsearch.action.index.IndexRequestBuilder;
11+
import org.elasticsearch.cluster.metadata.IndexMetadata;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.xcontent.XContentBuilder;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
16+
import org.junit.Before;
17+
18+
import java.io.IOException;
19+
import java.util.ArrayList;
20+
import java.util.HashMap;
21+
import java.util.List;
22+
import java.util.Map;
23+
24+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
25+
26+
public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
27+
28+
private final Map<Integer, List<Float>> indexedVectors = new HashMap<>();
29+
30+
public void testKnn() {
31+
var query = """
32+
FROM test
33+
| WHERE knn(vector, [1.0, 2.0, 3.0])
34+
| KEEP id, floats
35+
""";
36+
37+
try (var resp = run(query)) {
38+
assertColumnNames(resp.columns(), List.of("id", "floats"));
39+
assertColumnTypes(resp.columns(), List.of("integer", "double"));
40+
}
41+
}
42+
43+
@Before
44+
public void setup() throws IOException {
45+
var indexName = "test";
46+
var client = client().admin().indices();
47+
XContentBuilder mapping = XContentFactory.jsonBuilder()
48+
.startObject()
49+
.startObject("properties")
50+
.startObject("id")
51+
.field("type", "integer")
52+
.endObject()
53+
.startObject("vector")
54+
.field("type", "dense_vector")
55+
.field("similarity", "l2_norm")
56+
.endObject()
57+
.startObject("floats")
58+
.field("type", "float")
59+
.endObject()
60+
.endObject()
61+
.endObject();
62+
63+
Settings.Builder settingsBuilder = Settings.builder()
64+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
65+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1);
66+
67+
var CreateRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
68+
assertAcked(CreateRequest);
69+
70+
int numDocs = randomIntBetween(10, 100);
71+
int numDims = 3;
72+
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
73+
float value = 0.0f;
74+
for (int i = 0; i < numDocs; i++) {
75+
List<Float> vector = new ArrayList<>(numDims);
76+
for (int j = 0; j < numDims; j++) {
77+
vector.add(value++);
78+
}
79+
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector);
80+
indexedVectors.put(i, vector);
81+
}
82+
83+
indexRandom(true, docs);
84+
}
85+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,12 @@ public enum Cap {
10441044
/**
10451045
* Support for dense_vector field type
10461046
*/
1047-
DENSE_VECTOR_SUPPORT(DENSE_VECTOR_FEATURE_FLAG);
1047+
DENSE_VECTOR_SUPPORT(DENSE_VECTOR_FEATURE_FLAG),
1048+
1049+
/**
1050+
* Support knn function
1051+
*/
1052+
KNN_FUNCTION(Build.current().isSnapshot());
10481053

10491054
private final boolean enabled;
10501055

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java

+10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.expression;
99

1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
1112
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
1213
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
1314
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
@@ -73,6 +74,7 @@
7374
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
7475
import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike;
7576
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
77+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
7678
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
7779
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
7880
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
@@ -105,6 +107,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
105107
entries.addAll(binaryComparisons());
106108
entries.addAll(fullText());
107109
entries.addAll(unaryScalars());
110+
entries.addAll(vector());
108111
return entries;
109112
}
110113

@@ -226,4 +229,11 @@ private static List<NamedWriteableRegistry.Entry> binaryComparisons() {
226229
private static List<NamedWriteableRegistry.Entry> fullText() {
227230
return FullTextWritables.getNamedWriteables();
228231
}
232+
233+
private static List<NamedWriteableRegistry.Entry> vector() {
234+
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
235+
return List.of(Knn.ENTRY);
236+
}
237+
return List.of();
238+
}
229239
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
160160
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
161161
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
162+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
162163
import org.elasticsearch.xpack.esql.parser.ParsingException;
163164
import org.elasticsearch.xpack.esql.session.Configuration;
164165

@@ -446,7 +447,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
446447
def(MaxOverTime.class, uni(MaxOverTime::new), "max_over_time"),
447448
def(AvgOverTime.class, uni(AvgOverTime::new), "avg_over_time"),
448449
def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"),
449-
def(Term.class, bi(Term::new), "term") } };
450+
def(Term.class, bi(Term::new), "term"),
451+
def(Knn.class, bi(Knn::new), "knn") } };
450452
}
451453

452454
public EsqlFunctionRegistry snapshotRegistry() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.vector;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
14+
import org.elasticsearch.xpack.esql.core.expression.Expression;
15+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
16+
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
17+
import org.elasticsearch.xpack.esql.core.expression.function.Function;
18+
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
19+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
20+
import org.elasticsearch.xpack.esql.core.tree.Source;
21+
import org.elasticsearch.xpack.esql.core.type.DataType;
22+
import org.elasticsearch.xpack.esql.core.util.Check;
23+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
24+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
25+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
26+
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
27+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
28+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
29+
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
30+
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
31+
32+
import java.io.IOException;
33+
import java.util.List;
34+
import java.util.Objects;
35+
36+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
37+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
38+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
39+
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
40+
import static org.elasticsearch.xpack.esql.expression.function.fulltext.Match.getNameFromFieldAttribute;
41+
42+
public class Knn extends Function implements TranslationAware {
43+
44+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
45+
46+
private final Expression field;
47+
private final Expression query;
48+
49+
@FunctionInfo(
50+
returnType = "boolean",
51+
preview = true,
52+
description = """
53+
Finds the k nearest vectors to a query vector, as measured by a similarity metric.
54+
knn function finds nearest vectors through approximate search on indexed dense_vectors
55+
""",
56+
appliesTo = {
57+
@FunctionAppliesTo(
58+
lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT
59+
) }
60+
)
61+
public Knn(Source source, Expression field, Expression query) {
62+
super(source, List.of(field, query));
63+
this.field = field;
64+
this.query = query;
65+
}
66+
67+
public Expression field() {
68+
return field;
69+
}
70+
71+
public Expression query() {
72+
return query;
73+
}
74+
75+
@Override
76+
public DataType dataType() {
77+
return DataType.BOOLEAN;
78+
}
79+
80+
@Override
81+
protected final TypeResolution resolveType() {
82+
if (childrenResolved() == false) {
83+
return new TypeResolution("Unresolved children");
84+
}
85+
86+
return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"))
87+
.and(TypeResolutions.isNumeric(query(), sourceText(), TypeResolutions.ParamOrdinal.SECOND));
88+
}
89+
90+
@Override
91+
public boolean translatable(LucenePushdownPredicates pushdownPredicates) {
92+
return true;
93+
}
94+
95+
@Override
96+
public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
97+
var fieldAttribute = Match.fieldAsFieldAttribute(field());
98+
99+
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
100+
String fieldName = getNameFromFieldAttribute(fieldAttribute);
101+
@SuppressWarnings("unchecked")
102+
List<Double> queryFolded = (List<Double>) query().fold(FoldContext.small() /* TODO remove me */);
103+
float[] queryAsFloats = new float[queryFolded.size()];
104+
for (int i = 0; i < queryFolded.size(); i++) {
105+
queryAsFloats[i] = queryFolded.get(i).floatValue();
106+
}
107+
return new KnnQuery(source(), fieldName, queryAsFloats);
108+
}
109+
110+
@Override
111+
public Expression replaceChildren(List<Expression> newChildren) {
112+
return new Knn(source(), newChildren.get(0), newChildren.get(1));
113+
}
114+
115+
@Override
116+
protected NodeInfo<? extends Expression> info() {
117+
return NodeInfo.create(this, Knn::new, field(), query());
118+
}
119+
120+
@Override
121+
public String getWriteableName() {
122+
return ENTRY.name;
123+
}
124+
125+
private static Knn readFrom(StreamInput in) throws IOException {
126+
Source source = Source.readFrom((PlanStreamInput) in);
127+
Expression field = in.readNamedWriteable(Expression.class);
128+
Expression query = in.readNamedWriteable(Expression.class);
129+
130+
return new Knn(source, field, query);
131+
}
132+
133+
@Override
134+
public void writeTo(StreamOutput out) throws IOException {
135+
source().writeTo(out);
136+
out.writeNamedWriteable(field());
137+
out.writeNamedWriteable(query());
138+
}
139+
140+
@Override
141+
public boolean equals(Object o) {
142+
if (o == null || getClass() != o.getClass()) return false;
143+
if (super.equals(o) == false) return false;
144+
Knn knn = (Knn) o;
145+
return Objects.equals(field, knn.field) && Objects.equals(query, knn.query);
146+
}
147+
148+
@Override
149+
public int hashCode() {
150+
return Objects.hash(super.hashCode(), field, query);
151+
}
152+
153+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.querydsl.query;
9+
10+
import org.elasticsearch.index.query.QueryBuilder;
11+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
12+
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
13+
import org.elasticsearch.xpack.esql.core.tree.Source;
14+
15+
import java.util.Arrays;
16+
import java.util.Objects;
17+
18+
public class KnnQuery extends Query {
19+
20+
private final String field;
21+
private final float[] query;
22+
23+
public KnnQuery(Source source, String field, float[] query) {
24+
super(source);
25+
this.field = field;
26+
this.query = query;
27+
}
28+
29+
@Override
30+
protected QueryBuilder asBuilder() {
31+
return new KnnVectorQueryBuilder(field, query, null, null, null, null);
32+
}
33+
34+
@Override
35+
protected String innerToString() {
36+
return "knn(" + field + ", " + Arrays.toString(query) + ")";
37+
}
38+
39+
@Override
40+
public boolean equals(Object o) {
41+
if (!(o instanceof KnnQuery knnQuery)) return false;
42+
if (super.equals(o) == false) return false;
43+
return Objects.equals(field, knnQuery.field) && Objects.deepEquals(query, knnQuery.query);
44+
}
45+
46+
@Override
47+
public int hashCode() {
48+
return Objects.hash(super.hashCode(), field, Arrays.hashCode(query));
49+
}
50+
}

0 commit comments

Comments
 (0)