committed by
Mark Paluch
parent
eab7aae16c
commit
21568c84eb
@@ -486,7 +486,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
|
||||
return doStream(query, entityType, collectionName, returnType, QueryResultConverter.entity());
|
||||
}
|
||||
|
||||
@SuppressWarnings("ConstantConditions")
|
||||
@SuppressWarnings({"ConstantConditions", "NullAway"})
|
||||
<T, R> Stream<R> doStream(Query query, Class<?> entityType, String collectionName, Class<T> returnType,
|
||||
QueryResultConverter<? super T, ? extends R> resultConverter) {
|
||||
|
||||
@@ -1086,34 +1086,29 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
|
||||
return new GeoResults<>(result, avgDistance);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public <T> T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass) {
|
||||
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass) {
|
||||
return findAndModify(query, update, new FindAndModifyOptions(), entityClass, getCollectionName(entityClass));
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public <T> T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass,
|
||||
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, Class<T> entityClass,
|
||||
String collectionName) {
|
||||
return findAndModify(query, update, new FindAndModifyOptions(), entityClass, collectionName);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public <T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
|
||||
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
|
||||
Class<T> entityClass) {
|
||||
return findAndModify(query, update, options, entityClass, getCollectionName(entityClass));
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public <T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
|
||||
public <T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
|
||||
Class<T> entityClass, String collectionName) {
|
||||
return findAndModify(query, update, options, entityClass, collectionName, QueryResultConverter.entity());
|
||||
}
|
||||
|
||||
<S, T> T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
|
||||
<S, T> @Nullable T findAndModify(Query query, UpdateDefinition update, FindAndModifyOptions options,
|
||||
Class<S> entityClass, String collectionName, QueryResultConverter<? super S, ? extends T> resultConverter) {
|
||||
|
||||
Assert.notNull(query, "Query must not be null");
|
||||
@@ -1185,15 +1180,13 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
|
||||
// Find methods that take a Query to express the query and that return a single object that is also removed from the
|
||||
// collection in the database.
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public <T> T findAndRemove(Query query, Class<T> entityClass) {
|
||||
public <T> @Nullable T findAndRemove(Query query, Class<T> entityClass) {
|
||||
return findAndRemove(query, entityClass, getCollectionName(entityClass));
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public <T> T findAndRemove(Query query, Class<T> entityClass, String collectionName) {
|
||||
public <T> @Nullable T findAndRemove(Query query, Class<T> entityClass, String collectionName) {
|
||||
|
||||
Assert.notNull(query, "Query must not be null");
|
||||
Assert.notNull(entityClass, "EntityClass must not be null");
|
||||
@@ -2161,11 +2154,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
|
||||
* @param entityClass
|
||||
* @return
|
||||
*/
|
||||
@SuppressWarnings("NullAway")
|
||||
protected <T> List<T> doFindAndDelete(String collectionName, Query query, Class<T> entityClass) {
|
||||
return doFindAndDelete(collectionName, query, entityClass, QueryResultConverter.entity());
|
||||
}
|
||||
|
||||
@SuppressWarnings("NullAway")
|
||||
<S, T> List<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass,
|
||||
QueryResultConverter<? super S, ? extends T> resultConverter) {
|
||||
|
||||
@@ -2229,7 +2222,7 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware,
|
||||
return doAggregate(aggregation, collectionName, outputType, QueryResultConverter.entity(), context);
|
||||
}
|
||||
|
||||
@SuppressWarnings("ConstantConditions")
|
||||
@SuppressWarnings({"ConstantConditions", "NullAway"})
|
||||
<T, O> AggregationResults<O> doAggregate(Aggregation aggregation, String collectionName, Class<T> outputType,
|
||||
QueryResultConverter<? super T, ? extends O> resultConverter, AggregationOperationContext context) {
|
||||
|
||||
|
||||
@@ -2293,6 +2293,7 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati
|
||||
.flatMapSequential(deleteResult -> Flux.fromIterable(list)));
|
||||
}
|
||||
|
||||
@SuppressWarnings({"rawtypes", "unchecked", "NullAway"})
|
||||
<S, T> Flux<T> doFindAndDelete(String collectionName, Query query, Class<S> entityClass,
|
||||
QueryResultConverter<? super S, ? extends T> resultConverter) {
|
||||
|
||||
|
||||
@@ -22,8 +22,10 @@ import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import org.bson.Document;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
import org.springframework.lang.Contract;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}.
|
||||
@@ -82,6 +84,14 @@ public class AggregationPipeline {
|
||||
return Collections.unmodifiableList(pipeline);
|
||||
}
|
||||
|
||||
public @Nullable AggregationOperation firstOperation() {
|
||||
return CollectionUtils.firstElement(pipeline);
|
||||
}
|
||||
|
||||
public @Nullable AggregationOperation lastOperation() {
|
||||
return CollectionUtils.lastElement(pipeline);
|
||||
}
|
||||
|
||||
List<Document> toDocuments(AggregationOperationContext context) {
|
||||
|
||||
verify();
|
||||
@@ -97,8 +107,8 @@ public class AggregationPipeline {
|
||||
return false;
|
||||
}
|
||||
|
||||
AggregationOperation operation = pipeline.get(pipeline.size() - 1);
|
||||
return isOut(operation) || isMerge(operation);
|
||||
AggregationOperation operation = lastOperation();
|
||||
return operation != null && (isOut(operation) || isMerge(operation));
|
||||
}
|
||||
|
||||
void verify() {
|
||||
|
||||
@@ -356,6 +356,7 @@ public class ArrayOperators {
|
||||
* @return new instance of {@link SortArray}.
|
||||
* @since 4.5
|
||||
*/
|
||||
@SuppressWarnings("NullAway")
|
||||
public SortArray sort(Direction direction) {
|
||||
|
||||
if (usesFieldRef()) {
|
||||
|
||||
@@ -77,7 +77,7 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Vector getVector() {
|
||||
public @Nullable Vector getVector() {
|
||||
return delegate.getVector();
|
||||
}
|
||||
|
||||
@@ -104,12 +104,12 @@ public class ConvertingParameterAccessor implements MongoParameterAccessor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public @org.jspecify.annotations.Nullable Score getScore() {
|
||||
public @Nullable Score getScore() {
|
||||
return delegate.getScore();
|
||||
}
|
||||
|
||||
@Override
|
||||
public @org.jspecify.annotations.Nullable Range<Score> getScoreRange() {
|
||||
public @Nullable Range<Score> getScoreRange() {
|
||||
return delegate.getScoreRange();
|
||||
}
|
||||
|
||||
|
||||
@@ -61,14 +61,13 @@ public class MongoParametersParameterAccessor extends ParametersParameterAccesso
|
||||
public Range<Score> getScoreRange() {
|
||||
|
||||
MongoParameters mongoParameters = method.getParameters();
|
||||
int rangeIndex = mongoParameters.getScoreRangeIndex();
|
||||
|
||||
if (rangeIndex != -1) {
|
||||
return getValue(rangeIndex);
|
||||
if (mongoParameters.hasScoreRangeParameter()) {
|
||||
return getValue(mongoParameters.getScoreRangeIndex());
|
||||
}
|
||||
|
||||
int scoreIndex = mongoParameters.getScoreIndex();
|
||||
Bound<Score> maxDistance = scoreIndex == -1 ? Bound.unbounded() : Bound.inclusive((Score) getScore());
|
||||
Score score = getScore();
|
||||
Bound<Score> maxDistance = score != null ? Bound.inclusive(score) : Bound.unbounded();
|
||||
|
||||
return Range.of(Bound.unbounded(), maxDistance);
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
*/
|
||||
package org.springframework.data.mongodb.repository.query;
|
||||
|
||||
import static org.springframework.data.mongodb.core.query.Criteria.*;
|
||||
import static org.springframework.data.mongodb.core.query.Criteria.Placeholder;
|
||||
import static org.springframework.data.mongodb.core.query.Criteria.where;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
@@ -27,7 +28,6 @@ import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.bson.BsonRegularExpression;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
|
||||
import org.springframework.data.domain.Range;
|
||||
import org.springframework.data.domain.Range.Bound;
|
||||
import org.springframework.data.domain.Sort;
|
||||
@@ -118,8 +118,9 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
return new Criteria();
|
||||
}
|
||||
|
||||
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
|
||||
return null;
|
||||
if (isPartOfSearchQuery(part)) {
|
||||
skip(part, iterator);
|
||||
return new Criteria();
|
||||
}
|
||||
|
||||
PersistentPropertyPath<MongoPersistentProperty> path = context.getPersistentPropertyPath(part.getProperty());
|
||||
@@ -135,7 +136,8 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
return create(part, iterator);
|
||||
}
|
||||
|
||||
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
|
||||
if (isPartOfSearchQuery(part)) {
|
||||
skip(part, iterator);
|
||||
return base;
|
||||
}
|
||||
|
||||
@@ -176,15 +178,6 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
@SuppressWarnings("NullAway")
|
||||
private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator<Object> parameters) {
|
||||
|
||||
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
|
||||
|
||||
int numberOfArguments = part.getType().getNumberOfArguments();
|
||||
for (int i = 0; i < numberOfArguments; i++) {
|
||||
parameters.next();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
Type type = part.getType();
|
||||
|
||||
switch (type) {
|
||||
@@ -206,13 +199,13 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
return criteria.is(null);
|
||||
case NOT_IN:
|
||||
Object ninValue = parameters.next();
|
||||
if(ninValue instanceof Placeholder) {
|
||||
if (ninValue instanceof Placeholder) {
|
||||
return criteria.raw("$nin", ninValue);
|
||||
}
|
||||
return criteria.nin(valueAsList(ninValue, part));
|
||||
case IN:
|
||||
Object inValue = parameters.next();
|
||||
if(inValue instanceof Placeholder) {
|
||||
if (inValue instanceof Placeholder) {
|
||||
return criteria.raw("$in", inValue);
|
||||
}
|
||||
return criteria.in(valueAsList(inValue, part));
|
||||
@@ -231,7 +224,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString());
|
||||
case EXISTS:
|
||||
Object next = parameters.next();
|
||||
if(next instanceof Placeholder placeholder) {
|
||||
if (next instanceof Placeholder placeholder) {
|
||||
return criteria.raw("$exists", placeholder);
|
||||
} else {
|
||||
return criteria.exists((Boolean) next);
|
||||
@@ -355,7 +348,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
|
||||
if (property.isCollectionLike()) {
|
||||
Object next = parameters.next();
|
||||
if(next instanceof Placeholder) {
|
||||
if (next instanceof Placeholder) {
|
||||
return criteria.raw("$in", next);
|
||||
}
|
||||
return criteria.in(valueAsList(next, part));
|
||||
@@ -433,8 +426,7 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
streamable = streamable.map(it -> {
|
||||
if (it instanceof String sv) {
|
||||
|
||||
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode),
|
||||
regexOptions);
|
||||
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), regexOptions);
|
||||
}
|
||||
return it;
|
||||
});
|
||||
@@ -468,10 +460,23 @@ public class MongoQueryCreator extends AbstractQueryCreator<Query, Criteria> {
|
||||
return false;
|
||||
}
|
||||
|
||||
private boolean isPartOfSearchQuery(Part part) {
|
||||
return isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN));
|
||||
}
|
||||
|
||||
private static void skip(Part part, Iterator<?> parameters) {
|
||||
|
||||
int total = part.getNumberOfArguments();
|
||||
int i = 0;
|
||||
while (parameters.hasNext() && i < total) {
|
||||
parameters.next();
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt},
|
||||
* {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}.
|
||||
* <br />
|
||||
* {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. <br />
|
||||
* In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are
|
||||
* used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used
|
||||
* for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}.
|
||||
|
||||
@@ -15,18 +15,16 @@
|
||||
*/
|
||||
package org.springframework.data.mongodb.repository.query;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import org.bson.Document;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
|
||||
import org.springframework.data.domain.Page;
|
||||
import org.springframework.data.domain.Pageable;
|
||||
import org.springframework.data.domain.Range;
|
||||
import org.springframework.data.domain.ScoringFunction;
|
||||
import org.springframework.data.domain.SearchResult;
|
||||
import org.springframework.data.domain.SearchResults;
|
||||
import org.springframework.data.domain.Similarity;
|
||||
@@ -37,6 +35,7 @@ import org.springframework.data.geo.GeoPage;
|
||||
import org.springframework.data.geo.GeoResult;
|
||||
import org.springframework.data.geo.GeoResults;
|
||||
import org.springframework.data.geo.Point;
|
||||
import org.springframework.data.mongodb.core.ExecutableAggregationOperation.TerminatingAggregation;
|
||||
import org.springframework.data.mongodb.core.ExecutableFindOperation;
|
||||
import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery;
|
||||
import org.springframework.data.mongodb.core.ExecutableFindOperation.TerminatingFind;
|
||||
@@ -45,12 +44,13 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.Executabl
|
||||
import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove;
|
||||
import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate;
|
||||
import org.springframework.data.mongodb.core.MongoOperations;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
|
||||
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
|
||||
import org.springframework.data.mongodb.core.query.NearQuery;
|
||||
import org.springframework.data.mongodb.core.query.Query;
|
||||
import org.springframework.data.mongodb.core.query.UpdateDefinition;
|
||||
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
|
||||
import org.springframework.data.mongodb.repository.util.SliceUtils;
|
||||
import org.springframework.data.repository.query.QueryMethod;
|
||||
import org.springframework.data.support.PageableExecutionUtils;
|
||||
@@ -186,7 +186,7 @@ public interface MongoQueryExecution {
|
||||
return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results;
|
||||
}
|
||||
|
||||
@SuppressWarnings({"unchecked","NullAway"})
|
||||
@SuppressWarnings({ "unchecked", "NullAway" })
|
||||
GeoResults<Object> doExecuteQuery(Query query) {
|
||||
|
||||
Point nearLocation = accessor.getGeoNearLocation();
|
||||
@@ -225,52 +225,53 @@ public interface MongoQueryExecution {
|
||||
* {@link MongoQueryExecution} to execute vector search.
|
||||
*
|
||||
* @author Mark Paluch
|
||||
* @author Chistoph Strobl
|
||||
* @since 5.0
|
||||
*/
|
||||
class VectorSearchExecution implements MongoQueryExecution {
|
||||
|
||||
private final MongoOperations operations;
|
||||
private final MongoQueryMethod method;
|
||||
private final TypeInformation<?> returnType;
|
||||
private final String collectionName;
|
||||
private final VectorSearchDelegate.QueryMetadata queryMetadata;
|
||||
private final List<AggregationOperation> pipeline;
|
||||
private final Class<?> targetType;
|
||||
private final ScoringFunction scoringFunction;
|
||||
private final AggregationPipeline pipeline;
|
||||
|
||||
public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
|
||||
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
|
||||
VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
|
||||
QueryContainer queryContainer) {
|
||||
this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(),
|
||||
queryContainer.scoringFunction());
|
||||
}
|
||||
|
||||
public VectorSearchExecution(MongoOperations operations, Class<?> targetType, String collectionName,
|
||||
TypeInformation<?> returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) {
|
||||
|
||||
this.operations = operations;
|
||||
this.returnType = returnType;
|
||||
this.collectionName = collectionName;
|
||||
this.queryMetadata = queryMetadata;
|
||||
this.method = method;
|
||||
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
|
||||
this.targetType = targetType;
|
||||
this.scoringFunction = scoringFunction;
|
||||
this.pipeline = pipeline;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public Object execute(Query query) {
|
||||
|
||||
AggregationResults<?> aggregated = operations.aggregate(
|
||||
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName,
|
||||
queryMetadata.outputType());
|
||||
TerminatingAggregation<?> executableAggregation = operations.aggregateAndReturn(targetType)
|
||||
.inCollection(collectionName).by(TypedAggregation.newAggregation(targetType, pipeline.getOperations()));
|
||||
|
||||
List<?> mappedResults = aggregated.getMappedResults();
|
||||
|
||||
if (isSearchResult(method.getReturnType())) {
|
||||
|
||||
List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class);
|
||||
List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size());
|
||||
|
||||
for (int i = 0; i < mappedResults.size(); i++) {
|
||||
Document document = rawResults.get(i);
|
||||
SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i),
|
||||
Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction()));
|
||||
|
||||
result.add(searchResult);
|
||||
}
|
||||
|
||||
return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result);
|
||||
if (!isSearchResult(returnType)) {
|
||||
return executableAggregation.all().getMappedResults();
|
||||
}
|
||||
|
||||
return mappedResults;
|
||||
AggregationResults<? extends SearchResult<?>> result = executableAggregation
|
||||
.map((raw, container) -> new SearchResult<>(container.get(),
|
||||
Similarity.raw(raw.getDouble("__score__"), scoringFunction)))
|
||||
.all();
|
||||
|
||||
return isListOfSearchResult(returnType) ? result.getMappedResults()
|
||||
: new SearchResults(result.getMappedResults());
|
||||
}
|
||||
|
||||
private static boolean isListOfSearchResult(TypeInformation<?> returnType) {
|
||||
|
||||
@@ -18,12 +18,9 @@ package org.springframework.data.mongodb.repository.query;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.bson.Document;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
import org.reactivestreams.Publisher;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.data.convert.DtoInstantiatingConverter;
|
||||
import org.springframework.data.domain.Pageable;
|
||||
@@ -36,11 +33,12 @@ import org.springframework.data.geo.Point;
|
||||
import org.springframework.data.mapping.model.EntityInstantiators;
|
||||
import org.springframework.data.mongodb.core.ReactiveMongoOperations;
|
||||
import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
|
||||
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
|
||||
import org.springframework.data.mongodb.core.query.NearQuery;
|
||||
import org.springframework.data.mongodb.core.query.Query;
|
||||
import org.springframework.data.mongodb.core.query.UpdateDefinition;
|
||||
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
|
||||
import org.springframework.data.repository.query.ResultProcessor;
|
||||
import org.springframework.data.repository.query.ReturnedType;
|
||||
import org.springframework.data.util.ReactiveWrappers;
|
||||
@@ -134,24 +132,24 @@ interface ReactiveMongoQueryExecution {
|
||||
class VectorSearchExecution implements ReactiveMongoQueryExecution {
|
||||
|
||||
private final ReactiveMongoOperations operations;
|
||||
private final VectorSearchDelegate.QueryMetadata queryMetadata;
|
||||
private final List<AggregationOperation> pipeline;
|
||||
private final QueryContainer queryMetadata;
|
||||
private final AggregationPipeline pipeline;
|
||||
private final boolean returnSearchResult;
|
||||
|
||||
public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method,
|
||||
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
|
||||
VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, QueryContainer queryMetadata) {
|
||||
|
||||
this.operations = operations;
|
||||
this.queryMetadata = queryMetadata;
|
||||
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
|
||||
this.pipeline = queryMetadata.pipeline();
|
||||
this.returnSearchResult = isSearchResult(method.getReturnType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Publisher<? extends Object> execute(Query query, Class<?> type, String collection) {
|
||||
|
||||
Flux<Document> aggregate = operations
|
||||
.aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class);
|
||||
Flux<Document> aggregate = operations.aggregate(
|
||||
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline.getOperations()), collection,
|
||||
Document.class);
|
||||
|
||||
return aggregate.map(document -> {
|
||||
|
||||
|
||||
@@ -19,13 +19,13 @@ import reactor.core.publisher.Mono;
|
||||
|
||||
import org.bson.Document;
|
||||
import org.reactivestreams.Publisher;
|
||||
|
||||
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
|
||||
import org.springframework.data.mongodb.core.MongoOperations;
|
||||
import org.springframework.data.mongodb.core.ReactiveMongoOperations;
|
||||
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
|
||||
import org.springframework.data.mongodb.core.query.Query;
|
||||
import org.springframework.data.mongodb.repository.VectorSearch;
|
||||
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
|
||||
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
|
||||
import org.springframework.data.repository.query.ResultProcessor;
|
||||
import org.springframework.data.repository.query.ValueExpressionDelegate;
|
||||
@@ -84,11 +84,11 @@ public class ReactiveVectorSearchAggregation extends AbstractReactiveMongoQuery
|
||||
|
||||
ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue,
|
||||
expressionEvaluator);
|
||||
VectorSearchDelegate.QueryMetadata query = delegate.createQuery(expressionEvaluator, processor, accessor,
|
||||
typeToRead, codec, bindingContext);
|
||||
QueryContainer query = delegate.createQuery(expressionEvaluator, processor, accessor, typeToRead, codec,
|
||||
bindingContext);
|
||||
|
||||
ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution(
|
||||
mongoOperations, method, query, accessor);
|
||||
mongoOperations, method, query);
|
||||
|
||||
return execution.execute(query.query(), Document.class, collectionEntity.getCollection());
|
||||
});
|
||||
|
||||
@@ -15,16 +15,17 @@
|
||||
*/
|
||||
package org.springframework.data.mongodb.repository.query;
|
||||
|
||||
import org.jspecify.annotations.Nullable;
|
||||
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
|
||||
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
|
||||
import org.springframework.data.mongodb.core.MongoOperations;
|
||||
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
|
||||
import org.springframework.data.mongodb.core.query.Query;
|
||||
import org.springframework.data.mongodb.repository.VectorSearch;
|
||||
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
|
||||
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
|
||||
import org.springframework.data.repository.query.ResultProcessor;
|
||||
import org.springframework.data.repository.query.ValueExpressionDelegate;
|
||||
import org.springframework.lang.Nullable;
|
||||
|
||||
/**
|
||||
* {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived
|
||||
@@ -62,20 +63,19 @@ public class VectorSearchAggregation extends AbstractMongoQuery {
|
||||
this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor,
|
||||
@Nullable Class<?> typeToRead) {
|
||||
|
||||
VectorSearchDelegate.QueryMetadata query = createVectorSearchQuery(processor, accessor, typeToRead);
|
||||
QueryContainer query = createVectorSearchQuery(processor, accessor, typeToRead);
|
||||
|
||||
MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations,
|
||||
method, collectionEntity.getCollection(), query, accessor);
|
||||
method, collectionEntity.getCollection(), query);
|
||||
|
||||
return execution.execute(query.query());
|
||||
}
|
||||
|
||||
VectorSearchDelegate.QueryMetadata createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor,
|
||||
QueryContainer createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor,
|
||||
@Nullable Class<?> typeToRead) {
|
||||
|
||||
ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor);
|
||||
|
||||
@@ -20,7 +20,6 @@ import java.util.List;
|
||||
|
||||
import org.bson.Document;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
|
||||
import org.springframework.data.domain.Limit;
|
||||
import org.springframework.data.domain.Range;
|
||||
import org.springframework.data.domain.Score;
|
||||
@@ -34,6 +33,7 @@ import org.springframework.data.mapping.model.ValueExpressionEvaluator;
|
||||
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
|
||||
import org.springframework.data.mongodb.core.aggregation.Aggregation;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
|
||||
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
|
||||
import org.springframework.data.mongodb.core.convert.MongoConverter;
|
||||
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
|
||||
@@ -47,6 +47,7 @@ import org.springframework.data.repository.query.ResultProcessor;
|
||||
import org.springframework.data.repository.query.ValueExpressionDelegate;
|
||||
import org.springframework.data.repository.query.parser.Part;
|
||||
import org.springframework.data.repository.query.parser.PartTree;
|
||||
import org.springframework.util.NumberUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
@@ -58,32 +59,35 @@ class VectorSearchDelegate {
|
||||
|
||||
private final VectorSearchQueryFactory queryFactory;
|
||||
private final VectorSearchOperation.SearchType searchType;
|
||||
private final String indexName;
|
||||
private final @Nullable Integer numCandidates;
|
||||
private final @Nullable String numCandidatesExpression;
|
||||
private final Limit limit;
|
||||
private final @Nullable String limitExpression;
|
||||
private final MongoConverter converter;
|
||||
|
||||
public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) {
|
||||
VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) {
|
||||
|
||||
VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow();
|
||||
|
||||
this.searchType = vectorSearch.searchType();
|
||||
this.indexName = method.getAnnotatedHint();
|
||||
|
||||
if (StringUtils.hasText(vectorSearch.numCandidates())) {
|
||||
|
||||
ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates());
|
||||
|
||||
if (expression.isLiteral()) {
|
||||
numCandidates = Integer.parseInt(vectorSearch.numCandidates());
|
||||
numCandidatesExpression = null;
|
||||
this.numCandidates = Integer.parseInt(vectorSearch.numCandidates());
|
||||
this.numCandidatesExpression = null;
|
||||
} else {
|
||||
numCandidates = null;
|
||||
numCandidatesExpression = vectorSearch.numCandidates();
|
||||
this.numCandidates = null;
|
||||
this.numCandidatesExpression = vectorSearch.numCandidates();
|
||||
}
|
||||
|
||||
} else {
|
||||
numCandidates = null;
|
||||
numCandidatesExpression = null;
|
||||
this.numCandidates = null;
|
||||
this.numCandidatesExpression = null;
|
||||
}
|
||||
|
||||
if (StringUtils.hasText(vectorSearch.limit())) {
|
||||
@@ -91,26 +95,26 @@ class VectorSearchDelegate {
|
||||
ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit());
|
||||
|
||||
if (expression.isLiteral()) {
|
||||
limit = Limit.of(Integer.parseInt(vectorSearch.limit()));
|
||||
limitExpression = null;
|
||||
this.limit = Limit.of(Integer.parseInt(vectorSearch.limit()));
|
||||
this.limitExpression = null;
|
||||
} else {
|
||||
limit = Limit.unlimited();
|
||||
limitExpression = vectorSearch.limit();
|
||||
this.limit = Limit.unlimited();
|
||||
this.limitExpression = vectorSearch.limit();
|
||||
}
|
||||
|
||||
} else {
|
||||
limit = Limit.unlimited();
|
||||
limitExpression = null;
|
||||
this.limit = Limit.unlimited();
|
||||
this.limitExpression = null;
|
||||
}
|
||||
|
||||
this.converter = converter;
|
||||
|
||||
if (StringUtils.hasText(vectorSearch.filter())) {
|
||||
queryFactory = StringUtils.hasText(vectorSearch.path())
|
||||
this.queryFactory = StringUtils.hasText(vectorSearch.path())
|
||||
? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path())
|
||||
: new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity());
|
||||
} else {
|
||||
queryFactory = new PartTreeQueryFactory(
|
||||
this.queryFactory = new PartTreeQueryFactory(
|
||||
new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()),
|
||||
converter.getMappingContext());
|
||||
}
|
||||
@@ -119,43 +123,136 @@ class VectorSearchDelegate {
|
||||
/**
|
||||
* Create Query Metadata for {@code $vectorSearch}.
|
||||
*/
|
||||
public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor,
|
||||
QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor,
|
||||
MongoParameterAccessor accessor, @Nullable Class<?> typeToRead, ParameterBindingDocumentCodec codec,
|
||||
ParameterBindingContext context) {
|
||||
|
||||
Integer numCandidates = null;
|
||||
Limit limit;
|
||||
String scoreField = "__score__";
|
||||
Class<?> outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType();
|
||||
VectorSearchInput query = queryFactory.createQuery(accessor, codec, context);
|
||||
VectorSearchInput vectorSearchInput = createSearchInput(evaluator, accessor, codec, context);
|
||||
AggregationPipeline pipeline = createVectorSearchPipeline(vectorSearchInput, scoreField, outputType, accessor,
|
||||
evaluator);
|
||||
|
||||
if (this.limitExpression != null) {
|
||||
Object value = evaluator.evaluate(this.limitExpression);
|
||||
limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue());
|
||||
} else if (this.limit.isLimited()) {
|
||||
limit = this.limit;
|
||||
} else {
|
||||
limit = accessor.getLimit();
|
||||
}
|
||||
return new QueryContainer(vectorSearchInput.path, scoreField, vectorSearchInput.query, pipeline, searchType,
|
||||
outputType, getSimilarityFunction(accessor), indexName);
|
||||
}
|
||||
|
||||
if (limit.isLimited()) {
|
||||
query.query().limit(limit);
|
||||
}
|
||||
@SuppressWarnings("NullAway")
|
||||
AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType,
|
||||
MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) {
|
||||
|
||||
Vector vector = accessor.getVector();
|
||||
Score score = accessor.getScore();
|
||||
Range<Score> distance = accessor.getScoreRange();
|
||||
Limit limit = Limit.of(input.query().getLimit());
|
||||
|
||||
List<AggregationOperation> stages = new ArrayList<>();
|
||||
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector)
|
||||
.limit(limit);
|
||||
|
||||
Integer candidates = null;
|
||||
if (this.numCandidatesExpression != null) {
|
||||
numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
|
||||
candidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
|
||||
} else if (this.numCandidates != null) {
|
||||
numCandidates = this.numCandidates;
|
||||
} else if (query.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN
|
||||
candidates = this.numCandidates;
|
||||
} else if (input.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN
|
||||
|| searchType == VectorSearchOperation.SearchType.DEFAULT)) {
|
||||
|
||||
/*
|
||||
MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy.
|
||||
*/
|
||||
numCandidates = query.query().getLimit() * 20;
|
||||
candidates = input.query().getLimit() * 20;
|
||||
}
|
||||
|
||||
return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates,
|
||||
getSimilarityFunction(accessor));
|
||||
if (candidates != null) {
|
||||
$vectorSearch = $vectorSearch.numCandidates(candidates);
|
||||
}
|
||||
//
|
||||
$vectorSearch = $vectorSearch.filter(input.query.getQueryObject());
|
||||
$vectorSearch = $vectorSearch.searchType(this.searchType);
|
||||
$vectorSearch = $vectorSearch.withSearchScore(scoreField);
|
||||
|
||||
if (score != null) {
|
||||
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
|
||||
c.gt(score.getValue());
|
||||
});
|
||||
} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
|
||||
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
|
||||
Range.Bound<Score> lower = distance.getLowerBound();
|
||||
if (lower.isBounded()) {
|
||||
double value = lower.getValue().get().getValue();
|
||||
if (lower.isInclusive()) {
|
||||
c.gte(value);
|
||||
} else {
|
||||
c.gt(value);
|
||||
}
|
||||
}
|
||||
|
||||
Range.Bound<Score> upper = distance.getUpperBound();
|
||||
if (upper.isBounded()) {
|
||||
|
||||
double value = upper.getValue().get().getValue();
|
||||
if (upper.isInclusive()) {
|
||||
c.lte(value);
|
||||
} else {
|
||||
c.lt(value);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
stages.add($vectorSearch);
|
||||
|
||||
if (input.query().isSorted()) {
|
||||
|
||||
stages.add(ctx -> {
|
||||
|
||||
Document mappedSort = ctx.getMappedObject(input.query().getSortObject(), outputType);
|
||||
mappedSort.append(scoreField, -1);
|
||||
return ctx.getMappedObject(new Document("$sort", mappedSort));
|
||||
});
|
||||
} else {
|
||||
stages.add(Aggregation.sort(Sort.Direction.DESC, scoreField));
|
||||
}
|
||||
|
||||
return new AggregationPipeline(stages);
|
||||
}
|
||||
|
||||
private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor,
|
||||
ParameterBindingDocumentCodec codec, ParameterBindingContext context) {
|
||||
|
||||
VectorSearchInput input = queryFactory.createQuery(accessor, codec, context);
|
||||
Limit limit = getLimit(evaluator, accessor);
|
||||
if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) {
|
||||
input.query().limit(limit);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) {
|
||||
|
||||
if (this.limitExpression != null) {
|
||||
|
||||
Object value = evaluator.evaluate(this.limitExpression);
|
||||
if (value != null) {
|
||||
if (value instanceof Limit l) {
|
||||
return l;
|
||||
}
|
||||
if (value instanceof Number n) {
|
||||
return Limit.of(n.intValue());
|
||||
}
|
||||
if (value instanceof String s) {
|
||||
return Limit.of(NumberUtils.parseNumber(s, Integer.class));
|
||||
}
|
||||
throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number");
|
||||
}
|
||||
}
|
||||
|
||||
if (this.limit.isLimited()) {
|
||||
return this.limit;
|
||||
}
|
||||
|
||||
return accessor.getLimit();
|
||||
}
|
||||
|
||||
public String getQueryString() {
|
||||
@@ -192,82 +289,10 @@ class VectorSearchDelegate {
|
||||
* @param query
|
||||
* @param searchType
|
||||
* @param outputType
|
||||
* @param numCandidates
|
||||
* @param scoringFunction
|
||||
*/
|
||||
public record QueryMetadata(String path, String scoreField, Query query, VectorSearchOperation.SearchType searchType,
|
||||
Class<?> outputType, @Nullable Integer numCandidates, ScoringFunction scoringFunction) {
|
||||
|
||||
/**
|
||||
* Create the Aggregation Pipeline.
|
||||
*
|
||||
* @param queryMethod
|
||||
* @param accessor
|
||||
* @return
|
||||
*/
|
||||
public List<AggregationOperation> getAggregationPipeline(MongoQueryMethod queryMethod,
|
||||
MongoParameterAccessor accessor) {
|
||||
|
||||
Vector vector = accessor.getVector();
|
||||
Score score = accessor.getScore();
|
||||
Range<Score> distance = accessor.getScoreRange();
|
||||
Limit limit = Limit.unlimited();
|
||||
|
||||
if (query.isLimited()) {
|
||||
limit = Limit.of(query.getLimit());
|
||||
}
|
||||
|
||||
List<AggregationOperation> stages = new ArrayList<>();
|
||||
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(queryMethod.getAnnotatedHint()).path(path())
|
||||
.vector(vector).limit(limit);
|
||||
|
||||
if (numCandidates() != null) {
|
||||
$vectorSearch = $vectorSearch.numCandidates(numCandidates());
|
||||
}
|
||||
|
||||
$vectorSearch = $vectorSearch.filter(query.getQueryObject());
|
||||
$vectorSearch = $vectorSearch.searchType(searchType());
|
||||
$vectorSearch = $vectorSearch.withSearchScore(scoreField());
|
||||
|
||||
if (score != null) {
|
||||
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
|
||||
c.gt(score.getValue());
|
||||
});
|
||||
} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
|
||||
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
|
||||
Range.Bound<Score> lower = distance.getLowerBound();
|
||||
if (lower.isBounded()) {
|
||||
double value = lower.getValue().get().getValue();
|
||||
if (lower.isInclusive()) {
|
||||
c.gte(value);
|
||||
} else {
|
||||
c.gt(value);
|
||||
}
|
||||
}
|
||||
|
||||
Range.Bound<Score> upper = distance.getUpperBound();
|
||||
if (upper.isBounded()) {
|
||||
|
||||
double value = upper.getValue().get().getValue();
|
||||
if (upper.isInclusive()) {
|
||||
c.lte(value);
|
||||
} else {
|
||||
c.lt(value);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
stages.add($vectorSearch);
|
||||
|
||||
if (query.isSorted()) {
|
||||
// TODO stages.add(Aggregation.sort(query.with()));
|
||||
} else {
|
||||
stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__"));
|
||||
}
|
||||
|
||||
return stages;
|
||||
}
|
||||
record QueryContainer(String path, String scoreField, Query query, AggregationPipeline pipeline,
|
||||
VectorSearchOperation.SearchType searchType, Class<?> outputType, ScoringFunction scoringFunction, String index) {
|
||||
|
||||
}
|
||||
|
||||
@@ -368,11 +393,12 @@ class VectorSearchDelegate {
|
||||
this.tree = tree;
|
||||
}
|
||||
|
||||
@SuppressWarnings("NullAway")
|
||||
public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
|
||||
ParameterBindingContext context) {
|
||||
|
||||
MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(),
|
||||
false, true);
|
||||
MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false,
|
||||
true);
|
||||
|
||||
Query query = creator.createQuery(parameterAccessor.getSort());
|
||||
|
||||
|
||||
@@ -81,16 +81,15 @@ public class VectorSearchTests {
|
||||
|
||||
@Override
|
||||
public MongoClient mongoClient() {
|
||||
atlasLocal.start();
|
||||
return MongoClients.create(atlasLocal.getConnectionString());
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeAll
|
||||
static void beforeAll() throws InterruptedException {
|
||||
|
||||
atlasLocal.start();
|
||||
|
||||
System.out.println(atlasLocal.getConnectionString());
|
||||
client = MongoClients.create(atlasLocal.getConnectionString());
|
||||
template = new MongoTestTemplate(client, "vector-search-tests");
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
|
||||
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
|
||||
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
|
||||
import org.springframework.data.mongodb.repository.VectorSearch;
|
||||
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
|
||||
import org.springframework.data.projection.ProjectionFactory;
|
||||
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
|
||||
import org.springframework.data.repository.CrudRepository;
|
||||
@@ -68,7 +69,7 @@ class VectorSearchAggregationUnitTests {
|
||||
VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear",
|
||||
String.class, Vector.class, Score.class, Limit.class);
|
||||
|
||||
VectorSearchDelegate.QueryMetadata query = aggregation.createVectorSearchQuery(
|
||||
QueryContainer query = aggregation.createVectorSearchQuery(
|
||||
aggregation.getQueryMethod().getResultProcessor(),
|
||||
new MongoParametersParameterAccessor(aggregation.getQueryMethod(),
|
||||
new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }),
|
||||
|
||||
@@ -15,23 +15,30 @@
|
||||
*/
|
||||
package org.springframework.data.mongodb.repository.query;
|
||||
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.List;
|
||||
|
||||
import org.bson.Document;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.data.domain.Limit;
|
||||
import org.springframework.data.domain.Score;
|
||||
import org.springframework.data.domain.SearchResults;
|
||||
import org.springframework.data.domain.Vector;
|
||||
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
|
||||
import org.springframework.data.mongodb.core.aggregation.Aggregation;
|
||||
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
|
||||
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
|
||||
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
|
||||
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
|
||||
import org.springframework.data.mongodb.core.mapping.Field;
|
||||
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
|
||||
import org.springframework.data.mongodb.repository.VectorSearch;
|
||||
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
|
||||
import org.springframework.data.mongodb.util.aggregation.TestAggregationContext;
|
||||
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
|
||||
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
|
||||
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
|
||||
@@ -44,6 +51,7 @@ import org.springframework.data.repository.query.ValueExpressionDelegate;
|
||||
* Unit tests for {@link VectorSearchDelegate}.
|
||||
*
|
||||
* @author Mark Paluch
|
||||
* @author Christoph Strobl
|
||||
*/
|
||||
class VectorSearchDelegateUnitTests {
|
||||
|
||||
@@ -57,10 +65,10 @@ class VectorSearchDelegateUnitTests {
|
||||
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
|
||||
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
|
||||
|
||||
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
|
||||
QueryContainer container = createQueryContainer(queryMethod, accessor);
|
||||
|
||||
assertThat(query.query().getLimit()).isEqualTo(10);
|
||||
assertThat(query.numCandidates()).isEqualTo(10 * 20);
|
||||
assertThat(container.query().getLimit()).isEqualTo(10);
|
||||
assertThat(numCandidates(container.pipeline())).isEqualTo(10 * 20);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -71,10 +79,10 @@ class VectorSearchDelegateUnitTests {
|
||||
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
|
||||
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
|
||||
|
||||
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
|
||||
QueryContainer container = createQueryContainer(queryMethod, accessor);
|
||||
|
||||
assertThat(query.query().getLimit()).isEqualTo(10);
|
||||
assertThat(query.numCandidates()).isNull();
|
||||
assertThat(container.query().getLimit()).isEqualTo(10);
|
||||
assertThat(numCandidates(container.pipeline())).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -86,19 +94,87 @@ class VectorSearchDelegateUnitTests {
|
||||
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
|
||||
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11));
|
||||
|
||||
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
|
||||
QueryContainer container = createQueryContainer(queryMethod, accessor);
|
||||
|
||||
assertThat(query.query().getLimit()).isEqualTo(11);
|
||||
assertThat(query.numCandidates()).isEqualTo(11 * 20);
|
||||
assertThat(container.query().getLimit()).isEqualTo(11);
|
||||
assertThat(numCandidates(container.pipeline())).isEqualTo(11 * 20);
|
||||
}
|
||||
|
||||
private VectorSearchDelegate.QueryMetadata createQueryMetadata(MongoQueryMethod queryMethod,
|
||||
MongoParametersParameterAccessor accessor) {
|
||||
@Test
|
||||
void considersDerivedQueryPart() throws ReflectiveOperationException {
|
||||
|
||||
Method method = VectorSearchRepository.class.getMethod("searchTop10ByFirstNameAndEmbeddingNear", String.class,
|
||||
Vector.class, Score.class);
|
||||
|
||||
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
|
||||
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, "spring", Vector.of(1, 2), Score.of(1));
|
||||
|
||||
QueryContainer container = createQueryContainer(queryMethod, accessor);
|
||||
|
||||
assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter",
|
||||
new Document("first_name", "spring"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void considersDerivedQueryPartInDifferentOrder() throws ReflectiveOperationException {
|
||||
|
||||
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNearAndFirstName", Vector.class,
|
||||
Score.class, String.class);
|
||||
|
||||
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
|
||||
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), "spring");
|
||||
|
||||
QueryContainer container = createQueryContainer(queryMethod, accessor);
|
||||
|
||||
assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter",
|
||||
new Document("first_name", "spring"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void defaultSortsByScore() throws NoSuchMethodException {
|
||||
|
||||
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class,
|
||||
Limit.class);
|
||||
|
||||
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
|
||||
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(10));
|
||||
|
||||
QueryContainer container = createQueryContainer(queryMethod, accessor);
|
||||
|
||||
List<Document> stages = container.pipeline().lastOperation()
|
||||
.toPipelineStages(TestAggregationContext.contextFor(WithVector.class));
|
||||
|
||||
assertThat(stages).containsExactly(new Document("$sort", new Document("__score__", -1)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void usesDerivedSort() throws NoSuchMethodException {
|
||||
|
||||
Method method = VectorSearchRepository.class.getMethod("searchByEmbeddingNearOrderByFirstName", Vector.class,
|
||||
Score.class, Limit.class);
|
||||
|
||||
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
|
||||
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11));
|
||||
|
||||
QueryContainer container = createQueryContainer(queryMethod, accessor);
|
||||
AggregationPipeline aggregationPipeline = container.pipeline();
|
||||
|
||||
List<Document> stages = aggregationPipeline.lastOperation()
|
||||
.toPipelineStages(TestAggregationContext.contextFor(WithVector.class));
|
||||
|
||||
assertThat(stages).containsExactly(new Document("$sort", new Document("first_name", 1).append("__score__", -1)));
|
||||
}
|
||||
|
||||
Document vectorSearchStageOf(AggregationPipeline pipeline) {
|
||||
return pipeline.firstOperation().toPipelineStages(TestAggregationContext.contextFor(WithVector.class)).get(0);
|
||||
}
|
||||
|
||||
private QueryContainer createQueryContainer(MongoQueryMethod queryMethod, MongoParametersParameterAccessor accessor) {
|
||||
|
||||
VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create());
|
||||
|
||||
return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor,
|
||||
Object.class, new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class));
|
||||
return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, null,
|
||||
new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class));
|
||||
}
|
||||
|
||||
private MongoQueryMethod getMongoQueryMethod(Method method) {
|
||||
@@ -110,21 +186,69 @@ class VectorSearchDelegateUnitTests {
|
||||
return new MongoParametersParameterAccessor(queryMethod, values);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private static Integer numCandidates(AggregationPipeline pipeline) {
|
||||
|
||||
Document $vectorSearch = pipeline.firstOperation().toPipelineStages(Aggregation.DEFAULT_CONTEXT).get(0);
|
||||
if ($vectorSearch.containsKey("$vectorSearch")) {
|
||||
Object value = $vectorSearch.get("$vectorSearch", Document.class).get("numCandidates");
|
||||
return value instanceof Number i ? i.intValue() : null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
interface VectorSearchRepository extends Repository<WithVector, String> {
|
||||
|
||||
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
|
||||
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity);
|
||||
|
||||
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
|
||||
SearchResults<WithVector> searchTop10ByFirstNameAndEmbeddingNear(String firstName, Vector vector, Score similarity);
|
||||
|
||||
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
|
||||
SearchResults<WithVector> searchTop10ByEmbeddingNearAndFirstName(Vector vector, Score similarity, String firstname);
|
||||
|
||||
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN)
|
||||
SearchResults<WithVector> searchTop10EnnByEmbeddingNear(Vector vector, Score similarity);
|
||||
|
||||
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
|
||||
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit);
|
||||
|
||||
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
|
||||
SearchResults<WithVector> searchByEmbeddingNearOrderByFirstName(Vector vector, Score similarity, Limit limit);
|
||||
|
||||
}
|
||||
|
||||
static class WithVector {
|
||||
|
||||
Vector embedding;
|
||||
|
||||
String lastName;
|
||||
|
||||
@Field("first_name") String firstName;
|
||||
|
||||
public Vector getEmbedding() {
|
||||
return embedding;
|
||||
}
|
||||
|
||||
public void setEmbedding(Vector embedding) {
|
||||
this.embedding = embedding;
|
||||
}
|
||||
|
||||
public String getLastName() {
|
||||
return lastName;
|
||||
}
|
||||
|
||||
public void setLastName(String lastName) {
|
||||
this.lastName = lastName;
|
||||
}
|
||||
|
||||
public String getFirstName() {
|
||||
return firstName;
|
||||
}
|
||||
|
||||
public void setFirstName(String firstName) {
|
||||
this.firstName = firstName;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ Java::
|
||||
[source,java,indent=0,subs="verbatim,quotes",role="primary"]
|
||||
----
|
||||
VectorIndex index = new VectorIndex("vector_index")
|
||||
.addVector("plotEmbedding"), vector -> vector.dimensions(1536).similarity(COSINE)) <1>
|
||||
.addVector("plotEmbedding", vector -> vector.dimensions(1536).similarity(COSINE)) <1>
|
||||
.addFilter("year"); <2>
|
||||
|
||||
mongoTemplate.searchIndexOps(Movie.class) <3>
|
||||
|
||||
@@ -6,13 +6,13 @@ Annotated search methods use the `@VectorSearch` annotation to define parameters
|
||||
----
|
||||
interface CommentRepository extends Repository<Comment, String> {
|
||||
|
||||
@VectorSearch(indexName = "cos-index", filter = "{country: ?0}")
|
||||
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
@VectorSearch(indexName = "cos-index", filter = "{country: ?0}", limit="100", numCandidates="2000")
|
||||
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
Score distance);
|
||||
|
||||
@VectorSearch(indexName = "my-index", filter = "{country: ?0}", numCandidates = "#{#limit * 20}",
|
||||
@VectorSearch(indexName = "my-index", filter = "{country: ?0}", limit="?3", numCandidates = "#{#limit * 20}",
|
||||
searchType = VectorSearchOperation.SearchType.ANN)
|
||||
List<WithVector> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit);
|
||||
List<Comment> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit);
|
||||
}
|
||||
----
|
||||
====
|
||||
|
||||
@@ -6,14 +6,14 @@ MongoDB Search methods must use the `@VectorSearch` annotation to define the ind
|
||||
----
|
||||
interface CommentRepository extends Repository<Comment, String> {
|
||||
|
||||
@VectorSearch(indexName = "my-index")
|
||||
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Score score);
|
||||
@VectorSearch(indexName = "my-index", numCandidates="200")
|
||||
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Score score);
|
||||
|
||||
@VectorSearch(indexName = "my-index")
|
||||
SearchResults<Comment> searchByEmbeddingWithin(Vector vector, Range<Similarity> range);
|
||||
@VectorSearch(indexName = "my-index", numCandidates="200")
|
||||
SearchResults<Comment> searchTop10ByEmbeddingWithin(Vector vector, Range<Similarity> range);
|
||||
|
||||
@VectorSearch(indexName = "my-index")
|
||||
SearchResults<Comment> searchByCountryAndEmbeddingWithin(String country, Vector vector, Range<Similarity> range);
|
||||
@VectorSearch(indexName = "my-index", numCandidates="200")
|
||||
SearchResults<Comment> searchTop10ByCountryAndEmbeddingWithin(String country, Vector vector, Range<Similarity> range);
|
||||
}
|
||||
----
|
||||
====
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
----
|
||||
interface CommentRepository extends Repository<Comment, String> {
|
||||
|
||||
@VectorSearch(indexName = "my-index")
|
||||
@VectorSearch(indexName = "my-index", numCandidates="#{#limit.max() * 20}")
|
||||
SearchResults<Comment> searchByCountryAndEmbeddingNear(String country, Vector vector, Score score,
|
||||
Limit limit);
|
||||
|
||||
@VectorSearch(indexName = "my-index")
|
||||
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
@VectorSearch(indexName = "my-index", limit="10", numCandidates="200")
|
||||
SearchResults<Comment> searchByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
Score score);
|
||||
|
||||
}
|
||||
@@ -17,3 +17,9 @@ interface CommentRepository extends Repository<Comment, String> {
|
||||
SearchResults<Comment> results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10));
|
||||
----
|
||||
====
|
||||
|
||||
[TIP]
|
||||
====
|
||||
The MongoDB https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/[vector search aggregation] stage defines a set of required arguments and restrictions.
|
||||
Please make sure to follow the guidelines and make sure to provide required arguments like `limit`.
|
||||
====
|
||||
|
||||
@@ -9,13 +9,13 @@ The scoring function defaults to `ScoringFunction.unspecified()` as there is no
|
||||
interface CommentRepository extends Repository<Comment, String> {
|
||||
|
||||
@VectorSearch(…)
|
||||
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Score similarity);
|
||||
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Score similarity);
|
||||
|
||||
@VectorSearch(…)
|
||||
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Similarity similarity);
|
||||
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Similarity similarity);
|
||||
|
||||
@VectorSearch(…)
|
||||
SearchResults<Comment> searchByEmbeddingNear(Vector vector, Range<Similarity> range);
|
||||
SearchResults<Comment> searchTop10ByEmbeddingNear(Vector vector, Range<Similarity> range);
|
||||
}
|
||||
|
||||
repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1>
|
||||
|
||||
Reference in New Issue
Block a user