Support similarity scores in Document API

Document
* Introduced “score” attribute in Document API. It stores the similarity score.
* Consolidate “distance” metadata for Documents. It stores the distance measurement.
* Adopted prefix-less naming convention in Document.Builder and deprecated old methods.
* Deprecated the many overloaded Document constructors in favour of Document.Builder.

Vector Stores
* Every vector store implementation now configures a “score” attribute with the similarity score of the Document embedding. It also includes the “distance” metadata with the distance measurement.
* Fixed error in Elasticsearch where distance and similarity were mixed up.
* Added missing integration tests for SimpleVectorStore.
* The Azure Vector Store and HanaDB Vector Store do not include those measurements because the product documentation do not include information about how the similarity score is returned, and without access to the cloud products I could not verify that via debugging.
* Improved tests to actually assert the result of the similarity search based on the returned score.

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
This commit is contained in:
Thomas Vitale
2024-11-26 08:03:40 +01:00
committed by Mark Pollack
parent 50223d20e3
commit fe58fd30eb
59 changed files with 1096 additions and 434 deletions

View File

@@ -176,14 +176,14 @@ public class MarkdownDocumentReader implements DocumentReader {
}
translateLineBreakToSpace();
this.currentDocumentBuilder.withMetadata("category", "blockquote");
this.currentDocumentBuilder.metadata("category", "blockquote");
super.visit(blockQuote);
}
@Override
public void visit(Code code) {
this.currentParagraphs.add(code.getLiteral());
this.currentDocumentBuilder.withMetadata("category", "code_inline");
this.currentDocumentBuilder.metadata("category", "code_inline");
super.visit(code);
}
@@ -195,8 +195,8 @@ public class MarkdownDocumentReader implements DocumentReader {
translateLineBreakToSpace();
this.currentParagraphs.add(fencedCodeBlock.getLiteral());
this.currentDocumentBuilder.withMetadata("category", "code_block");
this.currentDocumentBuilder.withMetadata("lang", fencedCodeBlock.getInfo());
this.currentDocumentBuilder.metadata("category", "code_block");
this.currentDocumentBuilder.metadata("lang", fencedCodeBlock.getInfo());
buildAndFlush();
@@ -206,8 +206,8 @@ public class MarkdownDocumentReader implements DocumentReader {
@Override
public void visit(Text text) {
if (text.getParent() instanceof Heading heading) {
this.currentDocumentBuilder.withMetadata("category", "header_%d".formatted(heading.getLevel()))
.withMetadata("title", text.getLiteral());
this.currentDocumentBuilder.metadata("category", "header_%d".formatted(heading.getLevel()))
.metadata("title", text.getLiteral());
}
else {
this.currentParagraphs.add(text.getLiteral());
@@ -226,9 +226,9 @@ public class MarkdownDocumentReader implements DocumentReader {
if (!this.currentParagraphs.isEmpty()) {
String content = String.join("", this.currentParagraphs);
Document.Builder builder = this.currentDocumentBuilder.withContent(content);
Document.Builder builder = this.currentDocumentBuilder.content(content);
this.config.additionalMetadata.forEach(builder::withMetadata);
this.config.additionalMetadata.forEach(builder::metadata);
Document document = builder.build();

View File

@@ -111,7 +111,7 @@ class VertexAiMultimodalEmbeddingModelIT {
assertThat(this.multiModelEmbeddingModel).isNotNull();
var document = Document.builder()
.withMedia(new Media(MimeTypeUtils.TEXT_PLAIN, URI.create("http://example.com/image.png").toURL()))
.media(new Media(MimeTypeUtils.TEXT_PLAIN, URI.create("http://example.com/image.png").toURL()))
.build();
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);
@@ -135,7 +135,7 @@ class VertexAiMultimodalEmbeddingModelIT {
void imageEmbedding() {
var document = Document.builder()
.withMedia(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
.media(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
.build();
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);
@@ -161,7 +161,7 @@ class VertexAiMultimodalEmbeddingModelIT {
void videoEmbedding() {
var document = Document.builder()
.withMedia(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
.media(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
.build();
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);
@@ -186,9 +186,9 @@ class VertexAiMultimodalEmbeddingModelIT {
void textImageAndVideoEmbedding() {
var document = Document.builder()
.withContent("Hello World")
.withMedia(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
.withMedia(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
.content("Hello World")
.media(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
.media(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
.build();
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);

View File

@@ -16,6 +16,7 @@
package org.springframework.ai.chat.client.advisor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -184,10 +185,14 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
metadata.put(DOCUMENT_METADATA_CONVERSATION_ID, conversationId);
metadata.put(DOCUMENT_METADATA_MESSAGE_TYPE, message.getMessageType().name());
if (message instanceof UserMessage userMessage) {
return new Document(userMessage.getContent(), userMessage.getMedia(), metadata);
return Document.builder()
.content(userMessage.getContent())
.media(new ArrayList<>(userMessage.getMedia()))
.metadata(metadata)
.build();
}
else if (message instanceof AssistantMessage assistantMessage) {
return new Document(assistantMessage.getContent(), metadata);
return Document.builder().content(assistantMessage.getContent()).metadata(metadata).build();
}
throw new RuntimeException("Unknown message type: " + message.getMessageType());
})

View File

@@ -21,6 +21,7 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
@@ -31,6 +32,7 @@ import org.springframework.ai.document.id.IdGenerator;
import org.springframework.ai.document.id.RandomIdGenerator;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.MediaContent;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@@ -41,9 +43,9 @@ import org.springframework.util.StringUtils;
@JsonIgnoreProperties({ "contentFormatter" })
public class Document implements MediaContent {
public final static ContentFormatter DEFAULT_CONTENT_FORMATTER = DefaultContentFormatter.defaultConfig();
public static final ContentFormatter DEFAULT_CONTENT_FORMATTER = DefaultContentFormatter.defaultConfig();
public final static String EMPTY_TEXT = "";
public static final String EMPTY_TEXT = "";
/**
* Unique ID
@@ -61,7 +63,15 @@ public class Document implements MediaContent {
* Metadata for the document. It should not be nested and values should be restricted
* to string, int, float, boolean for simple use with Vector Dbs.
*/
private Map<String, Object> metadata;
private final Map<String, Object> metadata;
/**
* Measure of similarity between the document embedding and the query vector. The
* higher the score, the more they are similar. It's the opposite of the distance
* measure.
*/
@Nullable
private final Double score;
/**
* Embedding of the document. Note: ephemeral field.
@@ -84,10 +94,18 @@ public class Document implements MediaContent {
this(content, metadata, new RandomIdGenerator());
}
/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String content, Collection<Media> media, Map<String, Object> metadata) {
this(new RandomIdGenerator().generateId(content, metadata), content, media, metadata);
}
/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String content, Map<String, Object> metadata, IdGenerator idGenerator) {
this(idGenerator.generateId(content, metadata), content, metadata);
}
@@ -96,15 +114,33 @@ public class Document implements MediaContent {
this(id, content, List.of(), metadata);
}
/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String id, String content, Collection<Media> media, Map<String, Object> metadata) {
Assert.hasText(id, "id must not be null or empty");
Assert.notNull(content, "content must not be null");
Assert.notNull(metadata, "metadata must not be null");
this(id, content, media, metadata, null);
}
/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String id, String content, @Nullable Collection<Media> media,
@Nullable Map<String, Object> metadata, @Nullable Double score) {
Assert.hasText(id, "id cannot be null or empty");
Assert.notNull(content, "content cannot be null");
Assert.notNull(media, "media cannot be null");
Assert.noNullElements(media, "media cannot have null elements");
Assert.notNull(metadata, "metadata cannot be null");
Assert.noNullElements(metadata.keySet(), "metadata cannot have null keys");
Assert.noNullElements(metadata.values(), "metadata cannot have null values");
this.id = id;
this.content = content;
this.media = media;
this.metadata = metadata;
this.media = media != null ? media : List.of();
this.metadata = metadata != null ? metadata : new HashMap<>();
this.score = score;
}
public static Builder builder() {
@@ -149,6 +185,11 @@ public class Document implements MediaContent {
return this.metadata;
}
@Nullable
public Double getScore() {
return this.score;
}
/**
* Return the embedding that were calculated.
* @deprecated We are considering getting rid of this, please comment on
@@ -172,6 +213,7 @@ public class Document implements MediaContent {
* @return the current ContentFormatter instance used for formatting the document
* content.
*/
@Deprecated(since = "1.0.0-M4")
public ContentFormatter getContentFormatter() {
return this.contentFormatter;
}
@@ -184,59 +226,34 @@ public class Document implements MediaContent {
this.contentFormatter = contentFormatter;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((this.id == null) ? 0 : this.id.hashCode());
result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode());
result = prime * result + ((this.content == null) ? 0 : this.content.hashCode());
return result;
public Builder mutate() {
return new Builder().id(this.id)
.content(this.content)
.media(new ArrayList<>(this.media))
.metadata(this.metadata)
.score(this.score);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
public boolean equals(Object o) {
if (o == null || this.getClass() != o.getClass()) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
Document other = (Document) obj;
if (this.id == null) {
if (other.id != null) {
return false;
}
}
else if (!this.id.equals(other.id)) {
return false;
}
if (this.metadata == null) {
if (other.metadata != null) {
return false;
}
}
else if (!this.metadata.equals(other.metadata)) {
return false;
}
if (this.content == null) {
if (other.content != null) {
return false;
}
}
else if (!this.content.equals(other.content)) {
return false;
}
return true;
Document document = (Document) o;
return Objects.equals(this.id, document.id) && Objects.equals(this.content, document.content)
&& Objects.equals(this.media, document.media) && Objects.equals(this.metadata, document.metadata)
&& Objects.equals(this.score, document.score);
}
@Override
public int hashCode() {
return Objects.hash(this.id, this.content, this.media, this.metadata, this.score);
}
@Override
public String toString() {
return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content
+ '\'' + ", media=" + this.media + '}';
return "Document{" + "id='" + this.id + '\'' + ", content='" + this.content + '\'' + ", media=" + this.media
+ ", metadata=" + this.metadata + ", score=" + this.score + '}';
}
public static class Builder {
@@ -249,56 +266,103 @@ public class Document implements MediaContent {
private Map<String, Object> metadata = new HashMap<>();
private float[] embedding = new float[0];
@Nullable
private Double score;
private IdGenerator idGenerator = new RandomIdGenerator();
public Builder withIdGenerator(IdGenerator idGenerator) {
Assert.notNull(idGenerator, "idGenerator must not be null");
public Builder idGenerator(IdGenerator idGenerator) {
Assert.notNull(idGenerator, "idGenerator cannot be null");
this.idGenerator = idGenerator;
return this;
}
public Builder withId(String id) {
Assert.hasText(id, "id must not be null or empty");
public Builder id(String id) {
Assert.hasText(id, "id cannot be null or empty");
this.id = id;
return this;
}
public Builder withContent(String content) {
Assert.notNull(content, "content must not be null");
public Builder content(String content) {
this.content = content;
return this;
}
public Builder withMedia(List<Media> media) {
Assert.notNull(media, "media must not be null");
this.media = media;
public Builder media(List<Media> media) {
this.media.addAll(media);
return this;
}
public Builder withMedia(Media media) {
Assert.notNull(media, "media must not be null");
this.media.add(media);
public Builder media(Media... media) {
Assert.noNullElements(media, "media cannot contain null elements");
this.media.addAll(List.of(media));
return this;
}
public Builder withMetadata(Map<String, Object> metadata) {
Assert.notNull(metadata, "metadata must not be null");
public Builder metadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}
public Builder withMetadata(String key, Object value) {
Assert.notNull(key, "key must not be null");
Assert.notNull(value, "value must not be null");
public Builder metadata(String key, Object value) {
this.metadata.put(key, value);
return this;
}
public Builder embedding(float[] embedding) {
this.embedding = embedding;
return this;
}
public Builder score(@Nullable Double score) {
this.score = score;
return this;
}
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withIdGenerator(IdGenerator idGenerator) {
return idGenerator(idGenerator);
}
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withId(String id) {
return id(id);
}
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withContent(String content) {
return content(content);
}
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMedia(List<Media> media) {
return media(media);
}
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMedia(Media media) {
return media(media);
}
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMetadata(Map<String, Object> metadata) {
return metadata(metadata);
}
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMetadata(String key, Object value) {
return metadata(key, value);
}
public Document build() {
if (!StringUtils.hasText(this.id)) {
this.id = this.idGenerator.generateId(this.content, this.metadata);
}
return new Document(this.id, this.content, this.media, this.metadata);
var document = new Document(this.id, this.content, this.media, this.metadata, this.score);
document.setEmbedding(this.embedding);
return document;
}
}

View File

@@ -0,0 +1,55 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.document;
import org.springframework.ai.vectorstore.VectorStore;
/**
* Common set of metadata keys used in {@link Document}s by {@link DocumentReader}s and
* {@link VectorStore}s.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public enum DocumentMetadata {
// @formatter:off
/**
* Measure of distance between the document embedding and the query vector.
* The lower the distance, the more they are similar.
* It's the opposite of the similarity score.
*/
DISTANCE("distance");
private final String value;
DocumentMetadata(String value) {
this.value = value;
}
public String value() {
return this.value;
}
// @formatter:on
@Override
public String toString() {
return value;
}
}

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@NonNullApi
@NonNullFields
package org.springframework.ai.document;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@@ -22,6 +22,7 @@ import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
@@ -30,6 +31,7 @@ import org.springframework.util.Assert;
* instance and then apply the 'with' methods to alter the default values.
*
* @author Christian Tzolov
* @author Thomas Vitale
*/
public final class SearchRequest {
@@ -45,12 +47,13 @@ public final class SearchRequest {
*/
public static final int DEFAULT_TOP_K = 4;
public String query;
private String query;
private int topK = DEFAULT_TOP_K;
private double similarityThreshold = SIMILARITY_THRESHOLD_ACCEPT_ALL;
@Nullable
private Filter.Expression filterExpression;
private SearchRequest(String query) {
@@ -186,7 +189,7 @@ public final class SearchRequest {
* filter criteria. The 'null' value stands for no expression filters.
* @return this builder.
*/
public SearchRequest withFilterExpression(Filter.Expression expression) {
public SearchRequest withFilterExpression(@Nullable Filter.Expression expression) {
this.filterExpression = expression;
return this;
}
@@ -225,7 +228,7 @@ public final class SearchRequest {
* 'null' value stands for no expression filters.
* @return this.builder
*/
public SearchRequest withFilterExpression(String textExpression) {
public SearchRequest withFilterExpression(@Nullable String textExpression) {
this.filterExpression = (textExpression != null) ? new FilterExpressionTextParser().parse(textExpression)
: null;
return this;
@@ -243,6 +246,7 @@ public final class SearchRequest {
return this.similarityThreshold;
}
@Nullable
public Filter.Expression getFilterExpression() {
return this.filterExpression;
}

View File

@@ -43,6 +43,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
@@ -68,6 +69,7 @@ import org.springframework.core.io.Resource;
* @author Christian Tzolov
* @author Sebastien Deleuze
* @author Ilayaperumal Gopinathan
* @author Thomas Vitale
*/
public class SimpleVectorStore extends AbstractObservationVectorStore {
@@ -127,12 +129,11 @@ public class SimpleVectorStore extends AbstractObservationVectorStore {
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
return this.store.values()
.stream()
.map(entry -> new Similarity(entry,
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
.filter(s -> s.score >= request.getSimilarityThreshold())
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
.map(content -> content
.toDocument(EmbeddingMath.cosineSimilarity(userQueryEmbedding, content.getEmbedding())))
.filter(document -> document.getScore() >= request.getSimilarityThreshold())
.sorted(Comparator.comparing(Document::getScore).reversed())
.limit(request.getTopK())
.map(s -> s.getDocument())
.toList();
}
@@ -235,28 +236,7 @@ public class SimpleVectorStore extends AbstractObservationVectorStore {
.withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value());
}
public static class Similarity {
private SimpleVectorStoreContent content;
private double score;
public Similarity(SimpleVectorStoreContent content, double score) {
this.content = content;
this.score = score;
}
Document getDocument() {
return Document.builder()
.withId(this.content.getId())
.withContent(this.content.getContent())
.withMetadata(this.content.getMetadata())
.build();
}
}
public final class EmbeddingMath {
public static final class EmbeddingMath {
private EmbeddingMath() {
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");

View File

@@ -25,6 +25,8 @@ import java.util.Objects;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.document.id.IdGenerator;
import org.springframework.ai.document.id.RandomIdGenerator;
import org.springframework.ai.model.Content;
@@ -135,6 +137,12 @@ public final class SimpleVectorStoreContent implements Content {
return Arrays.copyOf(this.embedding, this.embedding.length);
}
public Document toDocument(Double score) {
var metadata = new HashMap<>(this.metadata);
metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - score);
return Document.builder().id(this.id).content(this.content).metadata(metadata).score(score).build();
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@NonNullApi
@NonNullFields
package org.springframework.ai.vectorstore;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@@ -70,8 +70,8 @@ class RetrievalAugmentationAdvisorTests {
.build());
// Document Retriever
var documentContext = List.of(Document.builder().withId("1").withContent("doc1").build(),
Document.builder().withId("2").withContent("doc2").build());
var documentContext = List.of(Document.builder().id("1").content("doc1").build(),
Document.builder().id("2").content("doc2").build());
var documentRetriever = mock(DocumentRetriever.class);
var queryCaptor = ArgumentCaptor.forClass(Query.class);
given(documentRetriever.retrieve(queryCaptor.capture())).willReturn(documentContext);

View File

@@ -42,13 +42,11 @@ public class DocumentBuilderTests {
URL mediaUrl2 = new URL("http://type2");
Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1);
Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2);
List<Media> mediaList = List.of(media1, media2);
return mediaList;
return List.of(media1, media2);
}
catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
@BeforeEach
@@ -58,32 +56,26 @@ public class DocumentBuilderTests {
@Test
void testWithIdGenerator() {
IdGenerator mockGenerator = new IdGenerator() {
IdGenerator mockGenerator = contents -> "mockedId";
@Override
public String generateId(Object... contents) {
return "mockedId";
}
};
Document.Builder result = this.builder.withIdGenerator(mockGenerator);
Document.Builder result = this.builder.idGenerator(mockGenerator);
assertThat(result).isSameAs(this.builder);
Document document = result.withContent("Test content").withMetadata("key", "value").build();
Document document = result.content("Test content").metadata("key", "value").build();
assertThat(document.getId()).isEqualTo("mockedId");
}
@Test
void testWithIdGeneratorNull() {
assertThatThrownBy(() -> this.builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("idGenerator must not be null");
assertThatThrownBy(() -> this.builder.idGenerator(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("idGenerator cannot be null");
}
@Test
void testWithId() {
Document.Builder result = this.builder.withId("testId");
Document.Builder result = this.builder.id("testId");
assertThat(result).isSameAs(this.builder);
assertThat(result.build().getId()).isEqualTo("testId");
@@ -91,16 +83,16 @@ public class DocumentBuilderTests {
@Test
void testWithIdNullOrEmpty() {
assertThatThrownBy(() -> this.builder.withId(null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("id must not be null or empty");
assertThatThrownBy(() -> this.builder.id(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("id cannot be null or empty");
assertThatThrownBy(() -> this.builder.withId("")).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("id must not be null or empty");
assertThatThrownBy(() -> this.builder.id("").build()).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("id cannot be null or empty");
}
@Test
void testWithContent() {
Document.Builder result = this.builder.withContent("Test content");
Document.Builder result = this.builder.content("Test content");
assertThat(result).isSameAs(this.builder);
assertThat(result.build().getContent()).isEqualTo("Test content");
@@ -108,14 +100,14 @@ public class DocumentBuilderTests {
@Test
void testWithContentNull() {
assertThatThrownBy(() -> this.builder.withContent(null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("content must not be null");
assertThatThrownBy(() -> this.builder.content(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("content cannot be null");
}
@Test
void testWithMediaList() {
List<Media> mediaList = getMediaList();
Document.Builder result = this.builder.withMedia(mediaList);
Document.Builder result = this.builder.media(mediaList);
assertThat(result).isSameAs(this.builder);
assertThat(result.build().getMedia()).isEqualTo(mediaList);
@@ -123,9 +115,9 @@ public class DocumentBuilderTests {
@Test
void testWithMediaListNull() {
assertThatThrownBy(() -> this.builder.withMedia((List<Media>) null))
assertThatThrownBy(() -> this.builder.media((List<Media>) null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("media must not be null");
.hasMessageContaining("media cannot be null");
}
@Test
@@ -133,7 +125,7 @@ public class DocumentBuilderTests {
URL mediaUrl = new URL("http://test");
Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl);
Document.Builder result = this.builder.withMedia(media);
Document.Builder result = this.builder.media(media);
assertThat(result).isSameAs(this.builder);
assertThat(result.build().getMedia()).contains(media);
@@ -141,8 +133,8 @@ public class DocumentBuilderTests {
@Test
void testWithMediaSingleNull() {
assertThatThrownBy(() -> this.builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("media must not be null");
assertThatThrownBy(() -> this.builder.media((Media) null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("media cannot contain null elements");
}
@Test
@@ -150,7 +142,7 @@ public class DocumentBuilderTests {
Map<String, Object> metadata = new HashMap<>();
metadata.put("key1", "value1");
metadata.put("key2", 2);
Document.Builder result = this.builder.withMetadata(metadata);
Document.Builder result = this.builder.metadata(metadata);
assertThat(result).isSameAs(this.builder);
assertThat(result.build().getMetadata()).isEqualTo(metadata);
@@ -158,47 +150,51 @@ public class DocumentBuilderTests {
@Test
void testWithMetadataMapNull() {
assertThatThrownBy(() -> this.builder.withMetadata((Map<String, Object>) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("metadata must not be null");
assertThatThrownBy(() -> this.builder.metadata(null).build()).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("metadata cannot be null");
}
@Test
void testWithMetadataKeyValue() {
Document.Builder result = this.builder.withMetadata("key", "value");
Document.Builder result = this.builder.metadata("key", "value");
assertThat(result).isSameAs(this.builder);
assertThat(result.build().getMetadata()).containsEntry("key", "value");
}
@Test
void testWithMetadataKeyValueNull() {
assertThatThrownBy(() -> this.builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("key must not be null");
void testWithMetadataKeyNull() {
assertThatThrownBy(() -> this.builder.metadata(null, "value").build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("metadata cannot have null keys");
}
assertThatThrownBy(() -> this.builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("value must not be null");
@Test
void testWithMetadataValueNull() {
assertThatThrownBy(() -> this.builder.metadata("key", null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("metadata cannot have null values");
}
@Test
void testBuildWithoutId() {
Document document = this.builder.withContent("Test content").build();
Document document = this.builder.content("Test content").build();
assertThat(document.getId()).isNotNull().isNotEmpty();
assertThat(document.getContent()).isEqualTo("Test content");
}
@Test
void testBuildWithAllProperties() throws MalformedURLException {
void testBuildWithAllProperties() {
List<Media> mediaList = getMediaList();
Map<String, Object> metadata = new HashMap<>();
metadata.put("key", "value");
Document document = this.builder.withId("customId")
.withContent("Test content")
.withMedia(mediaList)
.withMetadata(metadata)
Document document = this.builder.id("customId")
.content("Test content")
.media(mediaList)
.metadata(metadata)
.build();
assertThat(document.getId()).isEqualTo("customId");

View File

@@ -0,0 +1,181 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.document;
import org.junit.jupiter.api.Test;
import org.springframework.ai.model.Media;
import org.springframework.util.MimeTypeUtils;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
public class DocumentTests {
@Test
void testScore() {
Double score = 0.95;
Document document = Document.builder().content("Test content").score(score).build();
assertThat(document.getScore()).isEqualTo(score);
}
@Test
void testNullScore() {
Document document = Document.builder().content("Test content").score(null).build();
assertThat(document.getScore()).isNull();
}
@Test
void testMediaBuilderIsAdditive() {
try {
URL mediaUrl1 = new URL("http://type1");
URL mediaUrl2 = new URL("http://type2");
URL mediaUrl3 = new URL("http://type3");
Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1);
Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2);
Media media3 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl3);
Document document = Document.builder().media(media1).media(media2).media(List.of(media3)).build();
assertThat(document.getMedia()).hasSize(3).containsExactly(media1, media2, media3);
}
catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
@Test
void testMutate() {
List<Media> mediaList = getMediaList();
Map<String, Object> metadata = new HashMap<>();
metadata.put("key", "value");
Double score = 0.95;
Document original = Document.builder()
.id("customId")
.content("Test content")
.media(mediaList)
.metadata(metadata)
.score(score)
.build();
Document mutated = original.mutate().build();
assertThat(mutated).isNotSameAs(original).usingRecursiveComparison().isEqualTo(original);
}
@Test
void testEquals() {
List<Media> mediaList = getMediaList();
Map<String, Object> metadata = new HashMap<>();
metadata.put("key", "value");
Double score = 0.95;
Document doc1 = Document.builder()
.id("customId")
.content("Test content")
.media(mediaList)
.metadata(metadata)
.score(score)
.build();
Document doc2 = Document.builder()
.id("customId")
.content("Test content")
.media(mediaList)
.metadata(metadata)
.score(score)
.build();
Document differentDoc = Document.builder()
.id("differentId")
.content("Different content")
.media(mediaList)
.metadata(metadata)
.score(score)
.build();
assertThat(doc1).isEqualTo(doc2).isNotEqualTo(differentDoc).isNotEqualTo(null).isNotEqualTo(new Object());
assertThat(doc1.hashCode()).isEqualTo(doc2.hashCode());
}
@Test
void testEmptyDocument() {
Document emptyDoc = Document.builder().build();
assertThat(emptyDoc.getContent()).isEqualTo(Document.EMPTY_TEXT).isEmpty();
assertThat(emptyDoc.getMedia()).isEmpty();
assertThat(emptyDoc.getMetadata()).isEmpty();
assertThat(emptyDoc.getScore()).isNull();
}
@Test
void testToString() {
List<Media> mediaList = getMediaList();
Map<String, Object> metadata = new HashMap<>();
metadata.put("key", "value");
Double score = 0.95;
Document document = Document.builder()
.id("customId")
.content("Test content")
.media(mediaList)
.metadata(metadata)
.score(score)
.build();
String toString = document.toString();
assertThat(toString).contains("id='customId'")
.contains("content='Test content'")
.contains("media=" + mediaList)
.contains("metadata=" + metadata)
.contains("score=" + score);
}
@Test
void testToStringWithEmptyDocument() {
Document emptyDoc = Document.builder().build();
String toString = emptyDoc.toString();
assertThat(toString).contains("content=''").contains("media=[]").contains("metadata={}").contains("score=null");
}
private static List<Media> getMediaList() {
try {
URL mediaUrl1 = new URL("http://type1");
URL mediaUrl2 = new URL("http://type2");
Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1);
Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2);
return List.of(media1, media2);
}
catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -27,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Ilayaperumal Gopinathan
* @author Thomas Vitale
*/
public class SimpleVectorStoreSimilarityTests {
@@ -38,8 +39,7 @@ public class SimpleVectorStoreSimilarityTests {
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata,
testEmbedding);
SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d);
Document document = similarity.getDocument();
Document document = storeContent.toDocument(0.6);
assertThat(document).isNotNull();
assertThat(document.getId()).isEqualTo("1");
assertThat(document.getContent()).isEqualTo("hello, how are you?");

View File

@@ -62,11 +62,7 @@ class SimpleVectorStoreTests {
@Test
void shouldAddAndRetrieveDocument() {
Document doc = Document.builder()
.withId("1")
.withContent("test content")
.withMetadata(Map.of("key", "value"))
.build();
Document doc = Document.builder().id("1").content("test content").metadata(Map.of("key", "value")).build();
this.vectorStore.add(List.of(doc));
@@ -80,8 +76,8 @@ class SimpleVectorStoreTests {
@Test
void shouldAddMultipleDocuments() {
List<Document> docs = Arrays.asList(Document.builder().withId("1").withContent("first").build(),
Document.builder().withId("2").withContent("second").build());
List<Document> docs = Arrays.asList(Document.builder().id("1").content("first").build(),
Document.builder().id("2").content("second").build());
this.vectorStore.add(docs);
@@ -104,7 +100,7 @@ class SimpleVectorStoreTests {
@Test
void shouldDeleteDocuments() {
Document doc = Document.builder().withId("1").withContent("test content").build();
Document doc = Document.builder().id("1").content("test content").build();
this.vectorStore.add(List.of(doc));
assertThat(this.vectorStore.similaritySearch("test")).hasSize(1);
@@ -125,7 +121,7 @@ class SimpleVectorStoreTests {
// Configure mock to return different embeddings for different queries
when(this.mockEmbeddingModel.embed("query")).thenReturn(new float[] { 0.9f, 0.9f, 0.9f });
Document doc = Document.builder().withId("1").withContent("test content").build();
Document doc = Document.builder().id("1").content("test content").build();
this.vectorStore.add(List.of(doc));
@@ -138,9 +134,9 @@ class SimpleVectorStoreTests {
@Test
void shouldSaveAndLoadVectorStore() throws IOException {
Document doc = Document.builder()
.withId("1")
.withContent("test content")
.withMetadata(new HashMap<>(Map.of("key", "value")))
.id("1")
.content("test content")
.metadata(new HashMap<>(Map.of("key", "value")))
.build();
this.vectorStore.add(List.of(doc));
@@ -185,7 +181,7 @@ class SimpleVectorStoreTests {
for (int i = 0; i < numThreads; i++) {
final String id = String.valueOf(i);
threads[i] = new Thread(() -> {
Document doc = Document.builder().withId(id).withContent("content " + id).build();
Document doc = Document.builder().id(id).content("content " + id).build();
this.vectorStore.add(List.of(doc));
});
threads[i].start();

View File

@@ -54,7 +54,6 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
@@ -76,6 +75,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-test</artifactId>
<version>${project.parent.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-testcontainers</artifactId>

View File

@@ -16,7 +16,10 @@
package org.springframework.ai.integration.tests;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
/**
@@ -28,4 +31,9 @@ import org.springframework.context.annotation.Import;
@Import(TestcontainersConfiguration.class)
public class TestApplication {
@Bean
SimpleVectorStore simpleVectorStore(EmbeddingModel embeddingModel) {
return new SimpleVectorStore(embeddingModel);
}
}

View File

@@ -0,0 +1,122 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.integration.tests.vectorstore;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.integration.tests.TestApplication;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.DefaultResourceLoader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Integration tests for {@link SimpleVectorStore}.
*
* @author Thomas Vitale
*/
@SpringBootTest(classes = TestApplication.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
public class SimpleVectorStoreIT {
@Autowired
private SimpleVectorStore vectorStore;
List<Document> documents = List.of(
Document.builder()
.id("471a8c78-549a-4b2c-bce5-ef3ae6579be3")
.content(getText("classpath:/test/data/spring.ai.txt"))
.metadata(Map.of("meta1", "meta1"))
.build(),
Document.builder()
.id("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f")
.content(getText("classpath:/test/data/time.shelter.txt"))
.metadata(Map.of())
.build(),
Document.builder()
.id("d0237682-1150-44ff-b4d2-1be9b1731ee5")
.content(getText("classpath:/test/data/great.depression.txt"))
.metadata(Map.of("meta2", "meta2"))
.build());
public static String getText(String uri) {
var resource = new DefaultResourceLoader().getResource(uri);
try {
return resource.getContentAsString(StandardCharsets.UTF_8);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@AfterEach
void setUp() {
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
}
@Test
public void searchWithThreshold() {
Document document = Document.builder()
.id(UUID.randomUUID().toString())
.content("Spring AI rocks!!")
.metadata("meta1", "meta1")
.build();
vectorStore.add(List.of(document));
List<Document> results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = Document.builder()
.id(document.getId())
.content("The World is Big and Salvation Lurks Around the Corner")
.metadata("meta2", "meta2")
.build();
vectorStore.add(List.of(sameIdDocument));
results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5));
assertThat(results).hasSize(1);
resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
vectorStore.delete(List.of(document.getId()));
}
}

View File

@@ -18,6 +18,7 @@ package org.springframework.ai.autoconfigure.vectorstore.pinecone;
import java.time.Duration;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.vectorstore.PineconeVectorStore;
import org.springframework.boot.context.properties.ConfigurationProperties;
@@ -25,6 +26,7 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
* Configuration properties for Pinecone Vector Store.
*
* @author Christian Tzolov
* @author Thomas Vitale
*/
@ConfigurationProperties(PineconeVectorStoreProperties.CONFIG_PREFIX)
public class PineconeVectorStoreProperties {
@@ -43,7 +45,7 @@ public class PineconeVectorStoreProperties {
private String contentFieldName = PineconeVectorStore.CONTENT_FIELD_NAME;
private String distanceMetadataFieldName = PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME;
private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value();
private Duration serverSideTimeout = Duration.ofSeconds(20);

View File

@@ -20,12 +20,14 @@ import java.time.Duration;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.vectorstore.PineconeVectorStore;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
* @author Thomas Vitale
*/
public class PineconeVectorStorePropertiesTests {
@@ -39,7 +41,7 @@ public class PineconeVectorStorePropertiesTests {
assertThat(props.getIndexName()).isNull();
assertThat(props.getServerSideTimeout()).isEqualTo(Duration.ofSeconds(20));
assertThat(props.getContentFieldName()).isEqualTo(PineconeVectorStore.CONTENT_FIELD_NAME);
assertThat(props.getDistanceMetadataFieldName()).isEqualTo(PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME);
assertThat(props.getDistanceMetadataFieldName()).isEqualTo(DocumentMetadata.DISTANCE.value());
}
@Test

View File

@@ -74,6 +74,7 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv
*
* @author Theo van Kraay
* @author Soby Chacko
* @author Thomas Vitale
* @since 1.0.0
*/
public class CosmosDBVectorStore extends AbstractObservationVectorStore implements AutoCloseable {
@@ -338,7 +339,7 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen
.block();
// Convert JsonNode to Document
List<Document> docs = documents.stream()
.map(doc -> new Document(doc.get("id").asText(), doc.get("content").asText(), new HashMap<>()))
.map(doc -> Document.builder().id(doc.get("id").asText()).content(doc.get("content").asText()).build())
.collect(Collectors.toList());
return docs != null ? docs : List.of();

View File

@@ -42,6 +42,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* @author Theo van Kraay
* @author Thomas Vitale
* @since 1.0.0
*/
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vectorstore;
import org.testcontainers.utility.DockerImageName;
/**
* @author Thomas Vitale
*/
public final class CosmosDbImage {
// It must always be "latest" or else Azure locks the image after a while. See:
// https://github.com/Azure/azure-cosmos-db-emulator-docker/issues/60
public static final DockerImageName DEFAULT_IMAGE = DockerImageName
.parse("mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:latest");
private CosmosDbImage() {
}
}

View File

@@ -47,6 +47,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -96,8 +97,6 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements
private static final String METADATA_FIELD_NAME = "metadata";
private static final String DISTANCE_METADATA_FIELD_NAME = "distance";
private static final int DEFAULT_TOP_K = 4;
private static final Double DEFAULT_SIMILARITY_THRESHOLD = 0.0;
@@ -321,13 +320,15 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements
}) : Map.of();
metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore());
final Document doc = new Document(entry.id(), entry.content(), metadata);
doc.setEmbedding(EmbeddingUtils.toPrimitive(entry.embedding()));
return doc;
metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - result.getScore());
return Document.builder()
.id(entry.id())
.content(entry.content)
.metadata(metadata)
.score(result.getScore())
.embedding(EmbeddingUtils.toPrimitive(entry.embedding))
.build();
})
.collect(Collectors.toList());
}

View File

@@ -35,6 +35,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;
@@ -52,6 +53,7 @@ import static org.hamcrest.Matchers.hasSize;
/**
* @author Christian Tzolov
* @author Thomas Vitale
*/
@EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+")
@@ -103,7 +105,7 @@ public class AzureVectorStoreIT {
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -224,7 +226,7 @@ public class AzureVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -245,7 +247,7 @@ public class AzureVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(List.of(document.getId()));
@@ -271,21 +273,22 @@ public class AzureVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());

View File

@@ -45,6 +45,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -106,8 +107,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv
*/
public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable {
public static final String SIMILARITY_FIELD_NAME = "similarity_score";
public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates";
public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search";
@@ -252,14 +251,19 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
break;
}
Map<String, Object> docFields = new HashMap<>();
docFields.put(SIMILARITY_FIELD_NAME, score);
docFields.put(DocumentMetadata.DISTANCE.value(), 1 - score);
for (var metadata : this.conf.schema.metadataColumns()) {
var value = row.get(metadata.name(), metadata.javaType());
if (null != value) {
docFields.put(metadata.name(), value);
}
}
Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields);
Document doc = Document.builder()
.id(getDocumentId(row))
.content(row.getString(this.conf.schema.content()))
.metadata(docFields)
.score((double) score)
.build();
if (this.conf.returnEmbeddings) {
doc.setEmbedding(EmbeddingUtils

View File

@@ -37,6 +37,7 @@ import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.containers.CassandraContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -226,8 +227,7 @@ class CassandraRichSchemaVectorStoreIT {
assertThat(resultDoc.getMetadata()).hasSize(3);
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision",
CassandraVectorStore.SIMILARITY_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value());
// Remove all documents from the createStore
store.delete(documents.stream().map(doc -> doc.getId()).toList());
@@ -494,8 +494,7 @@ class CassandraRichSchemaVectorStoreIT {
assertThat(resultDoc.getId()).isNotEqualTo(sameIdDocument.getId());
assertThat(resultDoc.getContent()).doesNotContain(newContent);
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision",
CassandraVectorStore.SIMILARITY_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value());
}
});
}
@@ -509,16 +508,15 @@ class CassandraRichSchemaVectorStoreIT {
List<Document> fullResult = store
.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream()
.map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = store.similaritySearch(
SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThreshold(threshold));
List<Document> results = store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY)
.withTopK(5)
.withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -526,8 +524,8 @@ class CassandraRichSchemaVectorStoreIT {
assertThat(resultDoc.getContent()).contains(URANUS_ORBIT_QUERY);
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision",
CassandraVectorStore.SIMILARITY_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
}
});
}

View File

@@ -30,6 +30,7 @@ import com.datastax.oss.driver.api.core.servererrors.SyntaxError;
import com.datastax.oss.driver.api.core.type.DataTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.containers.CassandraContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -138,7 +139,7 @@ class CassandraVectorStoreIT {
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
store.delete(documents().stream().map(doc -> doc.getId()).toList());
@@ -174,7 +175,7 @@ class CassandraVectorStoreIT {
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).hasSize(1);
assertThat(resultDoc.getMetadata()).containsKey(CassandraVectorStore.SIMILARITY_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
store.delete(documents().stream().map(doc -> doc.getId()).toList());
@@ -359,7 +360,7 @@ class CassandraVectorStoreIT {
resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", CassandraVectorStore.SIMILARITY_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
store.delete(List.of(document.getId()));
}
@@ -375,16 +376,14 @@ class CassandraVectorStoreIT {
List<Document> fullResult = store
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream()
.map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = store
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(threshold));
List<Document> results = store.similaritySearch(
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -393,7 +392,8 @@ class CassandraVectorStoreIT {
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
}
});
}

View File

@@ -32,6 +32,7 @@ import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest;
import org.springframework.ai.chroma.ChromaApi.DeleteEmbeddingsRequest;
import org.springframework.ai.chroma.ChromaApi.Embedding;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -58,11 +59,10 @@ import org.springframework.util.CollectionUtils;
* @author Fu Cheng
* @author Sebastien Deleuze
* @author Soby Chacko
* @author Thomas Vitale
*/
public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean {
public static final String DISTANCE_FIELD_NAME = "distance";
public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection";
private final EmbeddingModel embeddingModel;
@@ -192,9 +192,14 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements
if (metadata == null) {
metadata = new HashMap<>();
}
metadata.put(DISTANCE_FIELD_NAME, distance);
Document document = new Document(id, content, metadata);
document.setEmbedding(chromaEmbedding.embedding());
metadata.put(DocumentMetadata.DISTANCE.value(), distance);
Document document = Document.builder()
.id(id)
.content(content)
.metadata(metadata)
.embedding(chromaEmbedding.embedding())
.score(1.0 - distance)
.build();
responseDocuments.add(document);
}
}
@@ -244,8 +249,7 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements
@NonNull String operationName) {
return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName)
.withDimensions(this.embeddingModel.dimensions())
.withCollectionName(this.collectionName + ":" + this.collectionId)
.withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null);
.withCollectionName(this.collectionName + ":" + this.collectionId);
}
public static class Builder {

View File

@@ -24,6 +24,7 @@ import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -81,7 +82,7 @@ public class ChromaVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
assertThat(vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()))
@@ -99,8 +100,8 @@ public class ChromaVectorStoreIT {
VectorStore vectorStore = context.getBean(VectorStore.class);
var document = Document.builder()
.withId("simpleDoc")
.withContent("The sky is blue because of Rayleigh scattering.")
.id("simpleDoc")
.content("The sky is blue because of Rayleigh scattering.")
.build();
vectorStore.add(List.of(document));
@@ -179,7 +180,7 @@ public class ChromaVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -194,7 +195,7 @@ public class ChromaVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(List.of(document.getId()));
@@ -213,13 +214,13 @@ public class ChromaVectorStoreIT {
var request = SearchRequest.query("Great").withTopK(5);
List<Document> fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll());
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -227,7 +228,8 @@ public class ChromaVectorStoreIT {
assertThat(resultDoc.getContent()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());

View File

@@ -107,7 +107,6 @@ public class ChromaVectorStoreObservationIT {
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(),
"TestCollection:" + vectorStore.getCollectionId())
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString())
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance")
.doesNotHaveHighCardinalityKeyValueWithKey(
HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString())
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString())
@@ -141,7 +140,6 @@ public class ChromaVectorStoreObservationIT {
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(),
"TestCollection:" + vectorStore.getCollectionId())
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString())
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance")
.doesNotHaveHighCardinalityKeyValueWithKey(
HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString())
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString(), "1")

View File

@@ -37,6 +37,7 @@ import com.tangosol.net.Session;
import com.tangosol.util.Filter;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.beans.factory.InitializingBean;
@@ -62,6 +63,7 @@ import org.springframework.beans.factory.InitializingBean;
* </ul>
*
* @author Aleks Seovic
* @author Thomas Vitale
* @since 1.0.0
*/
public class CoherenceVectorStore implements VectorStore, InitializingBean {
@@ -211,8 +213,13 @@ public class CoherenceVectorStore implements VectorStore, InitializingBean {
if (this.distanceType != DistanceType.COSINE || (1 - r.getDistance()) >= request.getSimilarityThreshold()) {
DocumentChunk.Id id = r.getKey();
DocumentChunk chunk = r.getValue();
chunk.metadata().put("distance", r.getDistance());
documents.add(new Document(id.docId(), chunk.text(), chunk.metadata()));
chunk.metadata().put(DocumentMetadata.DISTANCE.value(), r.getDistance());
documents.add(Document.builder()
.id(id.docId())
.content(chunk.text())
.metadata(chunk.metadata())
.score(1 - r.getDistance())
.build());
}
}
return documents;

View File

@@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
@@ -125,7 +126,7 @@ public class CoherenceVectorStoreIT {
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -223,7 +224,7 @@ public class CoherenceVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -236,7 +237,7 @@ public class CoherenceVectorStoreIT {
resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName());
});
@@ -257,18 +258,18 @@ public class CoherenceVectorStoreIT {
assertThat(isSortedByDistance(fullResult)).isTrue();
List<Double> distances = fullResult.stream()
.map(doc -> (Double) doc.getMetadata().get("distance"))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
double threshold = 1d - (distances.get(0) + distances.get(1)) / 2f;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId());
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName());
});
@@ -276,7 +277,7 @@ public class CoherenceVectorStoreIT {
private static boolean isSortedByDistance(final List<Document> documents) {
final List<Double> distances = documents.stream()
.map(doc -> (Double) doc.getMetadata().get("distance"))
.map(doc -> (Double) doc.getMetadata().get(DocumentMetadata.DISTANCE.value()))
.toList();
if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {

View File

@@ -40,6 +40,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -210,20 +211,24 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp
private Document toDocument(Hit<Document> hit) {
Document document = hit.source();
document.getMetadata().put("distance", calculateDistance(hit.score().floatValue()));
return document;
Document.Builder documentBuilder = document.mutate();
if (hit.score() != null) {
documentBuilder.metadata(DocumentMetadata.DISTANCE.value(), 1 - normalizeSimilarityScore(hit.score()));
documentBuilder.score(normalizeSimilarityScore(hit.score()));
}
return documentBuilder.build();
}
// more info on score/distance calculation
// https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search
private float calculateDistance(Float score) {
private double normalizeSimilarityScore(double score) {
switch (this.options.getSimilarity()) {
case l2_norm:
// the returned value of l2_norm is the opposite of the other functions
// (closest to zero means more accurate), so to make it consistent
// with the other functions the reverse is returned applying a "1-"
// to the standard transformation
return (float) (1 - (java.lang.Math.sqrt((1 / score) - 1)));
return (1 - (java.lang.Math.sqrt((1 / score) - 1)));
// cosine and dot_product
default:
return (2 * score) - 1;

View File

@@ -42,6 +42,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.elasticsearch.ElasticsearchContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -165,7 +166,7 @@ class ElasticsearchVectorStoreIT {
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
@@ -299,7 +300,7 @@ class ElasticsearchVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2"));
@@ -318,7 +319,7 @@ class ElasticsearchVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(List.of(document.getId()));
@@ -343,21 +344,22 @@ class ElasticsearchVectorStoreIT {
List<Document> fullResult = vectorStore.similaritySearch(query);
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float thresholdResult = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(thresholdResult));
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(Document::getId).toList());

View File

@@ -30,6 +30,7 @@ import com.fasterxml.jackson.databind.json.JsonMapper;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentMetadata;
import reactor.util.annotation.NonNull;
import org.springframework.ai.document.Document;
@@ -172,8 +173,6 @@ public class GemFireVectorStore extends AbstractObservationVectorStore implement
// Query Defaults
private static final String QUERY = "/query";
private static final String DISTANCE_METADATA_FIELD_NAME = "distance";
/**
* Initializes the GemFireVectorStore after properties are set. This method is called
* after all bean properties have been set and allows the bean to perform any
@@ -271,9 +270,9 @@ public class GemFireVectorStore extends AbstractObservationVectorStore implement
metadata = new HashMap<>();
metadata.put(DOCUMENT_FIELD, "--Deleted--");
}
metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score);
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - r.score);
String content = (String) metadata.remove(DOCUMENT_FIELD);
return new Document(r.key, content, metadata);
return Document.builder().id(r.key).content(content).metadata(metadata).score((double) r.score).build();
})
.collectList()
.onErrorMap(WebClientException.class, this::handleHttpClientException)

View File

@@ -34,6 +34,7 @@ import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.boot.SpringBootConfiguration;
@@ -134,7 +135,7 @@ public class GemFireVectorStoreIT {
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939)" + " was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
});
}
@@ -156,7 +157,7 @@ public class GemFireVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks " + "Around the Corner",
@@ -171,7 +172,7 @@ public class GemFireVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation" + " Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
});
}
@@ -191,12 +192,12 @@ public class GemFireVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
assertThat(distances).hasSize(3);
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold));
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
@@ -204,7 +205,8 @@ public class GemFireVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression " + "(19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
});
}

View File

@@ -54,6 +54,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -97,7 +98,7 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements
public static final String EMBEDDING_FIELD_NAME = "embedding";
// Metadata, automatically assigned by Milvus.
public static final String DISTANCE_FIELD_NAME = "distance";
private static final String DISTANCE_FIELD_NAME = "distance";
private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class);
@@ -258,13 +259,18 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements
try {
metadata = (JSONObject) rowRecord.get(this.config.metadataFieldName);
// inject the distance into the metadata.
metadata.put(DISTANCE_FIELD_NAME, 1 - getResultSimilarity(rowRecord));
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - getResultSimilarity(rowRecord));
}
catch (ParamException e) {
// skip the ParamException if metadata doesn't exist for the custom
// collection
}
return new Document(docId, content, (metadata != null) ? metadata.getInnerMap() : Map.of());
return Document.builder()
.id(docId)
.content(content)
.metadata((metadata != null) ? metadata.getInnerMap() : Map.of())
.score((double) getResultSimilarity(rowRecord))
.build();
})
.toList();
}

View File

@@ -30,6 +30,7 @@ import io.milvus.param.MetricType;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.milvus.MilvusContainer;
@@ -106,7 +107,7 @@ public class MilvusVectorStoreIT {
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -200,7 +201,7 @@ public class MilvusVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -215,7 +216,7 @@ public class MilvusVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
vectorStore.delete(List.of(document.getId()));
@@ -238,24 +239,22 @@ public class MilvusVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream()
.map(doc -> (Float) doc.getMetadata().get("distance"))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
});
}

View File

@@ -26,6 +26,7 @@ import com.mongodb.MongoCommandException;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -172,12 +173,17 @@ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore impl
private Document mapMongoDocument(org.bson.Document mongoDocument, float[] queryEmbedding) {
String id = mongoDocument.getString(ID_FIELD_NAME);
String content = mongoDocument.getString(CONTENT_FIELD_NAME);
double score = mongoDocument.getDouble(SCORE_FIELD_NAME);
Map<String, Object> metadata = mongoDocument.get(METADATA_FIELD_NAME, org.bson.Document.class);
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - score);
Document document = new Document(id, content, metadata);
document.setEmbedding(queryEmbedding);
return document;
return Document.builder()
.id(id)
.content(content)
.metadata(metadata)
.score(score)
.embedding(queryEmbedding)
.build();
}
@Override

View File

@@ -16,6 +16,8 @@
package org.springframework.ai.vectorstore;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@@ -27,6 +29,8 @@ import com.mongodb.client.MongoClient;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.core.io.DefaultResourceLoader;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.mongodb.MongoDBAtlasLocalContainer;
@@ -198,6 +202,55 @@ class MongoDBAtlasVectorStoreIT {
});
}
@Test
public void searchWithThreshold() {
this.contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
var documents = List.of(
new Document("471a8c78-549a-4b2c-bce5-ef3ae6579be3", getText("classpath:/test/data/spring.ai.txt"),
Map.of("meta1", "meta1")),
new Document("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f",
getText("classpath:/test/data/time.shelter.txt"), Map.of()),
new Document("d0237682-1150-44ff-b4d2-1be9b1731ee5",
getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
vectorStore.add(documents);
Thread.sleep(5000); // Await a second for the document to be indexed
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
assertThat(fullResult).hasSize(3);
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(scores).hasSize(3);
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId());
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
});
}
public static String getText(String uri) {
var resource = new DefaultResourceLoader().getResource(uri);
try {
return resource.getContentAsString(StandardCharsets.UTF_8);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@SpringBootConfiguration
@EnableAutoConfiguration
public static class TestApplication {

View File

@@ -29,6 +29,7 @@ import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Values;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -224,15 +225,19 @@ public class Neo4jVectorStore extends AbstractObservationVectorStore implements
var node = neoRecord.get("node").asNode();
var score = neoRecord.get("score").asFloat();
var metaData = new HashMap<String, Object>();
metaData.put("distance", 1 - score);
metaData.put(DocumentMetadata.DISTANCE.value(), 1 - score);
node.keys().forEach(key -> {
if (key.startsWith("metadata.")) {
metaData.put(key.substring(key.indexOf(".") + 1), node.get(key).asObject());
}
});
return new Document(node.get(this.config.idProperty).asString(), node.get("text").asString(),
Map.copyOf(metaData));
return Document.builder()
.id(node.get(this.config.idProperty).asString())
.content(node.get("text").asString())
.metadata(Map.copyOf(metaData))
.score((double) score)
.build();
}
@Override

View File

@@ -28,6 +28,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -91,7 +92,7 @@ class Neo4jVectorStoreIT {
assertThat(resultDoc.getContent()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
@@ -203,7 +204,7 @@ class Neo4jVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -218,7 +219,7 @@ class Neo4jVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
});
}
@@ -235,14 +236,14 @@ class Neo4jVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -250,8 +251,8 @@ class Neo4jVectorStoreIT {
assertThat(resultDoc.getContent()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
});
}

View File

@@ -40,6 +40,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -231,8 +232,12 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem
private Document toDocument(Hit<Document> hit) {
Document document = hit.source();
document.getMetadata().put("distance", 1 - hit.score().floatValue());
return document;
Document.Builder documentBuilder = document.mutate();
if (hit.score() != null) {
documentBuilder.metadata(DocumentMetadata.DISTANCE.value(), 1 - hit.score().floatValue());
documentBuilder.score(hit.score());
}
return documentBuilder.build();
}
public boolean exists(String targetIndex) {

View File

@@ -39,6 +39,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;
import org.opensearch.testcontainers.OpensearchContainer;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -144,7 +145,7 @@ class OpenSearchVectorStoreIT {
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
@@ -281,7 +282,7 @@ class OpenSearchVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2"));
@@ -300,7 +301,7 @@ class OpenSearchVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(List.of(document.getId()));
@@ -330,21 +331,22 @@ class OpenSearchVectorStoreIT {
List<Document> fullResult = vectorStore.similaritySearch(query);
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold));
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(Document::getId).toList());

View File

@@ -39,6 +39,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -80,6 +81,7 @@ import org.springframework.util.StringUtils;
* @author Loïc Lefèvre
* @author Christian Tzolov
* @author Soby Chacko
* @author Thomas Vitale
*/
public class OracleVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -649,12 +651,16 @@ public class OracleVectorStore extends AbstractObservationVectorStore implements
@Override
public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
final Map<String, Object> metadata = getMap(rs.getObject(3, OracleJsonValue.class));
metadata.put("distance", rs.getDouble(5));
metadata.put(DocumentMetadata.DISTANCE.value(), rs.getDouble(5));
final Document document = new Document(rs.getString(1), rs.getString(2), metadata);
final float[] embedding = rs.getObject(4, float[].class);
document.setEmbedding(embedding);
return document;
return Document.builder()
.id(rs.getString(1))
.content(rs.getString(2))
.metadata(metadata)
.score(1 - rs.getDouble(5))
.embedding(embedding)
.build();
}
private Map<String, Object> getMap(OracleJsonValue value) {

View File

@@ -32,6 +32,7 @@ import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.oracle.OracleContainer;
@@ -94,21 +95,19 @@ public class OracleVectorStoreIT {
jdbcTemplate.execute("DROP TABLE IF EXISTS " + tableName + " PURGE");
}
private static boolean isSortedByDistance(final List<Document> documents) {
final List<Double> distances = documents.stream()
.map(doc -> (Double) doc.getMetadata().get("distance"))
.toList();
private static boolean isSortedBySimilarity(final List<Document> documents) {
final List<Double> scores = documents.stream().map(Document::getScore).toList();
if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {
if (CollectionUtils.isEmpty(scores) || scores.size() == 1) {
return true;
}
Iterator<Double> iter = distances.iterator();
Iterator<Double> iter = scores.iterator();
Double current;
Double previous = iter.next();
while (iter.hasNext()) {
current = iter.next();
if (previous > current) {
if (previous < current) {
return false;
}
previous = current;
@@ -134,7 +133,7 @@ public class OracleVectorStoreIT {
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -243,7 +242,7 @@ public class OracleVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -256,7 +255,7 @@ public class OracleVectorStoreIT {
resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
dropTable(context, ((OracleVectorStore) vectorStore).getTableName());
});
@@ -279,20 +278,19 @@ public class OracleVectorStoreIT {
assertThat(fullResult).hasSize(3);
assertThat(isSortedByDistance(fullResult)).isTrue();
assertThat(isSortedBySimilarity(fullResult)).isTrue();
List<Double> distances = fullResult.stream()
.map(doc -> (Double) doc.getMetadata().get("distance"))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
double threshold = (distances.get(0) + distances.get(1)) / 2d;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2d;
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1d - threshold));
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
dropTable(context, ((OracleVectorStore) vectorStore).getTableName());
});

View File

@@ -35,6 +35,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -502,12 +503,15 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
Float distance = rs.getFloat(COLUMN_DISTANCE);
Map<String, Object> metadata = toMap(pgMetadata);
metadata.put(COLUMN_DISTANCE, distance);
metadata.put(DocumentMetadata.DISTANCE.value(), distance);
Document document = new Document(id, content, metadata);
document.setEmbedding(toFloatArray(embedding));
return document;
return Document.builder()
.id(id)
.content(content)
.metadata(metadata)
.score(1.0 - distance)
.embedding(toFloatArray(embedding))
.build();
}
private float[] toFloatArray(PGobject embedding) throws SQLException {

View File

@@ -34,6 +34,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -113,20 +114,20 @@ public class PgVectorStoreIT {
);
}
private static boolean isSortedByDistance(List<Document> docs) {
private static boolean isSortedBySimilarity(List<Document> docs) {
List<Float> distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = docs.stream().map(Document::getScore).toList();
if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {
if (CollectionUtils.isEmpty(scores) || scores.size() == 1) {
return true;
}
Iterator<Float> iter = distances.iterator();
Float current;
Float previous = iter.next();
Iterator<Double> iter = scores.iterator();
Double current;
Double previous = iter.next();
while (iter.hasNext()) {
current = iter.next();
if (previous > current) {
if (previous < current) {
return false;
}
previous = current;
@@ -150,7 +151,7 @@ public class PgVectorStoreIT {
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -289,7 +290,7 @@ public class PgVectorStoreIT {
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -303,7 +304,7 @@ public class PgVectorStoreIT {
resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
dropTable(context);
});
@@ -326,20 +327,19 @@ public class PgVectorStoreIT {
assertThat(fullResult).hasSize(3);
assertThat(isSortedByDistance(fullResult)).isTrue();
assertThat(isSortedBySimilarity(fullResult)).isTrue();
List<Float> distances = fullResult.stream()
.map(doc -> (Float) doc.getMetadata().get("distance"))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1 - threshold));
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
dropTable(context);
});

View File

@@ -38,6 +38,7 @@ import io.pinecone.proto.UpsertRequest;
import io.pinecone.proto.Vector;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -60,13 +61,12 @@ import org.springframework.util.StringUtils;
* @author Christian Tzolov
* @author Adam Bchouti
* @author Soby Chacko
* @author Thomas Vitale
*/
public class PineconeVectorStore extends AbstractObservationVectorStore {
public static final String CONTENT_FIELD_NAME = "document_content";
public static final String DISTANCE_METADATA_FIELD_NAME = "distance";
public final FilterExpressionConverter filterExpressionConverter = new PineconeFilterExpressionConverter();
private final EmbeddingModel embeddingModel;
@@ -236,7 +236,12 @@ public class PineconeVectorStore extends AbstractObservationVectorStore {
var content = metadataStruct.getFieldsOrThrow(this.pineconeContentFieldName).getStringValue();
Map<String, Object> metadata = extractMetadata(metadataStruct);
metadata.put(this.pineconeDistanceMetadataFieldName, 1 - scoredVector.getScore());
return new Document(id, content, metadata);
return Document.builder()
.id(id)
.content(content)
.metadata(metadata)
.score((double) scoredVector.getScore())
.build();
})
.toList();
}
@@ -298,6 +303,8 @@ public class PineconeVectorStore extends AbstractObservationVectorStore {
private final String contentFieldName;
// TODO: Why is this field configurable? Can we remove this after standardizing
// the key?
private final String distanceMetadataFieldName;
private final PineconeConnectionConfig connectionConfig;
@@ -357,7 +364,7 @@ public class PineconeVectorStore extends AbstractObservationVectorStore {
private String contentFieldName = CONTENT_FIELD_NAME;
private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME;
private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value();
/**
* Optional server-side timeout in seconds for all operations. Default: 20

View File

@@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.PineconeVectorStore.PineconeVectorStoreConfig;
@@ -46,6 +47,7 @@ import static org.hamcrest.Matchers.hasSize;
/**
* @author Christian Tzolov
* @author Thomas Vitale
*/
@EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+")
public class PineconeVectorStoreIT {
@@ -109,7 +111,7 @@ public class PineconeVectorStoreIT {
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -193,7 +195,7 @@ public class PineconeVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -214,7 +216,7 @@ public class PineconeVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(List.of(document.getId()));
@@ -240,21 +242,22 @@ public class PineconeVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());

View File

@@ -35,6 +35,7 @@ import io.qdrant.client.grpc.Points.SearchPoints;
import io.qdrant.client.grpc.Points.UpdateStatus;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -57,6 +58,7 @@ import org.springframework.util.Assert;
* @author Eddú Meléndez
* @author Josh Long
* @author Soby Chacko
* @author Thomas Vitale
* @since 0.8.1
*/
public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -65,8 +67,6 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements
private static final String CONTENT_FIELD_NAME = "doc_content";
private static final String DISTANCE_FIELD_NAME = "distance";
private final EmbeddingModel embeddingModel;
private final QdrantClient qdrantClient;
@@ -208,12 +208,17 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements
try {
var id = point.getId().getUuid();
var payload = QdrantObjectFactory.toObjectMap(point.getPayloadMap());
payload.put(DISTANCE_FIELD_NAME, 1 - point.getScore());
var metadata = QdrantObjectFactory.toObjectMap(point.getPayloadMap());
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - point.getScore());
var content = (String) payload.remove(CONTENT_FIELD_NAME);
var content = (String) metadata.remove(CONTENT_FIELD_NAME);
return new Document(id, content, payload);
return Document.builder()
.id(id)
.content(content)
.metadata(metadata)
.score((double) point.getScore())
.build();
}
catch (Exception e) {
throw new RuntimeException(e);

View File

@@ -30,6 +30,7 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.qdrant.QdrantContainer;
@@ -107,7 +108,7 @@ public class QdrantVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -185,7 +186,7 @@ public class QdrantVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -200,7 +201,7 @@ public class QdrantVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
vectorStore.delete(List.of(document.getId()));
});
@@ -218,13 +219,13 @@ public class QdrantVectorStoreIT {
var request = SearchRequest.query("Great").withTopK(5);
List<Document> fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll());
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
@@ -232,7 +233,8 @@ public class QdrantVectorStoreIT {
assertThat(resultDoc.getContent()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());

View File

@@ -30,6 +30,7 @@ import java.util.stream.Collectors;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentMetadata;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
@@ -240,14 +241,21 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements
private Document toDocument(redis.clients.jedis.search.Document doc) {
var id = doc.getId().substring(this.config.prefix.length());
var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName)
: null;
var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) : "";
Map<String, Object> metadata = this.config.metadataFields.stream()
.map(MetadataField::name)
.filter(doc::hasProperty)
.collect(Collectors.toMap(Function.identity(), doc::getString));
// TODO: this seems wrong. The key is named "vector_store", but the value is the
// distance. Can we remove this after standardizing the metadata?
metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc));
return new Document(id, content, metadata);
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - similarityScore(doc));
return Document.builder()
.id(id)
.content(content)
.metadata(metadata)
.score((double) similarityScore(doc))
.build();
}
private float similarityScore(redis.clients.jedis.search.Document doc) {

View File

@@ -26,6 +26,7 @@ import java.util.UUID;
import com.redis.testcontainers.RedisStackContainer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import redis.clients.jedis.JedisPooled;
@@ -50,6 +51,7 @@ import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Julien Ruaux
* @author Eddú Meléndez
* @author Thomas Vitale
*/
@Testcontainers
class RedisVectorStoreIT {
@@ -105,8 +107,9 @@ class RedisVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME);
assertThat(resultDoc.getMetadata()).hasSize(3);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME,
DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -190,6 +193,7 @@ class RedisVectorStoreIT {
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -205,6 +209,7 @@ class RedisVectorStoreIT {
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
vectorStore.delete(List.of(document.getId()));
@@ -223,24 +228,23 @@ class RedisVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream()
.map(doc -> (Float) doc.getMetadata().get(RedisVectorStore.DISTANCE_FIELD_NAME))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME,
DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
});
}

View File

@@ -26,6 +26,7 @@ import java.util.stream.Stream;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentMetadata;
import org.typesense.api.Client;
import org.typesense.api.FieldTypes;
import org.typesense.model.CollectionResponse;
@@ -57,6 +58,7 @@ import org.springframework.util.Assert;
* @author Pablo Sanchidrian Herrera
* @author Soby Chacko
* @author Christian Tzolov
* @author Thomas Vitale
*/
public class TypesenseVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -212,8 +214,13 @@ public class TypesenseVectorStore extends AbstractObservationVectorStore impleme
String content = rawDocument.get(CONTENT_FIELD_NAME).toString();
Map<String, Object> metadata = rawDocument.get(METADATA_FIELD_NAME) instanceof Map
? (Map<String, Object>) rawDocument.get(METADATA_FIELD_NAME) : Map.of();
metadata.put("distance", hit.getVectorDistance());
return new Document(docId, content, metadata);
metadata.put(DocumentMetadata.DISTANCE.value(), hit.getVectorDistance());
return Document.builder()
.id(docId)
.content(content)
.metadata(metadata)
.score(1.0 - hit.getVectorDistance())
.build();
}))
.toList();

View File

@@ -26,6 +26,7 @@ import java.util.Map;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -96,7 +97,7 @@ public class TypesenseVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -114,7 +115,7 @@ public class TypesenseVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
vectorStore.delete(List.of(document.getId()));
@@ -211,21 +212,22 @@ public class TypesenseVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
float threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
((TypesenseVectorStore) vectorStore).dropCollection();

View File

@@ -45,6 +45,7 @@ import io.weaviate.client.v1.graphql.query.fields.Field;
import io.weaviate.client.v1.graphql.query.fields.Fields;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
@@ -73,11 +74,10 @@ import org.springframework.util.StringUtils;
* @author Eddú Meléndez
* @author Josh Long
* @author Soby Chacko
* @author Thomas Vitale
*/
public class WeaviateVectorStore extends AbstractObservationVectorStore {
public static final String DOCUMENT_METADATA_DISTANCE_KEY_NAME = "distance";
private static final String METADATA_FIELD_PREFIX = "meta_";
private static final String CONTENT_FIELD_NAME = "content";
@@ -367,7 +367,7 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore {
// Metadata
Map<String, Object> metadata = new HashMap<>();
metadata.put(DOCUMENT_METADATA_DISTANCE_KEY_NAME, 1 - certainty);
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - certainty);
try {
String metadataJson = (String) item.get(METADATA_FIELD_NAME);
@@ -382,10 +382,13 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore {
// Content
String content = (String) item.get(CONTENT_FIELD_NAME);
var document = new Document(id, content, metadata);
document.setEmbedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding)));
return document;
return Document.builder()
.id(id)
.content(content)
.metadata(metadata)
.embedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding)))
.score(certainty)
.build();
}
@Override

View File

@@ -26,6 +26,7 @@ import java.util.UUID;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.DocumentMetadata;
import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -101,7 +102,7 @@ public class WeaviateVectorStoreIT {
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
@@ -186,7 +187,7 @@ public class WeaviateVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
@@ -201,7 +202,7 @@ public class WeaviateVectorStoreIT {
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
vectorStore.delete(List.of(document.getId()));
@@ -222,23 +223,22 @@ public class WeaviateVectorStoreIT {
List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
List<Double> distances = fullResult.stream()
.map(doc -> (Double) doc.getMetadata().get("distance"))
.toList();
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
assertThat(distances).hasSize(3);
assertThat(scores).hasSize(3);
double threshold = (distances.get(0) + distances.get(1)) / 2;
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
List<Document> results = vectorStore.similaritySearch(
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
assertThat(resultDoc.getContent()).contains(
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
});
}