Skip to content

Commit 25e934c

Browse files
committed
Add $rerank aggregation stage support
Adds builder support for the $rerank pipeline stage (MongoDB 8.3, Atlas only). API: - RerankQuery: query object with text shorthand or full Bson for future modalities - Aggregates.rerank(): 2 overloads (single path, multi path) - Scala wrappers and type alias in Aggregates.scala JAVA-6052
1 parent c0f9627 commit 25e934c

6 files changed

Lines changed: 345 additions & 6 deletions

File tree

driver-core/src/main/com/mongodb/client/model/Aggregates.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import static com.mongodb.internal.Iterables.concat;
6060
import static com.mongodb.internal.client.model.Util.sizeAtLeast;
6161
import static java.util.Arrays.asList;
62+
import static java.util.Collections.singletonList;
6263

6364
/**
6465
* Builders for aggregation pipeline stages.
@@ -1040,6 +1041,57 @@ public static Bson vectorSearch(
10401041
return new VectorSearchBson(path, queryVector, index, limit, options);
10411042
}
10421043

1044+
/**
1045+
* Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas.
1046+
* You may use the {@code $meta: "score"} expression to extract the relevance score
1047+
* assigned to each reranked document.
1048+
*
1049+
* @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)}.
1050+
* @param path The document field to send to the reranker.
1051+
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
1052+
* @param model The reranking model name. Accepted values:
1053+
* {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}.
1054+
* @return The {@code $rerank} pipeline stage.
1055+
* @mongodb.server.release 8.3
1056+
* @since 5.8
1057+
*/
1058+
@Beta(Reason.SERVER)
1059+
public static Bson rerank(
1060+
final RerankQuery query,
1061+
final String path,
1062+
final int numDocsToRerank,
1063+
final String model) {
1064+
notNull("path", path);
1065+
return rerank(query, singletonList(path), numDocsToRerank, model);
1066+
}
1067+
1068+
/**
1069+
* Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas.
1070+
* You may use the {@code $meta: "score"} expression to extract the relevance score
1071+
* assigned to each reranked document.
1072+
*
1073+
* @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)}.
1074+
* @param paths The document field(s) to send to the reranker.
1075+
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
1076+
* @param model The reranking model name. Accepted values:
1077+
* {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}.
1078+
* @return The {@code $rerank} pipeline stage.
1079+
* @mongodb.server.release 8.3
1080+
* @since 5.8
1081+
*/
1082+
@Beta(Reason.SERVER)
1083+
public static Bson rerank(
1084+
final RerankQuery query,
1085+
final List<String> paths,
1086+
final int numDocsToRerank,
1087+
final String model) {
1088+
notNull("query", query);
1089+
notNull("paths", paths);
1090+
isTrueArgument("paths must not be empty", !paths.isEmpty());
1091+
notNull("model", model);
1092+
return new RerankBson(query, paths, numDocsToRerank, model);
1093+
}
1094+
10431095
/**
10441096
* Creates an $unset pipeline stage that removes/excludes fields from documents
10451097
*
@@ -2290,4 +2342,38 @@ public String toString() {
22902342
+ '}';
22912343
}
22922344
}
2345+
2346+
private static class RerankBson implements Bson {
2347+
private final RerankQuery query;
2348+
private final List<String> paths;
2349+
private final int numDocsToRerank;
2350+
private final String model;
2351+
2352+
RerankBson(final RerankQuery query, final List<String> paths, final int numDocsToRerank,
2353+
final String model) {
2354+
this.query = query;
2355+
this.paths = paths;
2356+
this.numDocsToRerank = numDocsToRerank;
2357+
this.model = model;
2358+
}
2359+
2360+
@Override
2361+
public <TDocument> BsonDocument toBsonDocument(final Class<TDocument> documentClass, final CodecRegistry codecRegistry) {
2362+
Document specificationDoc = new Document("query", query)
2363+
.append("path", paths.size() == 1 ? paths.get(0) : paths)
2364+
.append("numDocsToRerank", numDocsToRerank)
2365+
.append("model", model);
2366+
return new Document("$rerank", specificationDoc).toBsonDocument(documentClass, codecRegistry);
2367+
}
2368+
2369+
@Override
2370+
public String toString() {
2371+
return "Stage{name=$rerank"
2372+
+ ", query=" + query
2373+
+ ", paths=" + paths
2374+
+ ", numDocsToRerank=" + numDocsToRerank
2375+
+ ", model=" + model
2376+
+ '}';
2377+
}
2378+
}
22932379
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.client.model;
18+
19+
import org.bson.BsonDocument;
20+
import org.bson.BsonString;
21+
import org.bson.annotations.Beta;
22+
import org.bson.annotations.Reason;
23+
import org.bson.codecs.configuration.CodecRegistry;
24+
import org.bson.conversions.Bson;
25+
26+
import static com.mongodb.assertions.Assertions.notNull;
27+
28+
/**
29+
* Represents a query for the {@code $rerank} aggregation pipeline stage.
30+
* <p>
31+
* Use {@link #rerankQuery(String)} for a simple text query, or
32+
* {@link #rerankQuery(Bson)} to specify the full query document directly
33+
* (e.g., for future modalities like imageURL or videoURL).
34+
*
35+
* @mongodb.server.release 8.3
36+
* @since 5.8
37+
*/
38+
@Beta(Reason.SERVER)
39+
public final class RerankQuery implements Bson {
40+
private final Bson query;
41+
42+
private RerankQuery(final Bson query) {
43+
this.query = query;
44+
}
45+
46+
/**
47+
* Creates a rerank query with the specified text.
48+
* <p>
49+
* This is a convenience for {@code rerankQuery(new Document("text", text))}.
50+
*
51+
* @param text the query text to rerank against.
52+
* @return a new {@link RerankQuery}
53+
*/
54+
public static RerankQuery rerankQuery(final String text) {
55+
notNull("text", text);
56+
return new RerankQuery(new BsonDocument("text", new BsonString(text)));
57+
}
58+
59+
/**
60+
* Creates a rerank query from a full query document.
61+
* <p>
62+
* Use this overload for future query modalities (e.g., imageURL, videoURL)
63+
* or to pass additional fields alongside text.
64+
*
65+
* @param query the query document.
66+
* @return a new {@link RerankQuery}
67+
*/
68+
public static RerankQuery rerankQuery(final Bson query) {
69+
notNull("query", query);
70+
return new RerankQuery(query);
71+
}
72+
73+
@Override
74+
public <TDocument> BsonDocument toBsonDocument(final Class<TDocument> documentClass, final CodecRegistry codecRegistry) {
75+
return query.toBsonDocument(documentClass, codecRegistry);
76+
}
77+
78+
@Override
79+
public String toString() {
80+
return "RerankQuery{" + query + '}';
81+
}
82+
}

driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import org.junit.jupiter.params.provider.MethodSource;
3434

3535
import java.math.RoundingMode;
36-
import java.util.Arrays;
3736
import java.util.Collections;
3837
import java.util.List;
3938
import java.util.stream.Stream;
@@ -43,8 +42,10 @@
4342
import static com.mongodb.client.model.Accumulators.percentile;
4443
import static com.mongodb.client.model.Aggregates.geoNear;
4544
import static com.mongodb.client.model.Aggregates.group;
45+
import static com.mongodb.client.model.Aggregates.rerank;
4646
import static com.mongodb.client.model.Aggregates.unset;
4747
import static com.mongodb.client.model.Aggregates.vectorSearch;
48+
import static com.mongodb.client.model.RerankQuery.rerankQuery;
4849
import static com.mongodb.client.model.GeoNearOptions.geoNearOptions;
4950
import static com.mongodb.client.model.Sorts.ascending;
5051
import static com.mongodb.client.model.Windows.Bound.UNBOUNDED;
@@ -260,17 +261,17 @@ public void testDocuments() {
260261
"{$documents: [{a: 1, b: {$add: [1, 1]}}, {a: 3, b: 4}]}",
261262
stage);
262263

263-
List<Bson> pipeline = Arrays.asList(stage);
264+
List<Bson> pipeline = asList(stage);
264265
getCollectionHelper().aggregateDb(pipeline);
265266

266267
assertEquals(
267268
parseToList("[{a: 1, b: 2}, {a: 3, b: 4}]"),
268269
getCollectionHelper().aggregateDb(pipeline));
269270

270271
// accepts lists of Documents and BsonDocuments
271-
List<BsonDocument> documents = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}"));
272+
List<BsonDocument> documents = asList(BsonDocument.parse("{a: 1, b: 2}"));
272273
assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(documents));
273-
List<BsonDocument> bsonDocuments = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}"));
274+
List<BsonDocument> bsonDocuments = asList(BsonDocument.parse("{a: 1, b: 2}"));
274275
assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(bsonDocuments));
275276
}
276277

@@ -281,13 +282,13 @@ public void testDocumentsLookup() {
281282
getCollectionHelper().insertDocuments("[{_id: 1, a: 8}, {_id: 2, a: 9}]");
282283
Bson documentsStage = Aggregates.documents(asList(Document.parse("{a: 5}")));
283284

284-
Bson lookupStage = Aggregates.lookup(null, Arrays.asList(documentsStage), "added");
285+
Bson lookupStage = Aggregates.lookup(null, asList(documentsStage), "added");
285286
assertPipeline(
286287
"{'$lookup': {'pipeline': [{'$documents': [{'a': 5}]}], 'as': 'added'}}",
287288
lookupStage);
288289
assertEquals(
289290
parseToList("[{_id:1, a:8, added: [{a: 5}]}, {_id:2, a:9, added: [{a: 5}]}]"),
290-
getCollectionHelper().aggregate(Arrays.asList(lookupStage)));
291+
getCollectionHelper().aggregate(asList(lookupStage)));
291292
}
292293

