committed by
Mark Paluch
parent
17a59905a7
commit
e802143baf
@@ -34,8 +34,6 @@ enum EmptyIntrospectedQuery implements EntityQuery {
|
||||
|
||||
EmptyIntrospectedQuery() {}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasParameterBindings() {
|
||||
return false;
|
||||
@@ -61,6 +59,7 @@ enum EmptyIntrospectedQuery implements EntityQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("NullAway")
|
||||
public <T> T doWithEnhancer(Function<QueryEnhancer, T> function) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -15,8 +15,9 @@
|
||||
*/
|
||||
package org.springframework.data.jpa.repository.query;
|
||||
|
||||
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.*;
|
||||
import static org.springframework.data.jpa.repository.query.QueryUtils.*;
|
||||
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlCount;
|
||||
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlLower;
|
||||
import static org.springframework.data.jpa.repository.query.QueryUtils.checkSortExpression;
|
||||
|
||||
import net.sf.jsqlparser.expression.Alias;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -52,7 +53,6 @@ import java.util.function.Predicate;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import org.jspecify.annotations.Nullable;
|
||||
|
||||
import org.springframework.data.domain.Sort;
|
||||
import org.springframework.data.util.Predicates;
|
||||
import org.springframework.util.Assert;
|
||||
@@ -356,6 +356,8 @@ public class JSqlParserQueryEnhancer implements QueryEnhancer {
|
||||
|
||||
private String applySorting(@Nullable Select selectStatement, Sort sort, @Nullable String alias) {
|
||||
|
||||
Assert.notNull(selectStatement, "SelectStatement must not be null");
|
||||
|
||||
if (selectStatement instanceof SetOperationList setOperationList) {
|
||||
return applySortingToSetOperationList(setOperationList, sort);
|
||||
}
|
||||
@@ -381,6 +383,7 @@ public class JSqlParserQueryEnhancer implements QueryEnhancer {
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("NullAway")
|
||||
public String createCountQueryFor(@Nullable String countProjection) {
|
||||
|
||||
if (this.parsedType != ParsedType.SELECT) {
|
||||
|
||||
@@ -295,6 +295,7 @@ public abstract class JpaQueryExecution {
|
||||
return provider.getResultCount(resultQuery, () -> doCount(repositoryQuery, accessor));
|
||||
}
|
||||
|
||||
@SuppressWarnings("NullAway")
|
||||
long doCount(AbstractJpaQuery repositoryQuery, JpaParametersParameterAccessor accessor) {
|
||||
|
||||
List<?> totals = repositoryQuery.createCountQuery(accessor).getResultList();
|
||||
|
||||
@@ -323,10 +323,6 @@ public class ParameterBinding {
|
||||
|
||||
return Collections.singleton(value);
|
||||
}
|
||||
|
||||
public String lower() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -120,7 +120,7 @@ public class ParameterMetadataProvider {
|
||||
this.templates = templates;
|
||||
}
|
||||
|
||||
public JpaParameters getParameters() {
|
||||
JpaParameters getParameters() {
|
||||
return this.jpaParameters;
|
||||
}
|
||||
|
||||
@@ -216,6 +216,10 @@ public class ParameterMetadataProvider {
|
||||
return binding;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the scoring function if available {@link ScoringFunction#unspecified()} by default.
|
||||
* @since 4.0
|
||||
*/
|
||||
ScoringFunction getScoringFunction() {
|
||||
|
||||
if (accessor != null) {
|
||||
@@ -225,6 +229,12 @@ public class ParameterMetadataProvider {
|
||||
return ScoringFunction.unspecified();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the vector binding identifier.
|
||||
* @throws IllegalStateException if parameters do not cotain
|
||||
* @since 4.0
|
||||
*/
|
||||
ParameterBinding getVectorBinding() {
|
||||
|
||||
if (!getParameters().hasVectorParameter()) {
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
*/
|
||||
package org.springframework.data.jpa.repository;
|
||||
|
||||
import static org.assertj.core.api.Assertions.*;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
@@ -36,7 +36,6 @@ import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.data.domain.Range;
|
||||
import org.springframework.data.domain.Score;
|
||||
@@ -53,6 +52,7 @@ import org.springframework.transaction.annotation.Transactional;
|
||||
* Testcase to verify Vector Search work with Hibernate.
|
||||
*
|
||||
* @author Mark Paluch
|
||||
* @author Christoph Strobl
|
||||
*/
|
||||
@Transactional
|
||||
@Rollback(value = false)
|
||||
@@ -65,10 +65,11 @@ abstract class AbstractVectorIntegrationTests {
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
|
||||
WithVector w1 = new WithVector("de", "one", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f });
|
||||
WithVector w2 = new WithVector("de", "two", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f });
|
||||
WithVector w3 = new WithVector("en", "three", new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f });
|
||||
WithVector w4 = new WithVector("de", "four", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f });
|
||||
WithVector w1 = new WithVector("de", "one", "d1", new float[] { 0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f });
|
||||
WithVector w2 = new WithVector("de", "two", "d2", new float[] { 0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f });
|
||||
WithVector w3 = new WithVector("en", "three", "d3",
|
||||
new float[] { 0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f });
|
||||
WithVector w4 = new WithVector("de", "four", "d4", new float[] { 0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f });
|
||||
|
||||
repository.deleteAllInBatch();
|
||||
repository.saveAllAndFlush(Arrays.asList(w1, w2, w3, w4));
|
||||
@@ -93,7 +94,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
VectorScoringFunctions.EUCLIDEAN);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldNormalizeEuclideanSimilarity() {
|
||||
|
||||
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -108,7 +109,16 @@ abstract class AbstractVectorIntegrationTests {
|
||||
assertThat(two.getScore().getValue()).isGreaterThan(0.99);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void orderTargetsProperty() {
|
||||
|
||||
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithinOrderByDistance("de", VECTOR,
|
||||
Similarity.of(0, VectorScoringFunctions.EUCLIDEAN));
|
||||
|
||||
assertThat(results.getContent()).extracting(it -> it.getContent().getDistance()).containsExactly("d1", "d2", "d4");
|
||||
}
|
||||
|
||||
@Test// GH-3868
|
||||
void shouldNormalizeCosineSimilarity() {
|
||||
|
||||
SearchResults<WithVector> results = repository.searchTop5ByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -123,7 +133,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
assertThat(two.getScore().getValue()).isGreaterThan(0.99);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldRunStringQuery() {
|
||||
|
||||
List<WithVector> results = repository.findAnnotatedByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -133,7 +143,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
assertThat(results).extracting(WithVector::getDescription).containsSequence("two", "one", "four");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldRunStringQueryWithDistance() {
|
||||
|
||||
SearchResults<WithVector> results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -149,7 +159,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
assertThat(result.getScore().getFunction()).isEqualTo(VectorScoringFunctions.COSINE);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldRunStringQueryWithFloatDistance() {
|
||||
|
||||
SearchResults<WithVector> results = repository.searchAnnotatedByCountryAndEmbeddingWithin("de", VECTOR, 2);
|
||||
@@ -164,7 +174,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
assertThat(result.getScore().getFunction()).isEqualTo(ScoringFunction.unspecified());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldApplyVectorSearchWithRange() {
|
||||
|
||||
SearchResults<WithVector> results = repository.searchAllByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -176,7 +186,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
.containsSequence("two", "one", "four");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldApplyVectorSearchAndReturnList() {
|
||||
|
||||
List<WithVector> results = repository.findAllByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -186,7 +196,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
assertThat(results).extracting(WithVector::getDescription).containsSequence("one", "two", "four");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldProjectVectorSearchAsInterface() {
|
||||
|
||||
SearchResults<WithDescription> results = repository.searchInterfaceProjectionByCountryAndEmbeddingWithin("de",
|
||||
@@ -196,7 +206,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
.containsSequence("two", "one", "four");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldProjectVectorSearchAsDto() {
|
||||
|
||||
SearchResults<DescriptionDto> results = repository.searchDtoByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -206,7 +216,7 @@ abstract class AbstractVectorIntegrationTests {
|
||||
.containsSequence("two", "one", "four");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test // GH-3868
|
||||
void shouldProjectVectorSearchDynamically() {
|
||||
|
||||
SearchResults<DescriptionDto> dtos = repository.searchDynamicByCountryAndEmbeddingWithin("de", VECTOR,
|
||||
@@ -233,16 +243,19 @@ abstract class AbstractVectorIntegrationTests {
|
||||
private String country;
|
||||
private String description;
|
||||
|
||||
private String distance;
|
||||
|
||||
@Column(name = "the_embedding")
|
||||
@JdbcTypeCode(SqlTypes.VECTOR)
|
||||
@Array(length = 5) private float[] embedding;
|
||||
|
||||
public WithVector() {}
|
||||
|
||||
public WithVector(String country, String description, float[] embedding) {
|
||||
public WithVector(String country, String description, String distance, float[] embedding) {
|
||||
this.country = country;
|
||||
this.description = description;
|
||||
this.embedding = embedding;
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
public Integer getId() {
|
||||
@@ -273,9 +286,22 @@ abstract class AbstractVectorIntegrationTests {
|
||||
this.embedding = embedding;
|
||||
}
|
||||
|
||||
public void setDescription(String description) {
|
||||
this.description = description;
|
||||
}
|
||||
|
||||
public String getDistance() {
|
||||
return distance;
|
||||
}
|
||||
|
||||
public void setDistance(String distance) {
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "WithVector{" + "country='" + country + '\'' + ", description='" + description + '\'' + '}';
|
||||
return "WithVector{" + "id=" + id + ", country='" + country + '\'' + ", description='" + description + '\''
|
||||
+ ", distance='" + distance + '\'' + ", embedding=" + Arrays.toString(embedding) + '}';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,6 +354,9 @@ abstract class AbstractVectorIntegrationTests {
|
||||
|
||||
SearchResults<WithVector> searchTop5ByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
|
||||
|
||||
SearchResults<WithVector> searchTop5ByCountryAndEmbeddingWithinOrderByDistance(String country, Vector embedding,
|
||||
Score distance);
|
||||
|
||||
SearchResults<WithDescription> searchInterfaceProjectionByCountryAndEmbeddingWithin(String country,
|
||||
Vector embedding, Score distance);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ CREATE TABLE IF NOT EXISTS with_vector
|
||||
id NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY,
|
||||
country varchar2(10),
|
||||
description varchar2(10),
|
||||
distance varchar2(10),
|
||||
the_embedding vector(5, FLOAT32) annotations(Distance 'COSINE', IndexType 'IVF')
|
||||
);;
|
||||
|
||||
|
||||
@@ -2,6 +2,6 @@ CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
DROP TABLE IF EXISTS with_vector;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10),the_embedding vector(5));
|
||||
CREATE TABLE IF NOT EXISTS with_vector (id bigserial PRIMARY KEY,country varchar(10), description varchar(10), distance varchar(10), the_embedding vector(5));
|
||||
|
||||
CREATE INDEX ON with_vector USING hnsw (the_embedding vector_l2_ops);
|
||||
|
||||
@@ -11,7 +11,7 @@ interface CommentRepository extends Repository<Comment, String> {
|
||||
WHERE c.country = ?1
|
||||
AND cosine_distance(c.embedding, :embedding) <= :distance
|
||||
ORDER BY distance asc""")
|
||||
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
Score distance);
|
||||
|
||||
@Query("""
|
||||
@@ -19,7 +19,7 @@ interface CommentRepository extends Repository<Comment, String> {
|
||||
WHERE c.country = ?1
|
||||
AND cosine_distance(c.embedding, :embedding) <= :distance
|
||||
ORDER BY cosine_distance(c.embedding, :embedding) asc""")
|
||||
List<WithVector> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
|
||||
List<Comment> findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance);
|
||||
}
|
||||
----
|
||||
====
|
||||
|
||||
@@ -12,7 +12,7 @@ interface CommentRepository extends Repository<Comment, String> {
|
||||
WHERE c.country = ?1
|
||||
AND cosine_distance(c.embedding, :embedding) <= :distance
|
||||
ORDER BY distance asc""")
|
||||
SearchResults<WithVector> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
SearchResults<Comment> searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding,
|
||||
Score distance);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user