diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/EmptyIntrospectedQuery.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/EmptyIntrospectedQuery.java index 188b0b8c2..a7336b98d 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/EmptyIntrospectedQuery.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/EmptyIntrospectedQuery.java @@ -34,8 +34,6 @@ enum EmptyIntrospectedQuery implements EntityQuery { EmptyIntrospectedQuery() {} - - @Override public boolean hasParameterBindings() { return false; @@ -61,6 +59,7 @@ enum EmptyIntrospectedQuery implements EntityQuery { } @Override + @SuppressWarnings("NullAway") public T doWithEnhancer(Function function) { return null; } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java index 4b17555c5..b340d49ce 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java @@ -15,8 +15,9 @@ */ package org.springframework.data.jpa.repository.query; -import static org.springframework.data.jpa.repository.query.JSqlParserUtils.*; -import static org.springframework.data.jpa.repository.query.QueryUtils.*; +import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlCount; +import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlLower; +import static org.springframework.data.jpa.repository.query.QueryUtils.checkSortExpression; import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.expression.Expression; @@ -52,7 +53,6 @@ import java.util.function.Predicate; import java.util.function.Supplier; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Sort; import org.springframework.data.util.Predicates; import org.springframework.util.Assert; @@ -356,6 +356,8 @@ public class JSqlParserQueryEnhancer implements QueryEnhancer { private String applySorting(@Nullable Select selectStatement, Sort sort, @Nullable String alias) { + Assert.notNull(selectStatement, "SelectStatement must not be null"); + if (selectStatement instanceof SetOperationList setOperationList) { return applySortingToSetOperationList(setOperationList, sort); } @@ -381,6 +383,7 @@ public class JSqlParserQueryEnhancer implements QueryEnhancer { } @Override + @SuppressWarnings("NullAway") public String createCountQueryFor(@Nullable String countProjection) { if (this.parsedType != ParsedType.SELECT) { diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java index c15716168..b0e7c49a4 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java @@ -295,6 +295,7 @@ public abstract class JpaQueryExecution { return provider.getResultCount(resultQuery, () -> doCount(repositoryQuery, accessor)); } + @SuppressWarnings("NullAway") long doCount(AbstractJpaQuery repositoryQuery, JpaParametersParameterAccessor accessor) { List totals = repositoryQuery.createCountQuery(accessor).getResultList(); diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java index ac5462175..443b6ca3c 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java @@ -323,10 +323,6 @@ public class ParameterBinding { return Collections.singleton(value); } - - public String lower() { - return null; - } } /** diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java index b1c08fc58..968751704 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterMetadataProvider.java @@ -120,7 +120,7 @@ public class ParameterMetadataProvider { this.templates = templates; } - public JpaParameters getParameters() { + JpaParameters getParameters() { return this.jpaParameters; } @@ -216,6 +216,10 @@ public class ParameterMetadataProvider { return binding; } + /** + * @return the scoring function if available {@link ScoringFunction#unspecified()} by default. + * @since 4.0 + */ ScoringFunction getScoringFunction() { if (accessor != null) { @@ -225,6 +229,12 @@ public class ParameterMetadataProvider { return ScoringFunction.unspecified(); } + /** + * + * @return the vector binding identifier. + * @throws IllegalStateException if parameters do not cotain + * @since 4.0 + */ ParameterBinding getVectorBinding() { if (!getParameters().hasVectorParameter()) { diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java index f4c334d39..71538f9df 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/AbstractVectorIntegrationTests.java @@ -15,7 +15,7 @@ */ package org.springframework.data.jpa.repository; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; import jakarta.persistence.Column; import jakarta.persistence.Entity; @@ -36,7 +36,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.domain.Range; import org.springframework.data.domain.Score; @@ -53,6 +52,7 @@ import org.springframework.transaction.annotation.Transactional; * Testcase to verify Vector Search work with Hibernate. * * @author Mark Paluch + * @author Christoph Strobl */ @Transactional @Rollback(value = false) @@ -65,10 +65,11 @@ abstract class AbstractVectorIntegrationTests { @BeforeEach void setUp() { - WithVector w1 = new WithVector("de", "one", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f }); - WithVector w2 = new WithVector("de", "two", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f }); - WithVector w3 = new WithVector("en", "three", new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f }); - WithVector w4 = new WithVector("de", "four", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f }); + WithVector w1 = new WithVector("de", "one", "d1", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f }); + WithVector w2 = new WithVector("de", "two", "d2", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f }); + WithVector w3 = new WithVector("en", "three", "d3", + new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f }); + WithVector w4 = new WithVector("de", "four", "d4", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f }); repository.deleteAllInBatch(); repository.saveAllAndFlush(Arrays.asList(w1, w2, w3, w4)); @@ -93,7 +94,7 @@ abstract class AbstractVectorIntegrationTests { VectorScoringFunctions.EUCLIDEAN); } - @Test + @Test // GH-3868 void shouldNormalizeEuclideanSimilarity() { SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, @@ -108,7 +109,16 @@ abstract class AbstractVectorIntegrationTests { assertThat(two.getScore().getValue()).isGreaterThan(0.99); } - @Test + @Test // GH-3868 + void orderTargetsProperty() { + + SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithinOrderByDistance("de", VECTOR, + Similarity.of(0, VectorScoringFunctions.EUCLIDEAN)); + + assertThat(results.getContent()).extracting(it -> it.getContent().getDistance()).containsExactly("d1", "d2", "d4"); + } + + @Test// GH-3868 void shouldNormalizeCosineSimilarity() { SearchResults results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR, @@ -123,7 +133,7 @@ abstract class AbstractVectorIntegrationTests { assertThat(two.getScore().getValue()).isGreaterThan(0.99); } - @Test + @Test // GH-3868 void shouldRunStringQuery() { List results = repository.findAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, @@ -133,7 +143,7 @@ abstract class AbstractVectorIntegrationTests { assertThat(results).extracting(WithVector::getDescription).containsSequence("two", "one", "four"); } - @Test + @Test // GH-3868 void shouldRunStringQueryWithDistance() { SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, @@ -149,7 +159,7 @@ abstract class AbstractVectorIntegrationTests { assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE); } - @Test + @Test // GH-3868 void shouldRunStringQueryWithFloatDistance() { SearchResults results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, 2); @@ -164,7 +174,7 @@ abstract class AbstractVectorIntegrationTests { assertThat(result.getScore().getFunction()).isEqualTo(ScoringFunction.unspecified()); } - @Test + @Test // GH-3868 void shouldApplyVectorSearchWithRange() { SearchResults results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR, @@ -176,7 +186,7 @@ abstract class AbstractVectorIntegrationTests { .containsSequence("two", "one", "four"); } - @Test + @Test // GH-3868 void shouldApplyVectorSearchAndReturnList() { List results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR, @@ -186,7 +196,7 @@ abstract class AbstractVectorIntegrationTests { assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four"); } - @Test + @Test // GH-3868 void shouldProjectVectorSearchAsInterface() { SearchResults results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de", @@ -196,7 +206,7 @@ abstract class AbstractVectorIntegrationTests { .containsSequence("two", "one", "four"); } - @Test + @Test // GH-3868 void shouldProjectVectorSearchAsDto() { SearchResults results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR, @@ -206,7 +216,7 @@ abstract class AbstractVectorIntegrationTests { .containsSequence("two", "one", "four"); } - @Test + @Test // GH-3868 void shouldProjectVectorSearchDynamically() { SearchResults dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR, @@ -233,16 +243,19 @@ abstract class AbstractVectorIntegrationTests { private String country; private String description; + private String distance; + @Column(name = "the_embedding") @JdbcTypeCode(SqlTypes.VECTOR) @Array(length = 5) private float[] embedding; public WithVector() {} - public WithVector(String country, String description, float[] embedding) { + public WithVector(String country, String description, String distance, float[] embedding) { this.country = country; this.description = description; this.embedding = embedding; + this.distance = distance; } public Integer getId() { @@ -273,9 +286,22 @@ abstract class AbstractVectorIntegrationTests { this.embedding = embedding; } + public void setDescription(String description) { + this.description = description; + } + + public String getDistance() { + return distance; + } + + public void setDistance(String distance) { + this.distance = distance; + } + @Override public String toString() { - return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}'; + return "WithVector{" + "id=" + id + ", country='" + country + '\'' + ", description='" + description + '\'' + + ", distance='" + distance + '\'' + ", embedding=" + Arrays.toString(embedding) + '}'; } } @@ -328,6 +354,9 @@ abstract class AbstractVectorIntegrationTests { SearchResults searchTop5ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + SearchResults searchTop5ByCountryAndEmbeddingWithinOrderByDistance(String country, Vector embedding, + Score distance); + SearchResults searchInterfaceProjectionByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); diff --git a/spring-data-jpa/src/test/resources/scripts/oracle-vector.sql b/spring-data-jpa/src/test/resources/scripts/oracle-vector.sql index 2d0bf06de..f11fb13fc 100644 --- a/spring-data-jpa/src/test/resources/scripts/oracle-vector.sql +++ b/spring-data-jpa/src/test/resources/scripts/oracle-vector.sql @@ -5,6 +5,7 @@ CREATE TABLE IF NOT EXISTS with_vector id NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY, country varchar2(10), description varchar2(10), + distance varchar2(10), the_embedding vector(5, FLOAT32) annotations(Distance 'COSINE', IndexType 'IVF') );; diff --git a/spring-data-jpa/src/test/resources/scripts/pgvector.sql b/spring-data-jpa/src/test/resources/scripts/pgvector.sql index 4057dd952..b91725750 100644 --- a/spring-data-jpa/src/test/resources/scripts/pgvector.sql +++ b/spring-data-jpa/src/test/resources/scripts/pgvector.sql @@ -2,6 +2,6 @@ CREATE EXTENSION IF NOT EXISTS vector; DROP TABLE IF EXISTS with_vector; -CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10),the_embedding vector(5)); +CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10), distance varchar(10), the_embedding vector(5)); CREATE INDEX ON with_vector USING hnsw (the_embedding vector_l2_ops); diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc index 8b27401c7..851457e68 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc @@ -11,7 +11,7 @@ interface CommentRepository extends Repository { WHERE c.country = ?1 AND cosine_distance(c.embedding, :embedding) <= :distance ORDER BY distance asc""") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); @Query(""" @@ -19,7 +19,7 @@ interface CommentRepository extends Repository { WHERE c.country = ?1 AND cosine_distance(c.embedding, :embedding) <= :distance ORDER BY cosine_distance(c.embedding, :embedding) asc""") - List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); } ---- ==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc index 716bf5a56..8955bafe8 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc @@ -12,7 +12,7 @@ interface CommentRepository extends Repository { WHERE c.country = ?1 AND cosine_distance(c.embedding, :embedding) <= :distance ORDER BY distance asc""") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); }