diff --git a/pom.xml b/pom.xml index ffbe2b8e19..9eb477077f 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT pom Spring Data MongoDB @@ -26,7 +26,7 @@ multi spring-data-mongodb - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-RESULT-SNAPSHOT 5.4.0 1.19 diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index fc88571622..6e0e6b99f4 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index ad3c1338ec..3e7ccd09e8 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 5c7df76cc5..5ed7f9b8a3 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -1098,7 +1098,7 @@ public GeoResults geoNear(NearQuery near, Class domainType, String col result.add(geoResult); } - Distance avgDistance = new Distance( + Distance avgDistance = Distance.of( result.size() == 0 ? 0 : aggregate.divide(new BigDecimal(result.size()), RoundingMode.HALF_UP).doubleValue(), near.getMetric()); @@ -2654,7 +2654,9 @@ protected List doFind(String collectionName, if (LOGGER.isDebugEnabled()) { - Document mappedSort = preparer instanceof SortingQueryCursorPreparer sqcp ? getMappedSortObject(sqcp.getSortObject(), entity) : null; + Document mappedSort = preparer instanceof SortingQueryCursorPreparer sqcp + ? getMappedSortObject(sqcp.getSortObject(), entity) + : null; LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s", serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), entityClass, collectionName)); @@ -3553,7 +3555,7 @@ public GeoResult doWith(Document object) { T doWith = delegate.doWith(object); - return new GeoResult<>(doWith, new Distance(distance, metric)); + return new GeoResult<>(doWith, Distance.of(distance, metric)); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 325a96dc85..e263735187 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -3227,7 +3227,7 @@ public Mono> doWith(Document object) { double distance = getDistance(object); - return delegate.doWith(object).map(doWith -> new GeoResult<>(doWith, new Distance(distance, metric))); + return delegate.doWith(object).map(doWith -> new GeoResult<>(doWith, Distance.of(distance, metric))); } double getDistance(Document object) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java index 40966bcf3d..f06803997b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java @@ -22,8 +22,10 @@ import java.util.function.Predicate; import org.bson.Document; +import org.jspecify.annotations.Nullable; import org.springframework.lang.Contract; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}. @@ -82,6 +84,14 @@ public List getOperations() { return Collections.unmodifiableList(pipeline); } + public @Nullable AggregationOperation firstOperation() { + return CollectionUtils.firstElement(pipeline); + } + + public @Nullable AggregationOperation lastOperation() { + return CollectionUtils.lastElement(pipeline); + } + List toDocuments(AggregationOperationContext context) { verify(); @@ -97,8 +107,8 @@ public boolean isOutOrMerge() { return false; } - AggregationOperation operation = pipeline.get(pipeline.size() - 1); - return isOut(operation) || isMerge(operation); + AggregationOperation operation = lastOperation(); + return operation != null && (isOut(operation) || isMerge(operation)); } void verify() { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java index f5a861cddd..7b27739229 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java @@ -105,4 +105,5 @@ public Document getRawResults() { Object object = rawResults.get("serverUsed"); return object instanceof String stringValue ? stringValue : null; } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 02b805d5ed..85952d8f39 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -356,6 +356,7 @@ public SortArray sort(Sort sort) { * @return new instance of {@link SortArray}. * @since 4.5 */ + @SuppressWarnings("NullAway") public SortArray sort(Direction direction) { if (usesFieldRef()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java index 2c74900bc5..95f1c5b4d2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java @@ -237,7 +237,10 @@ public Document toDocument(AggregationOperationContext context) { } $vectorSearch.append("index", indexName); - $vectorSearch.append("limit", limit.max()); + + if(limit.isLimited()) { // TODO: exception or pass it on? + $vectorSearch.append("limit", limit.max()); + } if (numCandidates != null) { $vectorSearch.append("numCandidates", numCandidates); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java index ae73ab68bd..b595ab688f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java @@ -270,7 +270,7 @@ enum DocumentToCircleConverter implements Converter { Assert.notNull(center, "Center must not be null"); Assert.notNull(radius, "Radius must not be null"); - Distance distance = new Distance(toPrimitiveDoubleValue(radius)); + Distance distance = Distance.of(toPrimitiveDoubleValue(radius)); if (source.containsKey("metric")) { @@ -335,7 +335,7 @@ enum DocumentToSphereConverter implements Converter { Assert.notNull(center, "Center must not be null"); Assert.notNull(radius, "Radius must not be null"); - Distance distance = new Distance(toPrimitiveDoubleValue(radius)); + Distance distance = Distance.of(toPrimitiveDoubleValue(radius)); if (source.containsKey("metric")) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java index 47be645869..d3ca840d6b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java @@ -63,7 +63,7 @@ public Sphere(Point center, Distance radius) { * @param radius */ public Sphere(Point center, double radius) { - this(center, new Distance(radius)); + this(center, Distance.of(radius)); } /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java index 3b3a520bc3..6b4d9b9e9b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java @@ -29,6 +29,7 @@ import org.bson.types.Decimal128; import org.bson.types.ObjectId; import org.bson.types.Symbol; + import org.springframework.data.mapping.model.SimpleTypeHolder; import com.mongodb.DBRef; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java index 6dad07b8cb..88d7dc5c1d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java @@ -19,6 +19,7 @@ import org.bson.Document; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Pageable; import org.springframework.data.geo.CustomMetric; import org.springframework.data.geo.Distance; @@ -329,7 +330,7 @@ public NearQuery with(Pageable pageable) { */ @Contract("_ -> this") public NearQuery maxDistance(double maxDistance) { - return maxDistance(new Distance(maxDistance, getMetric())); + return maxDistance(Distance.of(maxDistance, getMetric())); } /** @@ -345,7 +346,7 @@ public NearQuery maxDistance(double maxDistance, Metric metric) { Assert.notNull(metric, "Metric must not be null"); - return maxDistance(new Distance(maxDistance, metric)); + return maxDistance(Distance.of(maxDistance, metric)); } /** @@ -388,7 +389,7 @@ public NearQuery maxDistance(Distance distance) { */ @Contract("_ -> this") public NearQuery minDistance(double minDistance) { - return minDistance(new Distance(minDistance, getMetric())); + return minDistance(Distance.of(minDistance, getMetric())); } /** @@ -405,7 +406,7 @@ public NearQuery minDistance(double minDistance, Metric metric) { Assert.notNull(metric, "Metric must not be null"); - return minDistance(new Distance(minDistance, metric)); + return minDistance(Distance.of(minDistance, metric)); } /** @@ -611,7 +612,7 @@ public NearQuery withReadPreference(ReadPreference readPreference) { * Get the {@link ReadConcern} to use. Will return the underlying {@link #query(Query) queries} * {@link Query#getReadConcern() ReadConcern} if present or the one defined on the {@link NearQuery#readConcern} * itself. - * + * * @return can be {@literal null} if none set. * @since 4.1 * @see ReadConcernAware diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java new file mode 100644 index 0000000000..336889f719 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; + +/** + * Annotation to declare Vector Search queries directly on repository methods. Vector Search queries are used to search + * for similar documents based on vector embeddings typically returning + * {@link org.springframework.data.domain.SearchResults} and limited by either a + * {@link org.springframework.data.domain.Score} (within) or a {@link org.springframework.data.domain.Range} of scores + * (between). + *

+ * Vector search must define an index name using the {@link #indexName()} attribute. The index must be created in the + * MongoDB Atlas cluster before executing the query. Any misspelling of the index name will result in returning no + * results. + *

+ * When using pre-filters, you can either define {@link #filter()} or use query derivation to define the pre-filter. + * {@link org.springframework.data.domain.Vector} and distance parameters are considered once these are present. Vector + * search supports sorting and will consider {@link org.springframework.data.domain.Sort} parameters. + * + * @author Mark Paluch + * @since 5.0 + * @see org.springframework.data.domain.Score + * @see org.springframework.data.domain.Vector + * @see org.springframework.data.domain.SearchResults + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) +@Documented +@Query +@Hint +public @interface VectorSearch { + + /** + * Configuration whether to use + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN} or + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ENN} for the search. + * + * @return the search type to use. + */ + VectorSearchOperation.SearchType searchType() default VectorSearchOperation.SearchType.DEFAULT; + + /** + * Name of the Atlas Vector Search index to use. Atlas Vector Search doesn't return results if you misspell the index + * name or if the specified index doesn't already exist on the cluster. + * + * @return name of the Atlas Vector Search index to use. + */ + @AliasFor(annotation = Hint.class, value = "indexName") + String indexName(); + + /** + * Indexed vector type field to search. This is defaulted from the domain model using the first Vector property found. + * + * @return an empty String by default. + */ + String path() default ""; + + /** + * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Supports Value Expressions. Alias + * for {@link VectorSearch#filter}. + * + * @return an empty String by default. + */ + @AliasFor(annotation = Query.class) + String value() default ""; + + /** + * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Supports Value Expressions. Alias + * for {@link VectorSearch#value}. + * + * @return an empty String by default. + */ + @AliasFor(annotation = Query.class, value = "value") + String filter() default ""; + + /** + * Number of documents to return in the results. This value can't exceed the value of {@link #numCandidates} if you + * specify {@link #numCandidates}. Limit accepts Value Expressions. A Vector Search method cannot define both, + * {@code limit()} and a {@link org.springframework.data.domain.Limit} parameter. Supports Value Expressions. + * + * @return number of documents to return in the results. + */ + String limit() default ""; + + /** + * Number of nearest neighbors to use during the search. Value must be less than or equal to ({@code <=}) + * {@code 10000}. You can't specify a number less than the {@link #limit() number of documents to return}. We + * recommend that you specify a number at least {@code 20} times higher than the {@link #limit() number of documents + * to return} to increase accuracy. + *

+ * This over-request pattern is the recommended way to trade off latency and recall in your ANN searches, and we + * recommend tuning this parameter based on your specific dataset size and query requirements. Required if the query + * uses + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN}/{@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#DEFAULT}. + * Supports Value Expressions. + * + * @return number of nearest neighbors to use during the search. + */ + String numCandidates() default ""; + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 831d21bb44..17c19ad951 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -25,8 +25,10 @@ import org.jspecify.annotations.Nullable; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.convert.MongoCustomConversions; @@ -129,6 +131,21 @@ public Range getDistanceRange() { return null; } + @Override + public @Nullable Vector getVector() { + return null; + } + + @Override + public @Nullable Score getScore() { + return null; + } + + @Override + public @Nullable Range getScoreRange() { + return null; + } + @Override public @Nullable Point getGeoNearLocation() { return null; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index def03c7973..40d72a69f7 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -29,6 +29,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; + import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; @@ -178,10 +179,11 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor private static boolean backoff(MongoQueryMethod method) { - boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery(); + boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery() + || method.isSearchQuery(); if (skip && logger.isDebugEnabled()) { - logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming or scrolling query" + logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query" .formatted(method.getName())); } return skip; @@ -193,7 +195,6 @@ private static MethodContributor aggregationMethodContributor( return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName()))); builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation) .usingAggregationVariableName("aggregation").build()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java index f56c2c7a22..596b895ebd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java @@ -164,7 +164,7 @@ private Query applyAnnotatedReadPreferenceIfPresent(Query query) { } @SuppressWarnings("NullAway") - private MongoQueryExecution getExecution(ConvertingParameterAccessor accessor, FindWithQuery operation) { + MongoQueryExecution getExecution(ConvertingParameterAccessor accessor, FindWithQuery operation) { if (isDeleteQuery()) { return new DeleteExecution<>(executableRemove, method); @@ -345,7 +345,7 @@ private Document bindParameters(String source, ConvertingParameterAccessor acces * @return never {@literal null}. * @since 3.4 */ - protected ParameterBindingContext prepareBindingContext(String source, ConvertingParameterAccessor accessor) { + protected ParameterBindingContext prepareBindingContext(String source, MongoParameterAccessor accessor) { ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); return new ParameterBindingContext(accessor::getBindableValue, evaluator); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java index d075b67efe..f203b67e67 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java @@ -22,11 +22,14 @@ import java.util.List; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.convert.MongoWriter; @@ -73,6 +76,11 @@ public PotentiallyConvertingIterator iterator() { return new ConvertingIterator(delegate.iterator()); } + @Override + public @Nullable Vector getVector() { + return delegate.getVector(); + } + @Override public @Nullable ScrollPosition getScrollPosition() { return delegate.getScrollPosition(); @@ -95,6 +103,16 @@ public Sort getSort() { return getConvertedValue(delegate.getBindableValue(index), null); } + @Override + public @Nullable Score getScore() { + return delegate.getScore(); + } + + @Override + public @Nullable Range getScoreRange() { + return delegate.getScoreRange(); + } + @Override public @Nullable Range getDistanceRange() { return delegate.getDistanceRange(); @@ -208,7 +226,7 @@ private static Collection asCollection(@Nullable Object source) { if (source instanceof Iterable iterable) { - if(source instanceof Collection collection) { + if (source instanceof Collection collection) { return new ArrayList<>(collection); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java index 00d748f8a9..1b52233eac 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java @@ -16,6 +16,7 @@ package org.springframework.data.mongodb.repository.query; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Range; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java index cb91ccd8e6..94acef17ce 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java @@ -21,8 +21,10 @@ import java.util.List; import org.jspecify.annotations.Nullable; + import org.springframework.core.MethodParameter; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; @@ -195,10 +197,6 @@ static int findNearIndexInParameters(Method method) { return index; } - public int getDistanceRangeIndex() { - return -1; - } - /** * Returns the index of the {@link Distance} parameter to be used for max distance in geo queries. * @@ -317,7 +315,8 @@ static class MongoParameter extends Parameter { @Override public boolean isSpecialParameter() { - return super.isSpecialParameter() || Distance.class.isAssignableFrom(getType()) || isNearParameter() + return super.isSpecialParameter() || Distance.class.isAssignableFrom(getType()) + || Vector.class.isAssignableFrom(getType()) || isNearParameter() || TextCriteria.class.isAssignableFrom(getType()) || Collation.class.isAssignableFrom(getType()); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java index 66529dfce9..0f56223492 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java @@ -16,8 +16,10 @@ package org.springframework.data.mongodb.repository.query; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; +import org.springframework.data.domain.Score; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.query.Collation; @@ -55,7 +57,24 @@ public MongoParametersParameterAccessor(MongoQueryMethod method, Object[] values } @SuppressWarnings("NullAway") - public @Nullable Range getDistanceRange() { + @Override + public Range getScoreRange() { + + MongoParameters mongoParameters = method.getParameters(); + + if (mongoParameters.hasScoreRangeParameter()) { + return getValue(mongoParameters.getScoreRangeIndex()); + } + + Score score = getScore(); + Bound maxDistance = score != null ? Bound.inclusive(score) : Bound.unbounded(); + + return Range.of(Bound.unbounded(), maxDistance); + } + + @SuppressWarnings("NullAway") + @Override + public Range getDistanceRange() { MongoParameters mongoParameters = method.getParameters(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java index b8a8c34f48..ba7394ec17 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java @@ -15,7 +15,8 @@ */ package org.springframework.data.mongodb.repository.query; -import static org.springframework.data.mongodb.core.query.Criteria.*; +import static org.springframework.data.mongodb.core.query.Criteria.Placeholder; +import static org.springframework.data.mongodb.core.query.Criteria.where; import java.util.Arrays; import java.util.Collection; @@ -72,6 +73,7 @@ public class MongoQueryCreator extends AbstractQueryCreator { private final MongoParameterAccessor accessor; private final MappingContext context; private final boolean isGeoNearQuery; + private final boolean isSearchQuery; /** * Creates a new {@link MongoQueryCreator} from the given {@link PartTree}, {@link ConvertingParameterAccessor} and @@ -81,9 +83,9 @@ public class MongoQueryCreator extends AbstractQueryCreator { * @param accessor * @param context */ - public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, + public MongoQueryCreator(PartTree tree, MongoParameterAccessor accessor, MappingContext context) { - this(tree, accessor, context, false); + this(tree, accessor, context, false, false); } /** @@ -94,9 +96,10 @@ public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, * @param accessor * @param context * @param isGeoNearQuery + * @param isSearchQuery */ - public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, - MappingContext context, boolean isGeoNearQuery) { + public MongoQueryCreator(PartTree tree, MongoParameterAccessor accessor, + MappingContext context, boolean isGeoNearQuery, boolean isSearchQuery) { super(tree, accessor); @@ -104,6 +107,7 @@ public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, this.accessor = accessor; this.isGeoNearQuery = isGeoNearQuery; + this.isSearchQuery = isSearchQuery; this.context = context; } @@ -114,6 +118,11 @@ protected Criteria create(Part part, Iterator iterator) { return new Criteria(); } + if (isPartOfSearchQuery(part)) { + skip(part, iterator); + return new Criteria(); + } + PersistentPropertyPath path = context.getPersistentPropertyPath(part.getProperty()); MongoPersistentProperty property = path.getLeafProperty(); @@ -127,6 +136,11 @@ protected Criteria and(Part part, Criteria base, Iterator iterator) { return create(part, iterator); } + if (isPartOfSearchQuery(part)) { + skip(part, iterator); + return base; + } + PersistentPropertyPath path = context.getPersistentPropertyPath(part.getProperty()); MongoPersistentProperty property = path.getLeafProperty(); @@ -185,13 +199,13 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return criteria.is(null); case NOT_IN: Object ninValue = parameters.next(); - if(ninValue instanceof Placeholder) { + if (ninValue instanceof Placeholder) { return criteria.raw("$nin", ninValue); } return criteria.nin(valueAsList(ninValue, part)); case IN: Object inValue = parameters.next(); - if(inValue instanceof Placeholder) { + if (inValue instanceof Placeholder) { return criteria.raw("$in", inValue); } return criteria.in(valueAsList(inValue, part)); @@ -210,7 +224,7 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); case EXISTS: Object next = parameters.next(); - if(next instanceof Placeholder placeholder) { + if (next instanceof Placeholder placeholder) { return criteria.raw("$exists", placeholder); } else { return criteria.exists((Boolean) next); @@ -334,7 +348,7 @@ private Criteria createContainingCriteria(Part part, MongoPersistentProperty pro if (property.isCollectionLike()) { Object next = parameters.next(); - if(next instanceof Placeholder) { + if (next instanceof Placeholder) { return criteria.raw("$in", next); } return criteria.in(valueAsList(next, part)); @@ -412,8 +426,7 @@ private java.util.List valueAsList(Object value, Part part) { streamable = streamable.map(it -> { if (it instanceof String sv) { - return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), - regexOptions); + return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), regexOptions); } return it; }); @@ -447,10 +460,23 @@ private boolean isSpherical(MongoPersistentProperty property) { return false; } + private boolean isPartOfSearchQuery(Part part) { + return isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN)); + } + + private static void skip(Part part, Iterator parameters) { + + int total = part.getNumberOfArguments(); + int i = 0; + while (parameters.hasNext() && i < total) { + parameters.next(); + i++; + } + } + /** * Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt}, - * {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. - *
+ * {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}.
* In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are * used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used * for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index 01d4e0c63d..f606a59859 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -15,14 +15,21 @@ */ package org.springframework.data.mongodb.repository.query; +import java.util.ArrayList; +import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.function.Supplier; +import org.bson.Document; import org.jspecify.annotations.Nullable; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; import org.springframework.data.geo.Distance; @@ -37,9 +44,14 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.mongodb.repository.util.SliceUtils; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.support.PageableExecutionUtils; @@ -175,7 +187,7 @@ public Object execute(Query query) { return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results; } - @SuppressWarnings({"unchecked","NullAway"}) + @SuppressWarnings({ "unchecked", "NullAway" }) GeoResults doExecuteQuery(Query query) { Point nearLocation = accessor.getGeoNearLocation(); @@ -210,6 +222,91 @@ private static boolean isListOfGeoResult(TypeInformation returnType) { } } + /** + * {@link MongoQueryExecution} to execute vector search. + * + * @author Mark Paluch + * @author Chistoph Strobl + * @since 5.0 + */ + class VectorSearchExecution implements MongoQueryExecution { + + private final MongoOperations operations; + private final TypeInformation returnType; + private final String collectionName; + private final Class targetType; + private final ScoringFunction scoringFunction; + private final AggregationPipeline pipeline; + + VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, + QueryContainer queryContainer) { + this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(), + queryContainer.scoringFunction()); + } + + public VectorSearchExecution(MongoOperations operations, Class targetType, String collectionName, + TypeInformation returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) { + + this.operations = operations; + this.returnType = returnType; + this.collectionName = collectionName; + this.targetType = targetType; + this.scoringFunction = scoringFunction; + this.pipeline = pipeline; + } + + @Override + public Object execute(Query query) { + + AggregationResults aggregated = operations + .aggregate(TypedAggregation.newAggregation(targetType, pipeline.getOperations()), collectionName, targetType); + + List mappedResults = aggregated.getMappedResults(); + + if (!isSearchResult(returnType)) { + return mappedResults; + } + + List rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); + List> result = new ArrayList<>(mappedResults.size()); + + for (int i = 0; i < mappedResults.size(); i++) { + + Document document = rawResults.get(i); + SearchResult searchResult = new SearchResult<>(mappedResults.get(i), + Similarity.raw(document.getDouble("__score__"), scoringFunction)); + + result.add(searchResult); + } + + return isListOfSearchResult(returnType) ? result : new SearchResults<>(result); + } + + private static boolean isListOfSearchResult(TypeInformation returnType) { + + if (!Collection.class.isAssignableFrom(returnType.getType())) { + return false; + } + + TypeInformation componentType = returnType.getComponentType(); + return componentType != null && SearchResult.class.equals(componentType.getType()); + } + + private static boolean isSearchResult(TypeInformation returnType) { + + if (SearchResults.class.isAssignableFrom(returnType.getType())) { + return true; + } + + if (!Iterable.class.isAssignableFrom(returnType.getType())) { + return false; + } + + TypeInformation componentType = returnType.getComponentType(); + return componentType != null && SearchResult.class.equals(componentType.getType()); + } + } + /** * {@link MongoQueryExecution} to execute geo-near queries with paging. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java index 52c5e32555..060d03e223 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java @@ -35,6 +35,7 @@ import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Tailable; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.core.RepositoryMetadata; @@ -414,10 +415,28 @@ private Optional findAnnotatedAggregation() { .filter(it -> !ObjectUtils.isEmpty(it)); } + /** + * Returns whether the method has an annotated vector search. + * + * @return true if {@link VectorSearch} is present. + * @since 5.0 + */ + public boolean hasAnnotatedVectorSearch() { + return findAnnotatedVectorSearch().isPresent(); + } + + Optional findAnnotatedVectorSearch() { + return lookupVectorSearchAnnotation(); + } + Optional lookupAggregationAnnotation() { return doFindAnnotation(Aggregation.class); } + Optional lookupVectorSearchAnnotation() { + return doFindAnnotation(VectorSearch.class); + } + Optional lookupUpdateAnnotation() { return doFindAnnotation(Update.class); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java index 6116cc5534..9682e4971f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java @@ -81,7 +81,7 @@ public PartTree getTree() { @SuppressWarnings("NullAway") protected Query createQuery(ConvertingParameterAccessor accessor) { - MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, isGeoNearQuery); + MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, isGeoNearQuery, false); Query query = creator.createQuery(); if (tree.isLimiting()) { @@ -126,7 +126,7 @@ protected Query createQuery(ConvertingParameterAccessor accessor) { @Override protected Query createCountQuery(ConvertingParameterAccessor accessor) { - return new MongoQueryCreator(tree, accessor, context, false).createQuery(); + return new MongoQueryCreator(tree, accessor, context, false, false).createQuery(); } @Override diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java index 06f946d745..29e2127e18 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java @@ -18,21 +18,27 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.bson.Document; import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; import org.springframework.core.convert.converter.Converter; import org.springframework.data.convert.DtoInstantiatingConverter; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.Similarity; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.Point; import org.springframework.data.mapping.model.EntityInstantiators; import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.util.ReactiveWrappers; @@ -117,6 +123,57 @@ private boolean isStreamOfGeoResult() { } } + /** + * {@link ReactiveMongoQueryExecution} to execute vector search. + * + * @author Mark Paluch + * @since 5.0 + */ + class VectorSearchExecution implements ReactiveMongoQueryExecution { + + private final ReactiveMongoOperations operations; + private final QueryContainer queryMetadata; + private final AggregationPipeline pipeline; + private final boolean returnSearchResult; + + VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, QueryContainer queryMetadata) { + + this.operations = operations; + this.queryMetadata = queryMetadata; + this.pipeline = queryMetadata.pipeline(); + this.returnSearchResult = isSearchResult(method.getReturnType()); + } + + @Override + public Publisher execute(Query query, Class type, String collection) { + + Flux aggregate = operations.aggregate( + TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline.getOperations()), collection, + Document.class); + + return aggregate.map(document -> { + + Object mappedResult = operations.getConverter().read(queryMetadata.outputType(), document); + + return returnSearchResult + ? new SearchResult<>(mappedResult, + Similarity.raw(document.getDouble(queryMetadata.scoreField()), queryMetadata.scoringFunction())) + : mappedResult; + }); + } + + private static boolean isSearchResult(TypeInformation returnType) { + + if (!Publisher.class.isAssignableFrom(returnType.getType())) { + return false; + } + + TypeInformation componentType = returnType.getComponentType(); + return componentType != null && SearchResult.class.equals(componentType.getType()); + } + + } + /** * {@link ReactiveMongoQueryExecution} removing documents matching the query. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java index 4aa773091b..9a17b2b5fc 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java @@ -90,7 +90,7 @@ protected Mono createCountQuery(ConvertingParameterAccessor accessor) { @SuppressWarnings("NullAway") private Query createQueryInternal(ConvertingParameterAccessor accessor, boolean isCountQuery) { - MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, !isCountQuery && isGeoNearQuery); + MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, !isCountQuery && isGeoNearQuery, false); Query query = creator.createQuery(); if (isCountQuery) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java new file mode 100644 index 0000000000..cf75c7db94 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import reactor.core.publisher.Mono; + +import org.bson.Document; +import org.reactivestreams.Publisher; +import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.ReactiveMongoOperations; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.repository.query.ValueExpressionDelegate; +import org.springframework.data.spel.ExpressionDependencies; + +/** + * {@link AbstractReactiveMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either + * derived from the method name or provided through {@link VectorSearch#filter()}. + * + * @author Mark Paluch + * @since 5.0 + */ +public class ReactiveVectorSearchAggregation extends AbstractReactiveMongoQuery { + + private final ReactiveMongoOperations mongoOperations; + private final MongoPersistentEntity collectionEntity; + private final ValueExpressionDelegate valueExpressionDelegate; + private final VectorSearchDelegate delegate; + + /** + * Creates a new {@link ReactiveVectorSearchAggregation} from the given {@link MongoQueryMethod} and + * {@link MongoOperations}. + * + * @param method must not be {@literal null}. + * @param mongoOperations must not be {@literal null}. + * @param delegate must not be {@literal null}. + */ + public ReactiveVectorSearchAggregation(ReactiveMongoQueryMethod method, ReactiveMongoOperations mongoOperations, + ValueExpressionDelegate delegate) { + + super(method, mongoOperations, delegate); + + this.valueExpressionDelegate = delegate; + if (!method.isSearchQuery() && !method.isCollectionQuery()) { + throw new InvalidMongoDbApiUsageException(String.format( + "Repository Vector Search method '%s' must return either return SearchResults or List but was %s", + method.getName(), method.getReturnType().getType().getSimpleName())); + } + + this.mongoOperations = mongoOperations; + this.collectionEntity = method.getEntityInformation().getCollectionEntity(); + this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate); + } + + @Override + protected Publisher doExecute(ReactiveMongoQueryMethod method, ResultProcessor processor, + ConvertingParameterAccessor accessor, @org.jspecify.annotations.Nullable Class typeToRead) { + + return getParameterBindingCodec().flatMapMany(codec -> { + + String json = delegate.getQueryString(); + ExpressionDependencies dependencies = codec.captureExpressionDependencies(json, accessor::getBindableValue, + valueExpressionDelegate); + + return getValueExpressionEvaluatorLater(dependencies, accessor).flatMapMany(expressionEvaluator -> { + + ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, + expressionEvaluator); + QueryContainer query = delegate.createQuery(expressionEvaluator, processor, accessor, typeToRead, codec, + bindingContext); + + ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution( + mongoOperations, method, query); + + return execution.execute(query.query(), Document.class, collectionEntity.getCollection()); + }); + }); + } + + @Override + protected Mono createQuery(ConvertingParameterAccessor accessor) { + throw new UnsupportedOperationException(); + } + + @Override + protected boolean isCountQuery() { + return false; + } + + @Override + protected boolean isExistsQuery() { + return false; + } + + @Override + protected boolean isDeleteQuery() { + return false; + } + + @Override + protected boolean isLimiting() { + return false; + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java new file mode 100644 index 0000000000..eb8dc2e52e --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import org.jspecify.annotations.Nullable; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.repository.query.ValueExpressionDelegate; + +/** + * {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived + * from the method name or provided through {@link VectorSearch#filter()}. + * + * @author Mark Paluch + * @since 5.0 + */ +public class VectorSearchAggregation extends AbstractMongoQuery { + + private final MongoOperations mongoOperations; + private final MongoPersistentEntity collectionEntity; + private final VectorSearchDelegate delegate; + + /** + * Creates a new {@link VectorSearchAggregation} from the given {@link MongoQueryMethod} and {@link MongoOperations}. + * + * @param method must not be {@literal null}. + * @param mongoOperations must not be {@literal null}. + * @param delegate must not be {@literal null}. + */ + public VectorSearchAggregation(MongoQueryMethod method, MongoOperations mongoOperations, + ValueExpressionDelegate delegate) { + + super(method, mongoOperations, delegate); + + if (!method.isSearchQuery() && !method.isCollectionQuery()) { + throw new InvalidMongoDbApiUsageException(String.format( + "Repository Vector Search method '%s' must return either return SearchResults or List but was %s", + method.getName(), method.getReturnType().getType().getSimpleName())); + } + + this.mongoOperations = mongoOperations; + this.collectionEntity = method.getEntityInformation().getCollectionEntity(); + this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate); + } + + @Override + protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, + @Nullable Class typeToRead) { + + QueryContainer query = createVectorSearchQuery(processor, accessor, typeToRead); + + MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations, + method, collectionEntity.getCollection(), query); + + return execution.execute(query.query()); + } + + QueryContainer createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor, + @Nullable Class typeToRead) { + + ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); + ParameterBindingContext bindingContext = prepareBindingContext(delegate.getQueryString(), accessor); + + return delegate.createQuery(evaluator, processor, accessor, typeToRead, getParameterBindingCodec(), bindingContext); + } + + @Override + protected Query createQuery(ConvertingParameterAccessor accessor) { + throw new UnsupportedOperationException(); + } + + @Override + protected boolean isCountQuery() { + return false; + } + + @Override + protected boolean isExistsQuery() { + return false; + } + + @Override + protected boolean isDeleteQuery() { + return false; + } + + @Override + protected boolean isLimiting() { + return false; + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java new file mode 100644 index 0000000000..0dbff2e932 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java @@ -0,0 +1,422 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import java.util.ArrayList; +import java.util.List; + +import org.bson.Document; +import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; +import org.springframework.data.expression.ValueExpression; +import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; +import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.repository.query.ValueExpressionDelegate; +import org.springframework.data.repository.query.parser.Part; +import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.util.NumberUtils; +import org.springframework.util.StringUtils; + +/** + * Delegate to assemble information about Vector Search queries necessary to run a MongoDB {@code $vectorSearch}. + * + * @author Mark Paluch + */ +class VectorSearchDelegate { + + private final VectorSearchQueryFactory queryFactory; + private final VectorSearchOperation.SearchType searchType; + private final String indexName; + private final @Nullable Integer numCandidates; + private final @Nullable String numCandidatesExpression; + private final Limit limit; + private final @Nullable String limitExpression; + private final MongoConverter converter; + + VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) { + + VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow(); + + this.searchType = vectorSearch.searchType(); + this.indexName = method.getAnnotatedHint(); + + if (StringUtils.hasText(vectorSearch.numCandidates())) { + + ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates()); + + if (expression.isLiteral()) { + this.numCandidates = Integer.parseInt(vectorSearch.numCandidates()); + this.numCandidatesExpression = null; + } else { + this.numCandidates = null; + this.numCandidatesExpression = vectorSearch.numCandidates(); + } + + } else { + this.numCandidates = null; + this.numCandidatesExpression = null; + } + + if (StringUtils.hasText(vectorSearch.limit())) { + + ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit()); + + if (expression.isLiteral()) { + this.limit = Limit.of(Integer.parseInt(vectorSearch.limit())); + this.limitExpression = null; + } else { + this.limit = Limit.unlimited(); + this.limitExpression = vectorSearch.limit(); + } + + } else { + this.limit = Limit.unlimited(); + this.limitExpression = null; + } + + this.converter = converter; + + if (StringUtils.hasText(vectorSearch.filter())) { + this.queryFactory = StringUtils.hasText(vectorSearch.path()) + ? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path()) + : new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity()); + } else { + this.queryFactory = new PartTreeQueryFactory( + new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()), + converter.getMappingContext()); + } + } + + /** + * Create Query Metadata for {@code $vectorSearch}. + */ + QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor, + MongoParameterAccessor accessor, @Nullable Class typeToRead, ParameterBindingDocumentCodec codec, + ParameterBindingContext context) { + + String scoreField = "__score__"; + Class outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); + VectorSearchInput vectorSearchInput = createSearchInput(evaluator, accessor, codec, context); + AggregationPipeline pipeline = createVectorSearchPipeline(vectorSearchInput, scoreField, outputType, accessor, + evaluator); + + return new QueryContainer(vectorSearchInput.path, scoreField, vectorSearchInput.query, pipeline, searchType, + outputType, getSimilarityFunction(accessor), indexName); + } + + @SuppressWarnings("NullAway") + AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class outputType, + MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) { + + Vector vector = accessor.getVector(); + Score score = accessor.getScore(); + Range distance = accessor.getScoreRange(); + Limit limit = Limit.of(input.query().getLimit()); + + List stages = new ArrayList<>(); + VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector) + .limit(limit); + + Integer candidates = null; + if (this.numCandidatesExpression != null) { + candidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); + } else if (this.numCandidates != null) { + candidates = this.numCandidates; + } else if (input.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT)) { + + /* + MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy. + */ + candidates = input.query().getLimit() * 20; + } + + if (candidates != null) { + $vectorSearch = $vectorSearch.numCandidates(candidates); + } + // + $vectorSearch = $vectorSearch.filter(input.query.getQueryObject()); + $vectorSearch = $vectorSearch.searchType(this.searchType); + $vectorSearch = $vectorSearch.withSearchScore(scoreField); + + if (score != null) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + c.gt(score.getValue()); + }); + } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + Range.Bound lower = distance.getLowerBound(); + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + c.gte(value); + } else { + c.gt(value); + } + } + + Range.Bound upper = distance.getUpperBound(); + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + c.lte(value); + } else { + c.lt(value); + } + } + }); + } + + stages.add($vectorSearch); + + if (input.query().isSorted()) { + + stages.add(ctx -> { + + Document mappedSort = ctx.getMappedObject(input.query().getSortObject(), outputType); + mappedSort.append(scoreField, -1); + return ctx.getMappedObject(new Document("$sort", mappedSort)); + }); + } else { + stages.add(Aggregation.sort(Sort.Direction.DESC, scoreField)); + } + + return new AggregationPipeline(stages); + } + + private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor, + ParameterBindingDocumentCodec codec, ParameterBindingContext context) { + + VectorSearchInput input = queryFactory.createQuery(accessor, codec, context); + Limit limit = getLimit(evaluator, accessor); + if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) { + input.query().limit(limit); + } + return input; + } + + private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) { + + if (this.limitExpression != null) { + + Object value = evaluator.evaluate(this.limitExpression); + if (value != null) { + if (value instanceof Limit l) { + return l; + } + if (value instanceof Number n) { + return Limit.of(n.intValue()); + } + if (value instanceof String s) { + return Limit.of(NumberUtils.parseNumber(s, Integer.class)); + } + throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number"); + } + } + + if (this.limit.isLimited()) { + return this.limit; + } + + return accessor.getLimit(); + } + + public String getQueryString() { + return queryFactory.getQueryString(); + } + + ScoringFunction getSimilarityFunction(MongoParameterAccessor accessor) { + + Score score = accessor.getScore(); + + if (score != null) { + return score.getFunction(); + } + + Range scoreRange = accessor.getScoreRange(); + + if (scoreRange != null) { + if (scoreRange.getUpperBound().isBounded()) { + return scoreRange.getUpperBound().getValue().get().getFunction(); + } + + if (scoreRange.getLowerBound().isBounded()) { + return scoreRange.getLowerBound().getValue().get().getFunction(); + } + } + + return ScoringFunction.unspecified(); + } + + /** + * Metadata for a Vector Search Aggregation. + * + * @param path + * @param query + * @param searchType + * @param outputType + * @param scoringFunction + */ + record QueryContainer(String path, String scoreField, Query query, AggregationPipeline pipeline, + VectorSearchOperation.SearchType searchType, Class outputType, ScoringFunction scoringFunction, String index) { + + } + + /** + * Strategy interface to implement a query factory for the Vector Search pre-filter query. + */ + private interface VectorSearchQueryFactory { + + VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, + ParameterBindingContext context); + + /** + * @return the underlying query string to determine {@link ParameterBindingContext}. + */ + String getQueryString(); + } + + private static class AnnotatedQueryFactory implements VectorSearchQueryFactory { + + private final String query; + private final String path; + + AnnotatedQueryFactory(String query, String path) { + + this.query = query; + this.path = path; + } + + AnnotatedQueryFactory(String query, MongoPersistentEntity entity) { + + this.query = query; + String path = null; + for (MongoPersistentProperty property : entity) { + if (Vector.class.isAssignableFrom(property.getType())) { + path = property.getFieldName(); + break; + } + } + + if (path == null) { + throw new InvalidMongoDbApiUsageException( + "Cannot find Vector Search property in entity [%s]".formatted(entity.getName())); + } + + this.path = path; + } + + public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, + ParameterBindingContext context) { + + Document queryObject = codec.decode(this.query, context); + Query query = new BasicQuery(queryObject); + + Sort sort = parameterAccessor.getSort(); + if (sort.isSorted()) { + query = query.with(sort); + } + + return new VectorSearchInput(path, query); + } + + @Override + public String getQueryString() { + return this.query; + } + } + + private class PartTreeQueryFactory implements VectorSearchQueryFactory { + + private final String path; + private final PartTree tree; + + @SuppressWarnings("NullableProblems") + PartTreeQueryFactory(PartTree tree, MappingContext context) { + + String path = null; + for (PartTree.OrPart part : tree) { + for (Part p : part) { + if (p.getType() == Part.Type.SIMPLE_PROPERTY || p.getType() == Part.Type.NEAR + || p.getType() == Part.Type.WITHIN || p.getType() == Part.Type.BETWEEN) { + PersistentPropertyPath ppp = context.getPersistentPropertyPath(p.getProperty()); + MongoPersistentProperty property = ppp.getLeafProperty(); + + if (Vector.class.isAssignableFrom(property.getType())) { + path = p.getProperty().toDotPath(); + break; + } + } + } + } + + if (path == null) { + throw new InvalidMongoDbApiUsageException( + "No Simple Property/Near/Within/Between part found for a Vector property"); + } + + this.path = path; + this.tree = tree; + } + + @SuppressWarnings("NullAway") + public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, + ParameterBindingContext context) { + + MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false, + true); + + Query query = creator.createQuery(parameterAccessor.getSort()); + + if (tree.isLimiting()) { + query.limit(tree.getMaxResults()); + } + + return new VectorSearchInput(path, query); + } + + @Override + public String getQueryString() { + return ""; + } + } + + private record VectorSearchInput(String path, Query query) { + + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java index e1abcdc2ab..d6047aa058 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java @@ -34,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.PartTreeMongoQuery; import org.springframework.data.mongodb.repository.query.StringBasedAggregation; import org.springframework.data.mongodb.repository.query.StringBasedMongoQuery; +import org.springframework.data.mongodb.repository.query.VectorSearchAggregation; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.querydsl.QuerydslPredicateExecutor; import org.springframework.data.repository.core.NamedQueries; @@ -182,6 +183,8 @@ public RepositoryQuery resolveQuery(Method method, RepositoryMetadata metadata, if (namedQueries.hasQuery(namedQueryName)) { String namedQuery = namedQueries.getQuery(namedQueryName); return new StringBasedMongoQuery(namedQuery, queryMethod, operations, expressionSupport); + } else if (queryMethod.hasAnnotatedVectorSearch()) { + return new VectorSearchAggregation(queryMethod, operations, expressionSupport); } else if (queryMethod.hasAnnotatedAggregation()) { return new StringBasedAggregation(queryMethod, operations, expressionSupport); } else if (queryMethod.hasAnnotatedQuery()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java index ae8561bc17..11c5b09460 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java @@ -22,6 +22,7 @@ import java.util.Optional; import org.jspecify.annotations.Nullable; + import org.springframework.beans.factory.BeanFactory; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.ReactiveMongoOperations; @@ -33,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.ReactivePartTreeMongoQuery; import org.springframework.data.mongodb.repository.query.ReactiveStringBasedAggregation; import org.springframework.data.mongodb.repository.query.ReactiveStringBasedMongoQuery; +import org.springframework.data.mongodb.repository.query.ReactiveVectorSearchAggregation; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.querydsl.ReactiveQuerydslPredicateExecutor; import org.springframework.data.repository.core.NamedQueries; @@ -174,6 +176,8 @@ public RepositoryQuery resolveQuery(Method method, RepositoryMetadata metadata, if (namedQueries.hasQuery(namedQueryName)) { String namedQuery = namedQueries.getQuery(namedQueryName); return new ReactiveStringBasedMongoQuery(namedQuery, queryMethod, operations, delegate); + } else if (queryMethod.hasAnnotatedVectorSearch()) { + return new ReactiveVectorSearchAggregation(queryMethod, operations, delegate); } else if (queryMethod.hasAnnotatedAggregation()) { return new ReactiveStringBasedAggregation(queryMethod, operations, delegate); } else if (queryMethod.hasAnnotatedQuery()) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java index 1b9aba1ba0..5f66e61bdc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java @@ -70,7 +70,7 @@ public void rendersNearQueryWithKeyCorrectly() { @Test // DATAMONGO-2264 public void rendersMaxDistanceCorrectly() { - NearQuery query = NearQuery.near(10.0, 20.0).maxDistance(new Distance(30.0)); + NearQuery query = NearQuery.near(10.0, 20.0).maxDistance(Distance.of(30.0)); assertThat(new GeoNearOperation(query, "distance").toPipelineStages(Aggregation.DEFAULT_CONTEXT)) .containsExactly($geoNear().near(10.0, 20.0).maxDistance(30.0).doc()); @@ -79,7 +79,7 @@ public void rendersMaxDistanceCorrectly() { @Test // DATAMONGO-2264 public void rendersMinDistanceCorrectly() { - NearQuery query = NearQuery.near(10.0, 20.0).minDistance(new Distance(30.0)); + NearQuery query = NearQuery.near(10.0, 20.0).minDistance(Distance.of(30.0)); assertThat(new GeoNearOperation(query, "distance").toPipelineStages(Aggregation.DEFAULT_CONTEXT)) .containsExactly($geoNear().near(10.0, 20.0).minDistance(30.0).doc()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java index 4ce045fe6f..936460f466 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java @@ -15,14 +15,14 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import java.util.List; import org.bson.Document; import org.junit.jupiter.api.Test; - import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Limit; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.Criteria; @@ -103,6 +103,16 @@ void mapsCriteriaToDomainType() { .containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter))); } + @Test + void withInvalidLimit() { + + VectorSearchOperation $search = VectorSearchOperation.search("vector_index").path("plot_embedding") + .vector(-0.0016261312, -0.028070757, -0.011342932).limit(Limit.unlimited()); + + List stages = $search.toPipelineStages(TestAggregationContext.contextFor(Movie.class)); + assertThat(stages.get(0)).doesNotContainKey("$vectorSearch.limit"); + } + static class Movie { @Id String id; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java index 7fb664b00c..84a494f9d8 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java @@ -69,7 +69,7 @@ public void convertsCircleToDocumentAndBackCorrectlyNeutralDistance() { @Test // DATAMONGO-858 public void convertsCircleToDocumentAndBackCorrectlyMilesDistance() { - Distance radius = new Distance(3, Metrics.MILES); + Distance radius = Distance.of(3, Metrics.MILES); Circle circle = new Circle(new Point(1, 2), radius); Document document = CircleToDocumentConverter.INSTANCE.convert(circle); @@ -106,7 +106,7 @@ public void convertsSphereToDocumentAndBackCorrectlyWithNeutralDistance() { @Test // DATAMONGO-858 public void convertsSphereToDocumentAndBackCorrectlyWithKilometerDistance() { - Distance radius = new Distance(3, Metrics.KILOMETERS); + Distance radius = Distance.of(3, Metrics.KILOMETERS); Sphere sphere = new Sphere(new Point(1, 2), radius); Document document = SphereToDocumentConverter.INSTANCE.convert(sphere); @@ -160,7 +160,7 @@ public void convertsCircleCorrectlyWhenUsingNonDoubleForCoordinates() { circle.put("radius", 3L); assertThat(DocumentToCircleConverter.INSTANCE.convert(circle)) - .isEqualTo(new Circle(new Point(1, 2), new Distance(3))); + .isEqualTo(new Circle(new Point(1, 2), Distance.of(3))); } @Test // DATAMONGO-1607 @@ -171,7 +171,7 @@ public void convertsSphereCorrectlyWhenUsingNonDoubleForCoordinates() { sphere.put("radius", 3L); assertThat(DocumentToSphereConverter.INSTANCE.convert(sphere)) - .isEqualTo(new Sphere(new Point(1, 2), new Distance(3))); + .isEqualTo(new Sphere(new Point(1, 2), Distance.of(3))); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java index 5bd7e06b97..6f1c7439c0 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java @@ -1626,7 +1626,7 @@ void shouldWriteEntityWithGeoSphereCorrectly() { void shouldWriteEntityWithGeoSphereWithMetricDistanceCorrectly() { ClassWithGeoSphere object = new ClassWithGeoSphere(); - Sphere sphere = new Sphere(new Point(1, 2), new Distance(3, Metrics.KILOMETERS)); + Sphere sphere = new Sphere(new Point(1, 2), Distance.of(3, Metrics.KILOMETERS)); Distance radius = sphere.getRadius(); object.sphere = sphere; @@ -4082,8 +4082,7 @@ static class WithExplicitTargetTypes { @Field(targetType = FieldType.DECIMAL128) // BigDecimal bigDecimal; - @Field(targetType = FieldType.DECIMAL128) - BigInteger bigInteger; + @Field(targetType = FieldType.DECIMAL128) BigInteger bigInteger; @Field(targetType = FieldType.INT64) // Date dateAsLong; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java index 3a9140d34c..1774c36493 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java @@ -23,9 +23,9 @@ import java.util.List; import org.junit.Test; + import org.springframework.data.domain.Sort.Direction; import org.springframework.data.geo.GeoResults; -import org.springframework.data.geo.Metric; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.Venue; @@ -67,7 +67,7 @@ public void geoNearWithMinDistance() { GeoResults result = template.geoNear(geoNear, Venue.class); assertThat(result.getContent().size()).isNotEqualTo(0); - assertThat(result.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(result.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-1110 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java index bbdad047f2..fdfa840d58 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java @@ -17,6 +17,7 @@ package org.springframework.data.mongodb.core.query; import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.data.Offset.*; import static org.assertj.core.data.Offset.offset; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public class MetricConversionUnitTests { @Test // DATAMONGO-1348 public void shouldConvertMilesToMeters() { - Distance distance = new Distance(1, Metrics.MILES); + Distance distance = Distance.of(1, Metrics.MILES); double distanceInMeters = MetricConversion.getDistanceInMeters(distance); assertThat(distanceInMeters).isCloseTo(1609.3438343d, offset(0.000000001)); @@ -43,7 +44,7 @@ public void shouldConvertMilesToMeters() { @Test // DATAMONGO-1348 public void shouldConvertKilometersToMeters() { - Distance distance = new Distance(1, Metrics.KILOMETERS); + Distance distance = Distance.of(1, Metrics.KILOMETERS); double distanceInMeters = MetricConversion.getDistanceInMeters(distance); assertThat(distanceInMeters).isCloseTo(1000, offset(0.000000001)); @@ -72,11 +73,13 @@ public void shouldCalculateMetersToMilesMultiplier() { @Test // GH-4004 void shouldConvertKilometersToRadians/* on an earth like sphere with r=6378.137km */() { - assertThat(MetricConversion.toRadians(new Distance(1, Metrics.KILOMETERS))).isCloseTo(0.000156785594d, offset(0.000000001)); + assertThat(MetricConversion.toRadians(Distance.of(1, Metrics.KILOMETERS))).isCloseTo(0.000156785594d, + offset(0.000000001)); } @Test // GH-4004 void shouldConvertMilesToRadians/* on an earth like sphere with r=6378.137km */() { - assertThat(MetricConversion.toRadians(new Distance(1, Metrics.MILES))).isCloseTo(0.000252321328d, offset(0.000000001)); + assertThat(MetricConversion.toRadians(Distance.of(1, Metrics.MILES))).isCloseTo(0.000252321328d, + offset(0.000000001)); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java index f4e3d26eb1..2b600988db 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java @@ -21,10 +21,10 @@ import java.math.RoundingMode; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; -import org.springframework.data.geo.Metric; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.DocumentTestUtils; @@ -44,7 +44,7 @@ */ public class NearQueryUnitTests { - private static final Distance ONE_FIFTY_KILOMETERS = new Distance(150, Metrics.KILOMETERS); + private static final Distance ONE_FIFTY_KILOMETERS = Distance.of(150, Metrics.KILOMETERS); @Test public void rejectsNullPoint() { @@ -57,7 +57,7 @@ public void settingUpNearWithMetricRecalculatesDistance() { NearQuery query = NearQuery.near(2.5, 2.5, Metrics.KILOMETERS).maxDistance(150); assertThat(query.getMaxDistance()).isEqualTo(ONE_FIFTY_KILOMETERS); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(query.getMetric()).isEqualTo(Metrics.KILOMETERS); assertThat(query.isSpherical()).isTrue(); } @@ -68,27 +68,27 @@ public void settingMetricRecalculatesMaxDistance() { query.inMiles(); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.MILES); + assertThat(query.getMetric()).isEqualTo(Metrics.MILES); } @Test public void configuresResultMetricCorrectly() { NearQuery query = NearQuery.near(2.5, 2.1); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.NEUTRAL); + assertThat(query.getMetric()).isEqualTo(Metrics.NEUTRAL); query = query.maxDistance(ONE_FIFTY_KILOMETERS); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(query.getMetric()).isEqualTo(Metrics.KILOMETERS); assertThat(query.getMaxDistance()).isEqualTo(ONE_FIFTY_KILOMETERS); assertThat(query.isSpherical()).isTrue(); query = query.in(Metrics.MILES); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.MILES); + assertThat(query.getMetric()).isEqualTo(Metrics.MILES); assertThat(query.getMaxDistance()).isEqualTo(ONE_FIFTY_KILOMETERS); assertThat(query.isSpherical()).isTrue(); - query = query.maxDistance(new Distance(200, Metrics.KILOMETERS)); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.MILES); + query = query.maxDistance(Distance.of(200, Metrics.KILOMETERS)); + assertThat(query.getMetric()).isEqualTo(Metrics.MILES); } @Test // DATAMONGO-445, DATAMONGO-2264 @@ -200,7 +200,7 @@ public void shouldUseMetersForGeoJsonData() { public void shouldUseMetersForGeoJsonDataWhenDistanceInKilometers() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.KILOMETERS)); + query.maxDistance(Distance.of(1, Metrics.KILOMETERS)); assertThat(query.toDocument()).containsEntry("maxDistance", 1000D).containsEntry("distanceMultiplier", 0.001D); } @@ -209,7 +209,7 @@ public void shouldUseMetersForGeoJsonDataWhenDistanceInKilometers() { public void shouldUseMetersForGeoJsonDataWhenDistanceInMiles() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.MILES)); + query.maxDistance(Distance.of(1, Metrics.MILES)); assertThat(query.toDocument()).containsEntry("maxDistance", 1609.3438343D).containsEntry("distanceMultiplier", 0.00062137D); @@ -219,7 +219,7 @@ public void shouldUseMetersForGeoJsonDataWhenDistanceInMiles() { public void shouldUseKilometersForDistanceWhenMaxDistanceInMiles() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.MILES)).in(Metrics.KILOMETERS); + query.maxDistance(Distance.of(1, Metrics.MILES)).in(Metrics.KILOMETERS); assertThat(query.toDocument()).containsEntry("maxDistance", 1609.3438343D).containsEntry("distanceMultiplier", 0.001D); @@ -229,7 +229,7 @@ public void shouldUseKilometersForDistanceWhenMaxDistanceInMiles() { public void shouldUseMilesForDistanceWhenMaxDistanceInKilometers() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.KILOMETERS)).in(Metrics.MILES); + query.maxDistance(Distance.of(1, Metrics.KILOMETERS)).in(Metrics.MILES); assertThat(query.toDocument()).containsEntry("maxDistance", 1000D).containsEntry("distanceMultiplier", 0.00062137D); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java index 3f2e60f4c4..c2cb6cacf8 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java @@ -38,6 +38,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DuplicateKeyException; import org.springframework.dao.IncorrectResultSizeDataAccessException; @@ -49,7 +50,6 @@ import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResults; -import org.springframework.data.geo.Metric; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; @@ -458,7 +458,7 @@ void executesGeoNearQueryForResultsCorrectly() { repository.save(dave); GeoResults results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS)); + Distance.of(2000, Metrics.KILOMETERS)); assertThat(results.getContent()).isNotEmpty(); } @@ -470,11 +470,11 @@ void executesGeoPageQueryForResultsCorrectly() { repository.save(dave); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(0, 20)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 20)); assertThat(results.getContent()).isNotEmpty(); // DATAMONGO-607 - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-323 @@ -634,13 +634,13 @@ void executesGeoPageQueryForWithPageRequestForPageInBetween() { repository.saveAll(Arrays.asList(dave, oliver, carter, boyd, leroi)); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(2); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isFalse(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); assertThat(results.getAverageDistance().getNormalizedValue()).isEqualTo(0.0); } @@ -656,12 +656,12 @@ void executesGeoPageQueryForWithPageRequestForPageAtTheEnd() { repository.saveAll(Arrays.asList(dave, oliver, carter)); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(1); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isTrue(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-445 @@ -672,13 +672,13 @@ void executesGeoPageQueryForWithPageRequestForJustOneElement() { repository.save(dave); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(0, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 2)); assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(1); assertThat(results.isFirst()).isTrue(); assertThat(results.isLast()).isTrue(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-445 @@ -688,13 +688,13 @@ void executesGeoPageQueryForWithPageRequestForJustOneElementEmptyPage() { repository.save(dave); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); assertThat(results.getContent()).isEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(0); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isTrue(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-1608 @@ -1117,7 +1117,7 @@ void executesGeoNearQueryForResultsCorrectlyWhenGivenMinAndMaxDistance() { dave.setLocation(point); repository.save(dave); - Range range = Distance.between(new Distance(0.01, KILOMETERS), new Distance(2000, KILOMETERS)); + Range range = Distance.between(Distance.of(0.01, KILOMETERS), Distance.of(2000, KILOMETERS)); GeoResults results = repository.findPersonByLocationNear(new Point(-73.99, 40.73), range); assertThat(results.getContent()).isNotEmpty(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java index 93a293ecff..1f4f682ebc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java @@ -24,6 +24,7 @@ import java.util.stream.Stream; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Limit; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java index 1a481b49ed..2a76c0ba6c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java @@ -20,6 +20,7 @@ import static org.springframework.data.domain.Sort.Direction.*; import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.core.query.Query.*; +import static org.springframework.data.mongodb.test.util.Assertions.*; import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import reactor.core.Disposable; @@ -40,6 +41,7 @@ import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; import org.reactivestreams.Publisher; + import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -353,7 +355,7 @@ void findsPeopleGeoresultByLocationWithinBox() { repository.save(dave).as(StepVerifier::create).expectNextCount(1).verifyComplete(); repository.findByLocationNear(new Point(-73.99, 40.73), // - new Distance(2000, Metrics.KILOMETERS)).as(StepVerifier::create).consumeNextWith(actual -> { + Distance.of(2000, Metrics.KILOMETERS)).as(StepVerifier::create).consumeNextWith(actual -> { assertThat(actual.getDistance().getValue()).isCloseTo(1, offset(1d)); assertThat(actual.getContent()).isEqualTo(dave); @@ -372,7 +374,7 @@ void findsPeoplePageableGeoresultByLocationWithinBox() throws InterruptedExcepti Thread.sleep(500); repository.findByLocationNear(new Point(-73.99, 40.73), // - new Distance(2000, Metrics.KILOMETERS), // + Distance.of(2000, Metrics.KILOMETERS), // PageRequest.of(0, 10)).as(StepVerifier::create) // .consumeNextWith(actual -> { @@ -393,7 +395,7 @@ void findsPeopleByLocationWithinBox() throws InterruptedException { Thread.sleep(500); repository.findPersonByLocationNear(new Point(-73.99, 40.73), // - new Distance(2000, Metrics.KILOMETERS)).as(StepVerifier::create) // + Distance.of(2000, Metrics.KILOMETERS)).as(StepVerifier::create) // .expectNext(dave) // .verifyComplete(); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java new file mode 100644 index 0000000000..14a4749c8a --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java @@ -0,0 +1,224 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository; + +import static org.assertj.core.api.Assertions.*; + +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.FilterType; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.ReactiveMongoTemplate; +import org.springframework.data.mongodb.core.SimpleReactiveMongoDatabaseFactory; +import org.springframework.data.mongodb.core.TestMongoConfiguration; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.index.VectorIndex; +import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; +import org.springframework.data.mongodb.repository.config.EnableReactiveMongoRepositories; +import org.springframework.data.mongodb.test.util.AtlasContainer; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.data.repository.CrudRepository; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; + +/** + * Integration tests using reactive Vector Search and Vector Indexes through local MongoDB Atlas. + * + * @author Mark Paluch + */ +@Testcontainers(disabledWithoutDocker = true) +@SpringJUnitConfig(classes = { ReactiveVectorSearchTests.Config.class }) +public class ReactiveVectorSearchTests { + + Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f); + + private static final MongoDBAtlasLocalContainer atlasLocal = AtlasContainer.bestMatch().withReuse(true); + private static final String COLLECTION_NAME = "collection-1"; + + static MongoClient client; + static MongoTestTemplate template; + + @Autowired ReactiveVectorSearchRepository repository; + + @EnableReactiveMongoRepositories( + includeFilters = { + @ComponentScan.Filter(value = ReactiveVectorSearchRepository.class, type = FilterType.ASSIGNABLE_TYPE) }, + considerNestedRepositories = true) + static class Config extends TestMongoConfiguration { + + @Override + public String getDatabaseName() { + return "vector-search-tests"; + } + + @Override + public MongoClient mongoClient() { + atlasLocal.start(); + return MongoClients.create(atlasLocal.getConnectionString()); + } + + @Bean + public com.mongodb.reactivestreams.client.MongoClient reactiveMongoClient() { + atlasLocal.start(); + return com.mongodb.reactivestreams.client.MongoClients.create(atlasLocal.getConnectionString()); + } + + @Bean + ReactiveMongoTemplate reactiveMongoTemplate(MappingMongoConverter mongoConverter) { + return new ReactiveMongoTemplate(new SimpleReactiveMongoDatabaseFactory(reactiveMongoClient(), getDatabaseName()), + mongoConverter); + } + } + + @BeforeAll + static void beforeAll() throws InterruptedException { + atlasLocal.start(); + + System.out.println(atlasLocal.getConnectionString()); + client = MongoClients.create(atlasLocal.getConnectionString()); + template = new MongoTestTemplate(client, "vector-search-tests"); + + template.remove(WithVectorFields.class).all(); + initDocuments(); + initIndexes(); + + Thread.sleep(500); // just wait a little or the index will be broken + } + + @Test + void shouldSearchEnnWithAnnotatedFilter() { + + Flux> results = repository.searchAnnotated("de", VECTOR, Score.of(0.4), + Limit.of(10)); + + results.as(StepVerifier::create).consumeNextWith(actual -> { + assertThat(actual.getScore().getValue()).isGreaterThan(0.4); + assertThat(actual.getScore()).isInstanceOf(Similarity.class); + + }).expectNextCount(2).verifyComplete(); + } + + @Test + void shouldSearchEnnWithDerivedFilter() { + + Flux results = repository.searchByCountryAndEmbeddingNear("de", VECTOR, Limit.of(10)); + + results.as(StepVerifier::create).consumeNextWith(actual -> assertThat(actual).isInstanceOf(WithVectorFields.class)) + .expectNextCount(2).verifyComplete(); + } + + static void initDocuments() { + + WithVectorFields w1 = new WithVectorFields("de", "one", Vector.of(0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f)); + WithVectorFields w2 = new WithVectorFields("de", "two", Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f)); + WithVectorFields w3 = new WithVectorFields("en", "three", + Vector.of(0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f)); + WithVectorFields w4 = new WithVectorFields("de", "four", + Vector.of(0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f)); + + template.insertAll(List.of(w1, w2, w3, w4)); + } + + static void initIndexes() { + + VectorIndex cosIndex = new VectorIndex("cos-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + + VectorIndex euclideanIndex = new VectorIndex("euc-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.EUCLIDEAN).dimensions(5)).addFilter("country"); + + VectorIndex inner = new VectorIndex("ip-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.DOT_PRODUCT).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(inner); + template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, inner.getName()); + } + + interface ReactiveVectorSearchRepository extends CrudRepository { + + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", numCandidates = "#{10+10}", + searchType = VectorSearchOperation.SearchType.ANN) + Flux> searchAnnotated(String country, Vector vector, Score distance, Limit limit); + + @VectorSearch(indexName = "cos-index") + Flux searchByCountryAndEmbeddingNear(String country, Vector vector, Limit limit); + + } + + @org.springframework.data.mongodb.core.mapping.Document(COLLECTION_NAME) + static class WithVectorFields { + + String id; + String country; + String description; + + Vector embedding; + + public WithVectorFields(String country, String description, Vector embedding) { + this.country = country; + this.description = description; + this.embedding = embedding; + } + + public String getId() { + return id; + } + + public String getCountry() { + return country; + } + + public String getDescription() { + return description; + } + + public Vector getEmbedding() { + return embedding; + } + + @Override + public String toString() { + return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description + + '\'' + '}'; + } + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java new file mode 100644 index 0000000000..a224481da1 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java @@ -0,0 +1,285 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.FilterType; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.TestMongoConfiguration; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.index.VectorIndex; +import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; +import org.springframework.data.mongodb.repository.config.EnableMongoRepositories; +import org.springframework.data.mongodb.test.util.AtlasContainer; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.data.repository.CrudRepository; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; + +/** + * Integration tests using Vector Search and Vector Indexes through local MongoDB Atlas. + * + * @author Christoph Strobl + * @author Mark Paluch + */ +@Testcontainers(disabledWithoutDocker = true) +@SpringJUnitConfig(classes = { VectorSearchTests.Config.class }) +public class VectorSearchTests { + + Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f); + + private static final MongoDBAtlasLocalContainer atlasLocal = AtlasContainer.bestMatch().withReuse(true); + private static final String COLLECTION_NAME = "collection-1"; + + static MongoClient client; + static MongoTestTemplate template; + + @Autowired VectorSearchRepository repository; + + @EnableMongoRepositories( + includeFilters = { + @ComponentScan.Filter(value = VectorSearchRepository.class, type = FilterType.ASSIGNABLE_TYPE) }, + considerNestedRepositories = true) + static class Config extends TestMongoConfiguration { + + @Override + public String getDatabaseName() { + return "vector-search-tests"; + } + + @Override + public MongoClient mongoClient() { + return MongoClients.create(atlasLocal.getConnectionString()); + } + } + + @BeforeAll + static void beforeAll() throws InterruptedException { + + atlasLocal.start(); + + client = MongoClients.create(atlasLocal.getConnectionString()); + template = new MongoTestTemplate(client, "vector-search-tests"); + + template.remove(WithVectorFields.class).all(); + initDocuments(); + initIndexes(); + + Thread.sleep(500); // just wait a little or the index will be broken + } + + @Test + void shouldSearchEnnWithAnnotatedFilter() { + + SearchResults results = repository.searchAnnotated("de", VECTOR, + Score.of(0.4), Limit.of(10)); + + assertThat(results).extracting(SearchResult::getScore).hasOnlyElementsOfType(Similarity.class); + assertThat(results).hasSize(3); + } + + @Test + void shouldSearchEnnWithDerivedFilter() { + + SearchResults results = repository.searchCosineByCountryAndEmbeddingNear("de", VECTOR, + Similarity.of(0.98), + Limit.of(10)); + + assertThat(results).extracting(SearchResult::getScore).hasOnlyElementsOfType(Similarity.class); + assertThat(results).hasSize(2).extracting(SearchResult::getContent).extracting(WithVectorFields::getCountry) + .containsOnly("de", "de"); + + assertThat(results).extracting(SearchResult::getContent).extracting(WithVectorFields::getDescription) + .containsExactlyInAnyOrder("two", "one"); + } + + @Test + void shouldSearchEnnWithDerivedFilterWithoutScore() { + + SearchResults de = repository.searchCosineByCountryAndEmbeddingNear("de", VECTOR, + Similarity.of(0.4), Limit.of(10)); + + assertThat(de).hasSizeGreaterThanOrEqualTo(2); + + assertThat(repository.searchCosineByCountryAndEmbeddingNear("de", VECTOR, Similarity.of(0.999), Limit.of(10))) + .hasSize(1); + } + + @Test + void shouldSearchAsListEnnWithDerivedFilterWithoutScore() { + + List de = repository.searchAsListByCountryAndEmbeddingNear("de", VECTOR, Limit.of(10)); + + assertThat(de).hasOnlyElementsOfType(WithVectorFields.class); + } + + @Test + void shouldSearchEuclideanWithDerivedFilter() { + + SearchResults results = repository.searchEuclideanByCountryAndEmbeddingNear("de", VECTOR, + Limit.of(2)); + + assertThat(results).hasSize(2).extracting(SearchResult::getContent).extracting(WithVectorFields::getCountry) + .containsOnly("de", "de"); + + assertThat(results).extracting(SearchResult::getContent).extracting(WithVectorFields::getDescription) + .containsExactlyInAnyOrder("two", "one"); + } + + @Test + void shouldSearchEnnWithDerivedFilterWithin() { + + SearchResults results = repository.searchByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.between(0.93, 0.98)); + + assertThat(results).hasSize(1); + for (SearchResult result : results) { + assertThat(result.getScore().getValue()).isBetween(0.93, 0.98); + } + } + + @Test + void shouldSearchEnnWithDerivedAndLimitedFilterWithin() { + + SearchResults results = repository.searchTop1ByCountryAndEmbeddingWithin("de", VECTOR, + Similarity.between(0.8, 1)); + + assertThat(results).hasSize(1); + + for (SearchResult result : results) { + assertThat(result.getScore().getValue()).isBetween(0.8, 1.0); + } + } + + static void initDocuments() { + + WithVectorFields w1 = new WithVectorFields("de", "one", Vector.of(0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f)); + WithVectorFields w2 = new WithVectorFields("de", "two", Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f)); + WithVectorFields w3 = new WithVectorFields("en", "three", + Vector.of(0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f)); + WithVectorFields w4 = new WithVectorFields("de", "four", + Vector.of(0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f)); + + template.insertAll(List.of(w1, w2, w3, w4)); + } + + static void initIndexes() { + + VectorIndex cosIndex = new VectorIndex("cos-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + + VectorIndex euclideanIndex = new VectorIndex("euc-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.EUCLIDEAN).dimensions(5)).addFilter("country"); + + VectorIndex inner = new VectorIndex("ip-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.DOT_PRODUCT).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(inner); + template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, inner.getName()); + } + + interface VectorSearchRepository extends CrudRepository { + + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", numCandidates = "#{10+10}", + searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchAnnotated(String country, Vector vector, + Score distance, Limit limit); + + @VectorSearch(indexName = "cos-index") + SearchResults searchCosineByCountryAndEmbeddingNear(String country, Vector vector, + Score similarity, Limit limit); + + @VectorSearch(indexName = "cos-index") + List searchAsListByCountryAndEmbeddingNear(String country, Vector vector, Limit limit); + + @VectorSearch(indexName = "euc-index") + SearchResults searchEuclideanByCountryAndEmbeddingNear(String country, Vector vector, + Limit limit); + + @VectorSearch(indexName = "cos-index", limit = "10") + SearchResults searchByCountryAndEmbeddingWithin(String country, Vector vector, + Range distance); + + @VectorSearch(indexName = "cos-index") + SearchResults searchTop1ByCountryAndEmbeddingWithin(String country, Vector vector, + Range distance); + + } + + @org.springframework.data.mongodb.core.mapping.Document(COLLECTION_NAME) + static class WithVectorFields { + + String id; + String country; + String description; + + Vector embedding; + + public WithVectorFields(String country, String description, Vector embedding) { + this.country = country; + this.description = description; + this.embedding = embedding; + } + + public String getId() { + return id; + } + + public String getCountry() { + return country; + } + + public String getDescription() { + return description; + } + + public Vector getEmbedding() { + return embedding; + } + + @Override + public String toString() { + return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description + + '\'' + '}'; + } + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java index 1c856394d8..f0ffebde20 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java @@ -22,8 +22,10 @@ import org.bson.Document; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; +import org.springframework.data.domain.Score; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; @@ -45,15 +47,15 @@ * @author Oliver Gierke * @author Christoph Strobl */ -public class MongoParametersParameterAccessorUnitTests { +class MongoParametersParameterAccessorUnitTests { - Distance DISTANCE = new Distance(2.5, Metrics.KILOMETERS); - RepositoryMetadata metadata = new DefaultRepositoryMetadata(PersonRepository.class); - MongoMappingContext context = new MongoMappingContext(); - ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); + private Distance DISTANCE = Distance.of(2.5, Metrics.KILOMETERS); + private RepositoryMetadata metadata = new DefaultRepositoryMetadata(PersonRepository.class); + private MongoMappingContext context = new MongoMappingContext(); + private ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); @Test - public void returnsUnboundedForDistanceIfNoneAvailable() throws NoSuchMethodException, SecurityException { + void returnsUnboundedForDistanceIfNoneAvailable() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -64,7 +66,7 @@ public void returnsUnboundedForDistanceIfNoneAvailable() throws NoSuchMethodExce } @Test - public void returnsDistanceIfAvailable() throws NoSuchMethodException, SecurityException { + void returnsDistanceIfAvailable() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class, Distance.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -75,7 +77,7 @@ public void returnsDistanceIfAvailable() throws NoSuchMethodException, SecurityE } @Test // DATAMONGO-973 - public void shouldReturnAsFullTextStringWhenNoneDefinedForMethod() throws NoSuchMethodException, SecurityException { + void shouldReturnAsFullTextStringWhenNoneDefinedForMethod() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class, Distance.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -86,7 +88,7 @@ public void shouldReturnAsFullTextStringWhenNoneDefinedForMethod() throws NoSuch } @Test // DATAMONGO-973 - public void shouldProperlyConvertTextCriteria() throws NoSuchMethodException, SecurityException { + void shouldProperlyConvertTextCriteria() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByFirstname", String.class, TextCriteria.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -98,13 +100,13 @@ public void shouldProperlyConvertTextCriteria() throws NoSuchMethodException, Se } @Test // DATAMONGO-1110 - public void shouldDetectMinAndMaxDistance() throws NoSuchMethodException, SecurityException { + void shouldDetectMinAndMaxDistance() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class, Range.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); - Distance min = new Distance(10, Metrics.KILOMETERS); - Distance max = new Distance(20, Metrics.KILOMETERS); + Distance min = Distance.of(10, Metrics.KILOMETERS); + Distance max = Distance.of(20, Metrics.KILOMETERS); MongoParameterAccessor accessor = new MongoParametersParameterAccessor(queryMethod, new Object[] { new Point(10, 20), Distance.between(min, max) }); @@ -116,7 +118,7 @@ public void shouldDetectMinAndMaxDistance() throws NoSuchMethodException, Securi } @Test // DATAMONGO-1854 - public void shouldDetectCollation() throws NoSuchMethodException, SecurityException { + void shouldDetectCollation() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByFirstname", String.class, Collation.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -129,7 +131,7 @@ public void shouldDetectCollation() throws NoSuchMethodException, SecurityExcept } @Test // GH-2107 - public void shouldReturnUpdateIfPresent() throws NoSuchMethodException, SecurityException { + void shouldReturnUpdateIfPresent() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findAndModifyByFirstname", String.class, UpdateDefinition.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -142,7 +144,7 @@ public void shouldReturnUpdateIfPresent() throws NoSuchMethodException, Security } @Test // GH-2107 - public void shouldReturnNullIfNoUpdatePresent() throws NoSuchMethodException, SecurityException { + void shouldReturnNullIfNoUpdatePresent() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -153,6 +155,23 @@ public void shouldReturnNullIfNoUpdatePresent() throws NoSuchMethodException, Se assertThat(accessor.getUpdate()).isNull(); } + @Test // GH- + void shouldReturnRangeFromScore() throws NoSuchMethodException, SecurityException { + + Method method = PersonRepository.class.getMethod("findByFirstname", String.class, Score.class); + MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); + + MongoParameterAccessor accessor = new MongoParametersParameterAccessor(queryMethod, + new Object[] { "foo", Score.of(1) }); + + Range scoreRange = accessor.getScoreRange(); + + assertThat(scoreRange).isNotNull(); + assertThat(scoreRange.getLowerBound().isBounded()).isFalse(); + assertThat(scoreRange.getUpperBound().isBounded()).isTrue(); + assertThat(scoreRange.getUpperBound().getValue()).contains(Score.of(1)); + } + interface PersonRepository extends Repository { List findByLocationNear(Point point); @@ -165,6 +184,8 @@ interface PersonRepository extends Repository { List findByFirstname(String firstname, Collation collation); + List findByFirstname(String firstname, Score score); + List findAndModifyByFirstname(String firstname, UpdateDefinition update); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java index 93674e23fc..fc1ffb971e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java @@ -27,6 +27,8 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Point; @@ -43,6 +45,7 @@ * * @author Oliver Gierke * @author Christoph Strobl + * @author Mark Paluch */ @ExtendWith(MockitoExtension.class) class MongoParametersUnitTests { @@ -184,6 +187,21 @@ void shouldReturnInvalidIndexIfUpdateDoesNotExist() throws NoSuchMethodException assertThat(parameters.getUpdateIndex()).isEqualTo(-1); } + @Test // GH-2107 + void shouldOmitVector() throws NoSuchMethodException, SecurityException { + + Method method = PersonRepository.class.getMethod("shouldOmitVector", Vector.class, Score.class, + Range.class, String.class); + MongoParameters parameters = new MongoParameters(ParametersSource.of(method), false); + + assertThat(parameters.getVectorIndex()).isEqualTo(0); + assertThat(parameters.getScoreIndex()).isEqualTo(1); + assertThat(parameters.getScoreRangeIndex()).isEqualTo(2); + + MongoParameters bindableParameters = parameters.getBindableParameters(); + assertThat(bindableParameters).hasSize(3); + } + interface PersonRepository { List findByLocationNear(Point point, Distance distance); @@ -205,5 +223,8 @@ interface PersonRepository { List findByText(String text, Collation collation); List findAndModifyByFirstname(String firstname, UpdateDefinition update, Pageable page); + + List shouldOmitVector(Vector vector, Score distance, Range range, + String country); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java index 609e0a0018..55e3df6b43 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java @@ -29,6 +29,7 @@ import org.bson.types.ObjectId; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; import org.springframework.data.geo.Distance; @@ -120,7 +121,7 @@ void createsIsNullQueryCorrectly() { void bindsMetricDistanceParameterToNearSphereCorrectly() throws Exception { Point point = new Point(10, 20); - Distance distance = new Distance(2.5, Metrics.KILOMETERS); + Distance distance = Distance.of(2.5, Metrics.KILOMETERS); Query query = query( where("location").nearSphere(point).maxDistance(distance.getNormalizedValue()).and("firstname").is("Dave")); @@ -131,7 +132,7 @@ void bindsMetricDistanceParameterToNearSphereCorrectly() throws Exception { void bindsDistanceParameterToNearCorrectly() throws Exception { Point point = new Point(10, 20); - Distance distance = new Distance(2.5); + Distance distance = Distance.of(2.5); Query query = query( where("location").near(point).maxDistance(distance.getNormalizedValue()).and("firstname").is("Dave")); @@ -405,7 +406,7 @@ void shouldCreateRegexWhenUsingNotContainsOnStringProperty() { void createsNonSphericalNearForDistanceWithDefaultMetric() { Point point = new Point(1.0, 1.0); - Distance distance = new Distance(1.0); + Distance distance = Distance.of(1.0); PartTree tree = new PartTree("findByLocationNear", Venue.class); MongoQueryCreator creator = new MongoQueryCreator(tree, getAccessor(converter, point, distance), context); @@ -445,7 +446,7 @@ void shouldCreateNearSphereQueryForSphericalProperty() { void shouldCreateNearSphereQueryForSphericalPropertyHavingDistanceWithDefaultMetric() { Point point = new Point(1.0, 1.0); - Distance distance = new Distance(1.0); + Distance distance = Distance.of(1.0); PartTree tree = new PartTree("findByAddress2dSphere_GeoNear", User.class); MongoQueryCreator creator = new MongoQueryCreator(tree, getAccessor(converter, point, distance), context); @@ -458,7 +459,7 @@ void shouldCreateNearSphereQueryForSphericalPropertyHavingDistanceWithDefaultMet void shouldCreateNearQueryForMinMaxDistance() { Point point = new Point(10, 20); - Range range = Distance.between(new Distance(10), new Distance(20)); + Range range = Distance.between(Distance.of(10), Distance.of(20)); PartTree tree = new PartTree("findByAddress_GeoNear", User.class); MongoQueryCreator creator = new MongoQueryCreator(tree, getAccessor(converter, point, range), context); @@ -664,7 +665,7 @@ void nearShouldUseMetricDistanceForGeoJsonTypes() { GeoJsonPoint point = new GeoJsonPoint(27.987901, 86.9165379); PartTree tree = new PartTree("findByLocationNear", User.class); MongoQueryCreator creator = new MongoQueryCreator(tree, - getAccessor(converter, point, new Distance(1, Metrics.KILOMETERS)), context); + getAccessor(converter, point, Distance.of(1, Metrics.KILOMETERS)), context); assertThat(creator.createQuery()).isEqualTo(query(where("location").nearSphere(point).maxDistance(1000.0D))); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java index dbd17aa805..2c0c996bc3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java @@ -32,6 +32,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; @@ -86,7 +87,7 @@ class MongoQueryExecutionUnitTests { @Mock DbRefResolver dbRefResolver; private Point POINT = new Point(10, 20); - private Distance DISTANCE = new Distance(2.5, Metrics.KILOMETERS); + private Distance DISTANCE = Distance.of(2.5, Metrics.KILOMETERS); private RepositoryMetadata metadata = new DefaultRepositoryMetadata(PersonRepository.class); private MongoMappingContext context = new MongoMappingContext(); private ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java index 8f9824e14d..386d0fa4b5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java index d7a3430048..1fbd60414a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java @@ -71,7 +71,7 @@ public void geoNearExecutionShouldApplyQuerySettings() throws Exception { Query query = new Query(); when(parameterAccessor.getGeoNearLocation()).thenReturn(new Point(1, 2)); when(parameterAccessor.getDistanceRange()) - .thenReturn(Range.from(Bound.inclusive(new Distance(10))).to(Bound.inclusive(new Distance(15)))); + .thenReturn(Range.from(Bound.inclusive(Distance.of(10))).to(Bound.inclusive(Distance.of(15)))); when(parameterAccessor.getPageable()).thenReturn(PageRequest.of(1, 10)); new GeoNearExecution(operations, parameterAccessor, TypeInformation.fromReturnTypeOf(geoNear)).execute(query, @@ -83,8 +83,8 @@ public void geoNearExecutionShouldApplyQuerySettings() throws Exception { NearQuery nearQuery = queryArgumentCaptor.getValue(); assertThat(nearQuery.toDocument().get("near")).isEqualTo(Arrays.asList(1d, 2d)); assertThat(nearQuery.getSkip()).isEqualTo(10L); - assertThat(nearQuery.getMinDistance()).isEqualTo(new Distance(10)); - assertThat(nearQuery.getMaxDistance()).isEqualTo(new Distance(15)); + assertThat(nearQuery.getMinDistance()).isEqualTo(Distance.of(10)); + assertThat(nearQuery.getMaxDistance()).isEqualTo(Distance.of(15)); } @Test // DATAMONGO-1444 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java index 82cd0a157c..14cbbc0394 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java @@ -17,7 +17,6 @@ import static org.assertj.core.api.Assertions.*; -import org.springframework.data.mongodb.repository.query.MongoQueryMethodUnitTests.PersonRepository; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -27,6 +26,7 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java index 3ed7ace0f9..91f23bb049 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java @@ -19,11 +19,14 @@ import java.util.Iterator; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.convert.MongoWriter; @@ -73,6 +76,21 @@ public StubParameterAccessor(Object... values) { } } + @Override + public Vector getVector() { + return null; + } + + @Override + public @org.jspecify.annotations.Nullable Score getScore() { + return null; + } + + @Override + public @org.jspecify.annotations.Nullable Range getScoreRange() { + return null; + } + @Override public ScrollPosition getScrollPosition() { return null; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java new file mode 100644 index 0000000000..819bba5a48 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java @@ -0,0 +1,109 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import java.lang.reflect.Method; + +import org.bson.conversions.Bson; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; +import org.springframework.data.projection.ProjectionFactory; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; +import org.springframework.data.repository.query.ValueExpressionDelegate; + +/** + * Unit tests for {@link VectorSearchAggregation}. + * + * @author Mark Paluch + */ +class VectorSearchAggregationUnitTests { + + MongoOperations operationsMock; + MongoMappingContext context; + MappingMongoConverter converter; + + @BeforeEach + public void setUp() { + + context = new MongoMappingContext(); + converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, context); + operationsMock = Mockito.mock(MongoOperations.class); + + when(operationsMock.getConverter()).thenReturn(converter); + when(operationsMock.execute(any())).thenReturn(Bson.DEFAULT_CODEC_REGISTRY); + } + + @Test + void derivesPrefilter() throws Exception { + + VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear", + String.class, Vector.class, Score.class, Limit.class); + + QueryContainer query = aggregation.createVectorSearchQuery( + aggregation.getQueryMethod().getResultProcessor(), + new MongoParametersParameterAccessor(aggregation.getQueryMethod(), + new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }), + Object.class); + + assertThat(query.query().getQueryObject()).containsEntry("country", "de"); + } + + private VectorSearchAggregation aggregation(Class repository, String name, Class... parameters) + throws Exception { + + Method method = repository.getMethod(name, parameters); + ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); + MongoQueryMethod queryMethod = new MongoQueryMethod(method, new DefaultRepositoryMetadata(repository), factory, + context); + return new VectorSearchAggregation(queryMethod, operationsMock, ValueExpressionDelegate.create()); + } + + interface SampleRepository extends CrudRepository { + + @VectorSearch(indexName = "cos-index") + SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score similarity, + Limit limit); + + } + + static class WithVectorFields { + + String id; + String country; + String description; + + Vector embedding; + + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java new file mode 100644 index 0000000000..078c01eece --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java @@ -0,0 +1,254 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import static org.mockito.Mockito.mock; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; + +import java.lang.reflect.Method; +import java.util.List; + +import org.bson.Document; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.Test; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Vector; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; +import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.Repository; +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.core.support.AnnotationRepositoryMetadata; +import org.springframework.data.repository.query.ValueExpressionDelegate; + +/** + * Unit tests for {@link VectorSearchDelegate}. + * + * @author Mark Paluch + * @author Christoph Strobl + */ +class VectorSearchDelegateUnitTests { + + MappingMongoConverter converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, new MongoMappingContext()); + + @Test + void shouldConsiderDerivedLimit() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(container.query().getLimit()).isEqualTo(10); + assertThat(numCandidates(container.pipeline())).isEqualTo(10 * 20); + } + + @Test + void shouldNotSetNumCandidates() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10EnnByEmbeddingNear", Vector.class, Score.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(container.query().getLimit()).isEqualTo(10); + assertThat(numCandidates(container.pipeline())).isNull(); + } + + @Test + void shouldConsiderProvidedLimit() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class, + Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(container.query().getLimit()).isEqualTo(11); + assertThat(numCandidates(container.pipeline())).isEqualTo(11 * 20); + } + + @Test + void considersDerivedQueryPart() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByFirstNameAndEmbeddingNear", String.class, + Vector.class, Score.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, "spring", Vector.of(1, 2), Score.of(1)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter", + new Document("first_name", "spring")); + } + + @Test + void considersDerivedQueryPartInDifferentOrder() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNearAndFirstName", Vector.class, + Score.class, String.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), "spring"); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter", + new Document("first_name", "spring")); + } + + @Test + void defaultSortsByScore() throws NoSuchMethodException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class, + Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(10)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + List stages = container.pipeline().lastOperation() + .toPipelineStages(TestAggregationContext.contextFor(WithVector.class)); + + assertThat(stages).containsExactly(new Document("$sort", new Document("__score__", -1))); + } + + @Test + void usesDerivedSort() throws NoSuchMethodException { + + Method method = VectorSearchRepository.class.getMethod("searchByEmbeddingNearOrderByFirstName", Vector.class, + Score.class, Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + AggregationPipeline aggregationPipeline = container.pipeline(); + + List stages = aggregationPipeline.lastOperation() + .toPipelineStages(TestAggregationContext.contextFor(WithVector.class)); + + assertThat(stages).containsExactly(new Document("$sort", new Document("first_name", 1).append("__score__", -1))); + } + + Document vectorSearchStageOf(AggregationPipeline pipeline) { + return pipeline.firstOperation().toPipelineStages(TestAggregationContext.contextFor(WithVector.class)).get(0); + } + + private QueryContainer createQueryContainer(MongoQueryMethod queryMethod, MongoParametersParameterAccessor accessor) { + + VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create()); + + return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, null, + new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class)); + } + + private MongoQueryMethod getMongoQueryMethod(Method method) { + RepositoryMetadata metadata = AnnotationRepositoryMetadata.getMetadata(method.getDeclaringClass()); + return new MongoQueryMethod(method, metadata, new SpelAwareProxyProjectionFactory(), converter.getMappingContext()); + } + + private static MongoParametersParameterAccessor getAccessor(MongoQueryMethod queryMethod, Object... values) { + return new MongoParametersParameterAccessor(queryMethod, values); + } + + @Nullable + private static Integer numCandidates(AggregationPipeline pipeline) { + + Document $vectorSearch = pipeline.firstOperation().toPipelineStages(Aggregation.DEFAULT_CONTEXT).get(0); + if ($vectorSearch.containsKey("$vectorSearch")) { + Object value = $vectorSearch.get("$vectorSearch", Document.class).get("numCandidates"); + return value instanceof Number i ? i.intValue() : null; + } + return null; + } + + interface VectorSearchRepository extends Repository { + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByFirstNameAndEmbeddingNear(String firstName, Vector vector, Score similarity); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByEmbeddingNearAndFirstName(Vector vector, Score similarity, String firstname); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN) + SearchResults searchTop10EnnByEmbeddingNear(Vector vector, Score similarity); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchByEmbeddingNearOrderByFirstName(Vector vector, Score similarity, Limit limit); + + } + + static class WithVector { + + Vector embedding; + + String lastName; + + @Field("first_name") String firstName; + + public Vector getEmbedding() { + return embedding; + } + + public void setEmbedding(Vector embedding) { + this.embedding = embedding; + } + + public String getLastName() { + return lastName; + } + + public void setLastName(String lastName) { + this.lastName = lastName; + } + + public String getFirstName() { + return firstName; + } + + public void setFirstName(String firstName) { + this.firstName = firstName; + } + } +} diff --git a/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt b/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt index cbb7ae46f3..99d57002e4 100644 --- a/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt +++ b/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt @@ -270,9 +270,9 @@ class ReactiveFindOperationExtensionsTests { fun terminatingFindNearAllAsFlow() { val spec = mockk>() - val foo = GeoResult("foo", Distance(0.0)) - val bar = GeoResult("bar", Distance(0.0)) - val baz = GeoResult("baz", Distance(0.0)) + val foo = GeoResult("foo", Distance.of(0.0)) + val bar = GeoResult("bar", Distance.of(0.0)) + val baz = GeoResult("baz", Distance.of(0.0)) every { spec.all() } returns Flux.just(foo, bar, baz) runBlocking { diff --git a/src/main/antora/modules/ROOT/nav.adoc b/src/main/antora/modules/ROOT/nav.adoc index a7401fb11f..6f2d1e2847 100644 --- a/src/main/antora/modules/ROOT/nav.adoc +++ b/src/main/antora/modules/ROOT/nav.adoc @@ -45,6 +45,7 @@ ** xref:repositories/create-instances.adoc[] ** xref:repositories/query-methods-details.adoc[] ** xref:mongodb/repositories/query-methods.adoc[] +** xref:mongodb/repositories/vector-search.adoc[] ** xref:mongodb/repositories/modifying-methods.adoc[] ** xref:repositories/projections.adoc[] ** xref:repositories/custom-implementations.adoc[] diff --git a/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc index 345b5dbb6c..7fc51de007 100644 --- a/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc +++ b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc @@ -25,7 +25,7 @@ Java:: [source,java,indent=0,subs="verbatim,quotes",role="primary"] ---- VectorIndex index = new VectorIndex("vector_index") - .addVector("plotEmbedding"), vector -> vector.dimensions(1536).similarity(COSINE)) <1> + .addVector("plotEmbedding", vector -> vector.dimensions(1536).similarity(COSINE)) <1> .addFilter("year"); <2> mongoTemplate.searchIndexOps(Movie.class) <3> diff --git a/src/main/antora/modules/ROOT/pages/mongodb/repositories/vector-search.adoc b/src/main/antora/modules/ROOT/pages/mongodb/repositories/vector-search.adoc new file mode 100644 index 0000000000..2e590107ec --- /dev/null +++ b/src/main/antora/modules/ROOT/pages/mongodb/repositories/vector-search.adoc @@ -0,0 +1,8 @@ +:vector-search-intro-include: data-mongodb::partial$vector-search-intro-include.adoc +:vector-search-model-include: data-mongodb::partial$vector-search-model-include.adoc +:vector-search-repository-include: data-mongodb::partial$vector-search-repository-include.adoc +:vector-search-scoring-include: data-mongodb::partial$vector-search-scoring-include.adoc +:vector-search-method-derived-include: data-mongodb::partial$vector-search-method-derived-include.adoc +:vector-search-method-annotated-include: data-mongodb::partial$vector-search-method-annotated-include.adoc + +include::{commons}@data-commons::page$repositories/vector-search.adoc[] diff --git a/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc new file mode 100644 index 0000000000..355bccf4e3 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc @@ -0,0 +1 @@ +To use Vector Search with MongoDB, you need a MongoDB Atlas instance that is either running in the cloud or by using https://www.mongodb.com/docs/atlas/cli/current/atlas-cli-deploy-docker/[Docker]. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc new file mode 100644 index 0000000000..252437f0b7 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc @@ -0,0 +1,23 @@ +Annotated search methods use the `@VectorSearch` annotation to define parameters for the https://www.mongodb.com/docs/upcoming/reference/operator/aggregation/vectorSearch/[`$vectorSearch`] aggregation stage. + +.Using `@VectorSearch` Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", limit="100", numCandidates="2000") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + + @VectorSearch(indexName = "my-index", filter = "{country: ?0}", limit="?3", numCandidates = "#{#limit * 20}", + searchType = VectorSearchOperation.SearchType.ANN) + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit); +} +---- +==== + +Annotated Search Methods can define `filter` for pre-filter usage. + +`filter`, `limit`, and `numCandidates` support xref:page$mongodb/value-expressions.adoc[Value Expressions] allowing references to search method arguments. + diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc new file mode 100644 index 0000000000..f2b006b8e4 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc @@ -0,0 +1,21 @@ +MongoDB Search methods must use the `@VectorSearch` annotation to define the index name for the https://www.mongodb.com/docs/upcoming/reference/operator/aggregation/vectorSearch/[`$vectorSearch`] aggregation stage. + +.Using `Near` and `Within` Keywords in Repository Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score score); + + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByEmbeddingWithin(Vector vector, Range range); + + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByCountryAndEmbeddingWithin(String country, Vector vector, Range range); +} +---- +==== + +Derived Search Methods can define domain model attributes to create the pre-filter for indexed fields. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc new file mode 100644 index 0000000000..e657f3aa63 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc @@ -0,0 +1,15 @@ +==== +[source,java] +---- +class Comment { + + @Id String id; + String country; + String comment; + + Vector embedding; + + // getters, setters, … +} +---- +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc new file mode 100644 index 0000000000..0e987fc1c5 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc @@ -0,0 +1,25 @@ +.Using `SearchResult` in a Repository Search Method +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(indexName = "my-index", numCandidates="#{#limit.max() * 20}") + SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score score, + Limit limit); + + @VectorSearch(indexName = "my-index", limit="10", numCandidates="200") + SearchResults searchByCountryAndEmbeddingWithin(String country, Vector embedding, + Score score); + +} + +SearchResults results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10)); +---- +==== + +[TIP] +==== +The MongoDB https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/[vector search aggregation] stage defines a set of required arguments and restrictions. +Please make sure to follow the guidelines and make sure to provide required arguments like `limit`. +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc new file mode 100644 index 0000000000..313d8bf394 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc @@ -0,0 +1,32 @@ +MongoDB reports the score directly as similarity value. +The scoring function must be specified in the index and therefore, Vector search methods do not consider the `Score.scoringFunction`. +The scoring function defaults to `ScoringFunction.unspecified()` as there is no information inside of search results how the score has been computed. + +.Using `Score` and `Similarity` in a Repository Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(…) + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity); + + @VectorSearch(…) + SearchResults searchTop10ByEmbeddingNear(Vector vector, Similarity similarity); + + @VectorSearch(…) + SearchResults searchTop10ByEmbeddingNear(Vector vector, Range range); +} + +repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1> + +repository.searchByEmbeddingNear(Vector.of(…), Similarity.of(0.9)); <2> + +repository.searchByEmbeddingNear(Vector.of(…), Similarity.between(0.5, 1)); <3> +---- + +<1> Run a search and return results with a similarity of `0.9` or greater. +<2> Return results with a similarity of `0.9` or greater. +<3> Return results with a similarity of between `0.5` and `1.0` or greater. +==== +