Polishing.

Original Pull Request: #3868
This commit is contained in:
Christoph Strobl
2025-05-08 12:11:16 +02:00
committed by Mark Paluch
parent 17a59905a7
commit e802143baf
10 changed files with 71 additions and 32 deletions

View File

@@ -34,8 +34,6 @@ enum EmptyIntrospectedQuery implements EntityQuery {
EmptyIntrospectedQuery() {}
@Override
public boolean hasParameterBindings() {
return false;
@@ -61,6 +59,7 @@ enum EmptyIntrospectedQuery implements EntityQuery {
}
@Override
@SuppressWarnings("NullAway")
public <T> T doWithEnhancer(Function<QueryEnhancer, T> function) {
return null;
}

View File

@@ -15,8 +15,9 @@
*/
package org.springframework.data.jpa.repository.query;
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.*;
import static org.springframework.data.jpa.repository.query.QueryUtils.*;
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlCount;
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlLower;
import static org.springframework.data.jpa.repository.query.QueryUtils.checkSortExpression;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
@@ -52,7 +53,6 @@ import java.util.function.Predicate;
import java.util.function.Supplier;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Sort;
import org.springframework.data.util.Predicates;
import org.springframework.util.Assert;
@@ -356,6 +356,8 @@ public class JSqlParserQueryEnhancer implements QueryEnhancer {
private String applySorting(@Nullable Select selectStatement, Sort sort, @Nullable String alias) {
Assert.notNull(selectStatement, "SelectStatement must not be null");
if (selectStatement instanceof SetOperationList setOperationList) {
return applySortingToSetOperationList(setOperationList, sort);
}
@@ -381,6 +383,7 @@ public class JSqlParserQueryEnhancer implements QueryEnhancer {
}
@Override
@SuppressWarnings("NullAway")
public String createCountQueryFor(@Nullable String countProjection) {
if (this.parsedType != ParsedType.SELECT) {

View File

@@ -295,6 +295,7 @@ public abstract class JpaQueryExecution {
return provider.getResultCount(resultQuery, () -> doCount(repositoryQuery, accessor));
}
@SuppressWarnings("NullAway")
long doCount(AbstractJpaQuery repositoryQuery, JpaParametersParameterAccessor accessor) {
List<?> totals = repositoryQuery.createCountQuery(accessor).getResultList();

View File

@@ -323,10 +323,6 @@ public class ParameterBinding {
return Collections.singleton(value);
}
public String lower() {
return null;
}
}
/**

View File

@@ -120,7 +120,7 @@ public class ParameterMetadataProvider {
this.templates = templates;
}
public JpaParameters getParameters() {
JpaParameters getParameters() {
return this.jpaParameters;
}
@@ -216,6 +216,10 @@ public class ParameterMetadataProvider {
return binding;
}
/**
* @return the scoring function if available {@link ScoringFunction#unspecified()} by default.
* @since 4.0
*/
ScoringFunction getScoringFunction() {
if (accessor != null) {
@@ -225,6 +229,12 @@ public class ParameterMetadataProvider {
return ScoringFunction.unspecified();
}
/**
*
* @return the vector binding identifier.
* @throws IllegalStateException if parameters do not cotain
* @since 4.0
*/
ParameterBinding getVectorBinding() {
if (!getParameters().hasVectorParameter()) {

View File

@@ -15,7 +15,7 @@
*/
package org.springframework.data.jpa.repository;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
@@ -36,7 +36,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
@@ -53,6 +52,7 @@ import org.springframework.transaction.annotation.Transactional;
* Testcase to verify Vector Search work with Hibernate.
*
* @author Mark Paluch
* @author Christoph Strobl
*/
@Transactional
@Rollback(value = false)
@@ -65,10 +65,11 @@ abstract class AbstractVectorIntegrationTests {
@BeforeEach
void setUp() {
WithVector w1 = new WithVector("de", "one", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f });
WithVector w2 = new WithVector("de", "two", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f });
WithVector w3 = new WithVector("en", "three", new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f });
WithVector w4 = new WithVector("de", "four", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f });
WithVector w1 = new WithVector("de", "one", "d1", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f });
WithVector w2 = new WithVector("de", "two", "d2", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f });
WithVector w3 = new WithVector("en", "three", "d3",
new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f });
WithVector w4 = new WithVector("de", "four", "d4", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f });
repository.deleteAllInBatch();
repository.saveAllAndFlush(Arrays.asList(w1, w2, w3, w4));
@@ -93,7 +94,7 @@ abstract class AbstractVectorIntegrationTests {
VectorScoringFunctions.EUCLIDEAN);
}
@Test
@Test // GH-3868
void shouldNormalizeEuclideanSimilarity() {
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR,
@@ -108,7 +109,16 @@ abstract class AbstractVectorIntegrationTests {
assertThat(two.getScore().getValue()).isGreaterThan(0.99);
}
@Test
@Test // GH-3868
void orderTargetsProperty() {
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithinOrderByDistance("de", VECTOR,
Similarity.of(0, VectorScoringFunctions.EUCLIDEAN));
assertThat(results.getContent()).extracting(it -> it.getContent().getDistance()).containsExactly("d1", "d2", "d4");
}
@Test// GH-3868
void shouldNormalizeCosineSimilarity() {
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR,
@@ -123,7 +133,7 @@ abstract class AbstractVectorIntegrationTests {
assertThat(two.getScore().getValue()).isGreaterThan(0.99);
}
@Test
@Test // GH-3868
void shouldRunStringQuery() {
List<WithVector> results = repository.findAnnotatedByCountryAndEmbeddingWithin("de", VECTOR,
@@ -133,7 +143,7 @@ abstract class AbstractVectorIntegrationTests {
assertThat(results).extracting(WithVector::getDescription).containsSequence("two", "one", "four");
}
@Test
@Test // GH-3868
void shouldRunStringQueryWithDistance() {
SearchResults<WithVector> results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR,
@@ -149,7 +159,7 @@ abstract class AbstractVectorIntegrationTests {
assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE);
}
@Test
@Test // GH-3868
void shouldRunStringQueryWithFloatDistance() {
SearchResults<WithVector> results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, 2);
@@ -164,7 +174,7 @@ abstract class AbstractVectorIntegrationTests {
assertThat(result.getScore().getFunction()).isEqualTo(ScoringFunction.unspecified());
}
@Test
@Test // GH-3868
void shouldApplyVectorSearchWithRange() {
SearchResults<WithVector> results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR,
@@ -176,7 +186,7 @@ abstract class AbstractVectorIntegrationTests {
.containsSequence("two", "one", "four");
}
@Test
@Test // GH-3868
void shouldApplyVectorSearchAndReturnList() {
List<WithVector> results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR,
@@ -186,7 +196,7 @@ abstract class AbstractVectorIntegrationTests {
assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four");
}
@Test
@Test // GH-3868
void shouldProjectVectorSearchAsInterface() {
SearchResults<WithDescription> results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de",
@@ -196,7 +206,7 @@ abstract class AbstractVectorIntegrationTests {
.containsSequence("two", "one", "four");
}
@Test
@Test // GH-3868
void shouldProjectVectorSearchAsDto() {
SearchResults<DescriptionDto> results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR,
@@ -206,7 +216,7 @@ abstract class AbstractVectorIntegrationTests {
.containsSequence("two", "one", "four");
}
@Test
@Test // GH-3868
void shouldProjectVectorSearchDynamically() {
SearchResults<DescriptionDto> dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR,
@@ -233,16 +243,19 @@ abstract class AbstractVectorIntegrationTests {
private String country;
private String description;
private String distance;
@Column(name = "the_embedding")
@JdbcTypeCode(SqlTypes.VECTOR)
@Array(length = 5) private float[] embedding;
public WithVector() {}
public WithVector(String country, String description, float[] embedding) {
public WithVector(String country, String description, String distance, float[] embedding) {
this.country = country;
this.description = description;
this.embedding = embedding;
this.distance = distance;
}
public Integer getId() {
@@ -273,9 +286,22 @@ abstract class AbstractVectorIntegrationTests {
this.embedding = embedding;
}
public void setDescription(String description) {
this.description = description;
}
public String getDistance() {
return distance;
}
public void setDistance(String distance) {
this.distance = distance;
}
@Override
public String toString() {
return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}';
return "WithVector{" + "id=" + id + ", country='" + country + '\'' + ", description='" + description + '\''
+ ", distance='" + distance + '\'' + ", embedding=" + Arrays.toString(embedding) + '}';
}
}
@@ -328,6 +354,9 @@ abstract class AbstractVectorIntegrationTests {
SearchResults<WithVector> searchTop5ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
SearchResults<WithVector> searchTop5ByCountryAndEmbeddingWithinOrderByDistance(String country, Vector embedding,
Score distance);
SearchResults<WithDescription> searchInterfaceProjectionByCountryAndEmbeddingWithin(String country,
Vector embedding, Score distance);

View File

@@ -5,6 +5,7 @@ CREATE TABLE IF NOT EXISTS with_vector
id NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY,
country varchar2(10),
description varchar2(10),
distance varchar2(10),
the_embedding vector(5, FLOAT32) annotations(Distance 'COSINE', IndexType 'IVF')
);;

View File

@@ -2,6 +2,6 @@ CREATE EXTENSION IF NOT EXISTS vector;
DROP TABLE IF EXISTS with_vector;
CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10),the_embedding vector(5));
CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10), distance varchar(10), the_embedding vector(5));
CREATE INDEX ON with_vector USING hnsw (the_embedding vector_l2_ops);

View File

@@ -11,7 +11,7 @@ interface CommentRepository extends Repository<Comment, String> {
WHERE c.country = ?1
AND cosine_distance(c.embedding, :embedding) <= :distance
ORDER BY distance asc""")
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
Score distance);
@Query("""
@@ -19,7 +19,7 @@ interface CommentRepository extends Repository<Comment, String> {
WHERE c.country = ?1
AND cosine_distance(c.embedding, :embedding) <= :distance
ORDER BY cosine_distance(c.embedding, :embedding) asc""")
List<WithVector> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
List<Comment> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
}
----
====

View File

@@ -12,7 +12,7 @@ interface CommentRepository extends Repository<Comment, String> {
WHERE c.country = ?1
AND cosine_distance(c.embedding, :embedding) <= :distance
ORDER BY distance asc""")
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
Score distance);
}