293294
@Test
@@ -374,4 +375,82 @@ public void testExactVectorSearchWithQueryObject() {
374375
exactVectorSearchOptions()
375376
));
376377
}
378+
379+
@Test
380+
public void testRerankWithSinglePath() {
381+
assertPipeline(
382+
"{"
383+
+ " '$rerank': {"
384+
+ " 'query': {'text': 'machine learning tutorials'},"
385+
+ " 'path': 'content',"
386+
+ " 'numDocsToRerank': 25,"
387+
+ " 'model': 'rerank-2.5'"
388+
+ " }"
389+
+ "}",
390+
rerank(
391+
rerankQuery("machine learning tutorials"),
392+
"content",
393+
25,
394+
"rerank-2.5"
395+
));
396+
}
397+
398+
@Test
399+
public void testRerankWithMultiplePaths() {
400+
assertPipeline(
401+
"{"
402+
+ " '$rerank': {"
403+
+ " 'query': {'text': 'machine learning tutorials'},"
404+
+ " 'path': ['content', 'title'],"
405+
+ " 'numDocsToRerank': 50,"
406+
+ " 'model': 'rerank-2.5-lite'"
407+
+ " }"
408+
+ "}",
409+
rerank(
410+
rerankQuery("machine learning tutorials"),
411+
asList("content", "title"),
412+
50,
413+
"rerank-2.5-lite"
414+
));
415+
}
416+
417+
@Test
418+
public void testRerankWithBsonQuery() {
419+
assertPipeline(
420+
"{"
421+
+ " '$rerank': {"
422+
+ " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'},"
423+
+ " 'path': 'content',"
424+
+ " 'numDocsToRerank': 25,"
425+
+ " 'model': 'rerank-2.5'"
426+
+ " }"
427+
+ "}",
428+
rerank(
429+
rerankQuery(new Document("text", "machine learning tutorials")
430+
.append("imageURL", "https://example.com/img.png")),
431+
"content",
432+
25,
433+
"rerank-2.5"
434+
));
435+
}
436+
437+
@Test
438+
public void testRerankWithMultiplePathsAndBsonQuery() {
439+
assertPipeline(
440+
"{"
441+
+ " '$rerank': {"
442+
+ " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'},"
443+
+ " 'path': ['content', 'title'],"
444+
+ " 'numDocsToRerank': 100,"
445+
+ " 'model': 'rerank-2'"
446+
+ " }"
447+
+ "}",
448+
rerank(
449+
rerankQuery(new Document("text", "machine learning tutorials")
450+
.append("imageURL", "https://example.com/img.png")),
451+
asList("content", "title"),
452+
100,
453+
"rerank-2"
454+
));
455+
}
377456
}

driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import com.mongodb.client.model.search.FieldSearchPath
2222

2323
import scala.collection.JavaConverters._
2424
import com.mongodb.client.model.{ Aggregates => JAggregates }
25+
import com.mongodb.client.model.RerankQuery
2526
import org.mongodb.scala.MongoNamespace
2627
import org.mongodb.scala.bson.conversions.Bson
2728
import org.mongodb.scala.model.densify.{ DensifyOptions, DensifyRange }
@@ -746,6 +747,50 @@ object Aggregates {
746747
): Bson =
747748
JAggregates.vectorSearch(path, queryVector.asJava, index, limit, options)
748749

750+
/**
751+
* Creates a `\$rerank` pipeline stage supported by MongoDB Atlas.
752+
* You may use the `\$meta: "score"` expression to extract the relevance score
753+
* assigned to each reranked document.
754+
*
755+
* @param query The query to rerank against.
756+
* @param path The document field to send to the reranker.
757+
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
758+
* @param model The reranking model name.
759+
* @return The `\$rerank` pipeline stage.
760+
* @note Requires MongoDB on Atlas 8.3 or greater
761+
* @since 5.8
762+
*/
763+
@Beta(Array(Reason.SERVER))
764+
def rerank(
765+
query: RerankQuery,
766+
path: String,
767+
numDocsToRerank: Int,
768+
model: String
769+
): Bson =
770+
JAggregates.rerank(query, path, numDocsToRerank, model)
771+
772+
/**
773+
* Creates a `\$rerank` pipeline stage supported by MongoDB Atlas.
774+
* You may use the `\$meta: "score"` expression to extract the relevance score
775+
* assigned to each reranked document.
776+
*
777+
* @param query The query to rerank against.
778+
* @param paths The document field(s) to send to the reranker.
779+
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
780+
* @param model The reranking model name.
781+
* @return The `\$rerank` pipeline stage.
782+
* @note Requires MongoDB on Atlas 8.3 or greater
783+
* @since 5.8
784+
*/
785+
@Beta(Array(Reason.SERVER))
786+
def rerank(
787+
query: RerankQuery,
788+
paths: Seq[String],
789+
numDocsToRerank: Int,
790+
model: String
791+
): Bson =
792+
JAggregates.rerank(query, paths.toList.asJava, numDocsToRerank, model)
793+
749794
/**
750795
* Creates an `\$unset` pipeline stage that removes/excludes fields from documents
751796
*

driver-scala/src/main/scala/org/mongodb/scala/model/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ package object model {
987987

988988
type GeoNearOptions = com.mongodb.client.model.GeoNearOptions
989989

990+
type RerankQuery = com.mongodb.client.model.RerankQuery
991+
990992
/**
991993
* @see `QuantileMethod.approximate()`
992994
*/

0 commit comments

Comments
 (0)