Support similarity scores in Document API
Document * Introduced “score” attribute in Document API. It stores the similarity score. * Consolidate “distance” metadata for Documents. It stores the distance measurement. * Adopted prefix-less naming convention in Document.Builder and deprecated old methods. * Deprecated the many overloaded Document constructors in favour of Document.Builder. Vector Stores * Every vector store implementation now configures a “score” attribute with the similarity score of the Document embedding. It also includes the “distance” metadata with the distance measurement. * Fixed error in Elasticsearch where distance and similarity were mixed up. * Added missing integration tests for SimpleVectorStore. * The Azure Vector Store and HanaDB Vector Store do not include those measurements because the product documentation do not include information about how the similarity score is returned, and without access to the cloud products I could not verify that via debugging. * Improved tests to actually assert the result of the similarity search based on the returned score. Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
This commit is contained in:
committed by
Mark Pollack
parent
50223d20e3
commit
fe58fd30eb
@@ -176,14 +176,14 @@ public class MarkdownDocumentReader implements DocumentReader {
|
||||
}
|
||||
|
||||
translateLineBreakToSpace();
|
||||
this.currentDocumentBuilder.withMetadata("category", "blockquote");
|
||||
this.currentDocumentBuilder.metadata("category", "blockquote");
|
||||
super.visit(blockQuote);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(Code code) {
|
||||
this.currentParagraphs.add(code.getLiteral());
|
||||
this.currentDocumentBuilder.withMetadata("category", "code_inline");
|
||||
this.currentDocumentBuilder.metadata("category", "code_inline");
|
||||
super.visit(code);
|
||||
}
|
||||
|
||||
@@ -195,8 +195,8 @@ public class MarkdownDocumentReader implements DocumentReader {
|
||||
|
||||
translateLineBreakToSpace();
|
||||
this.currentParagraphs.add(fencedCodeBlock.getLiteral());
|
||||
this.currentDocumentBuilder.withMetadata("category", "code_block");
|
||||
this.currentDocumentBuilder.withMetadata("lang", fencedCodeBlock.getInfo());
|
||||
this.currentDocumentBuilder.metadata("category", "code_block");
|
||||
this.currentDocumentBuilder.metadata("lang", fencedCodeBlock.getInfo());
|
||||
|
||||
buildAndFlush();
|
||||
|
||||
@@ -206,8 +206,8 @@ public class MarkdownDocumentReader implements DocumentReader {
|
||||
@Override
|
||||
public void visit(Text text) {
|
||||
if (text.getParent() instanceof Heading heading) {
|
||||
this.currentDocumentBuilder.withMetadata("category", "header_%d".formatted(heading.getLevel()))
|
||||
.withMetadata("title", text.getLiteral());
|
||||
this.currentDocumentBuilder.metadata("category", "header_%d".formatted(heading.getLevel()))
|
||||
.metadata("title", text.getLiteral());
|
||||
}
|
||||
else {
|
||||
this.currentParagraphs.add(text.getLiteral());
|
||||
@@ -226,9 +226,9 @@ public class MarkdownDocumentReader implements DocumentReader {
|
||||
if (!this.currentParagraphs.isEmpty()) {
|
||||
String content = String.join("", this.currentParagraphs);
|
||||
|
||||
Document.Builder builder = this.currentDocumentBuilder.withContent(content);
|
||||
Document.Builder builder = this.currentDocumentBuilder.content(content);
|
||||
|
||||
this.config.additionalMetadata.forEach(builder::withMetadata);
|
||||
this.config.additionalMetadata.forEach(builder::metadata);
|
||||
|
||||
Document document = builder.build();
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class VertexAiMultimodalEmbeddingModelIT {
|
||||
assertThat(this.multiModelEmbeddingModel).isNotNull();
|
||||
|
||||
var document = Document.builder()
|
||||
.withMedia(new Media(MimeTypeUtils.TEXT_PLAIN, URI.create("http://example.com/image.png").toURL()))
|
||||
.media(new Media(MimeTypeUtils.TEXT_PLAIN, URI.create("http://example.com/image.png").toURL()))
|
||||
.build();
|
||||
|
||||
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);
|
||||
@@ -135,7 +135,7 @@ class VertexAiMultimodalEmbeddingModelIT {
|
||||
void imageEmbedding() {
|
||||
|
||||
var document = Document.builder()
|
||||
.withMedia(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
|
||||
.media(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
|
||||
.build();
|
||||
|
||||
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);
|
||||
@@ -161,7 +161,7 @@ class VertexAiMultimodalEmbeddingModelIT {
|
||||
void videoEmbedding() {
|
||||
|
||||
var document = Document.builder()
|
||||
.withMedia(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
|
||||
.media(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
|
||||
.build();
|
||||
|
||||
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);
|
||||
@@ -186,9 +186,9 @@ class VertexAiMultimodalEmbeddingModelIT {
|
||||
void textImageAndVideoEmbedding() {
|
||||
|
||||
var document = Document.builder()
|
||||
.withContent("Hello World")
|
||||
.withMedia(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
|
||||
.withMedia(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
|
||||
.content("Hello World")
|
||||
.media(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")))
|
||||
.media(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")))
|
||||
.build();
|
||||
|
||||
DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document);
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -184,10 +185,14 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
|
||||
metadata.put(DOCUMENT_METADATA_CONVERSATION_ID, conversationId);
|
||||
metadata.put(DOCUMENT_METADATA_MESSAGE_TYPE, message.getMessageType().name());
|
||||
if (message instanceof UserMessage userMessage) {
|
||||
return new Document(userMessage.getContent(), userMessage.getMedia(), metadata);
|
||||
return Document.builder()
|
||||
.content(userMessage.getContent())
|
||||
.media(new ArrayList<>(userMessage.getMedia()))
|
||||
.metadata(metadata)
|
||||
.build();
|
||||
}
|
||||
else if (message instanceof AssistantMessage assistantMessage) {
|
||||
return new Document(assistantMessage.getContent(), metadata);
|
||||
return Document.builder().content(assistantMessage.getContent()).metadata(metadata).build();
|
||||
}
|
||||
throw new RuntimeException("Unknown message type: " + message.getMessageType());
|
||||
})
|
||||
|
||||
@@ -21,6 +21,7 @@ import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
@@ -31,6 +32,7 @@ import org.springframework.ai.document.id.IdGenerator;
|
||||
import org.springframework.ai.document.id.RandomIdGenerator;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.MediaContent;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
@@ -41,9 +43,9 @@ import org.springframework.util.StringUtils;
|
||||
@JsonIgnoreProperties({ "contentFormatter" })
|
||||
public class Document implements MediaContent {
|
||||
|
||||
public final static ContentFormatter DEFAULT_CONTENT_FORMATTER = DefaultContentFormatter.defaultConfig();
|
||||
public static final ContentFormatter DEFAULT_CONTENT_FORMATTER = DefaultContentFormatter.defaultConfig();
|
||||
|
||||
public final static String EMPTY_TEXT = "";
|
||||
public static final String EMPTY_TEXT = "";
|
||||
|
||||
/**
|
||||
* Unique ID
|
||||
@@ -61,7 +63,15 @@ public class Document implements MediaContent {
|
||||
* Metadata for the document. It should not be nested and values should be restricted
|
||||
* to string, int, float, boolean for simple use with Vector Dbs.
|
||||
*/
|
||||
private Map<String, Object> metadata;
|
||||
private final Map<String, Object> metadata;
|
||||
|
||||
/**
|
||||
* Measure of similarity between the document embedding and the query vector. The
|
||||
* higher the score, the more they are similar. It's the opposite of the distance
|
||||
* measure.
|
||||
*/
|
||||
@Nullable
|
||||
private final Double score;
|
||||
|
||||
/**
|
||||
* Embedding of the document. Note: ephemeral field.
|
||||
@@ -84,10 +94,18 @@ public class Document implements MediaContent {
|
||||
this(content, metadata, new RandomIdGenerator());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use builder instead: {@link Document#builder()}.
|
||||
*/
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Document(String content, Collection<Media> media, Map<String, Object> metadata) {
|
||||
this(new RandomIdGenerator().generateId(content, metadata), content, media, metadata);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use builder instead: {@link Document#builder()}.
|
||||
*/
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Document(String content, Map<String, Object> metadata, IdGenerator idGenerator) {
|
||||
this(idGenerator.generateId(content, metadata), content, metadata);
|
||||
}
|
||||
@@ -96,15 +114,33 @@ public class Document implements MediaContent {
|
||||
this(id, content, List.of(), metadata);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use builder instead: {@link Document#builder()}.
|
||||
*/
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Document(String id, String content, Collection<Media> media, Map<String, Object> metadata) {
|
||||
Assert.hasText(id, "id must not be null or empty");
|
||||
Assert.notNull(content, "content must not be null");
|
||||
Assert.notNull(metadata, "metadata must not be null");
|
||||
this(id, content, media, metadata, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use builder instead: {@link Document#builder()}.
|
||||
*/
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Document(String id, String content, @Nullable Collection<Media> media,
|
||||
@Nullable Map<String, Object> metadata, @Nullable Double score) {
|
||||
Assert.hasText(id, "id cannot be null or empty");
|
||||
Assert.notNull(content, "content cannot be null");
|
||||
Assert.notNull(media, "media cannot be null");
|
||||
Assert.noNullElements(media, "media cannot have null elements");
|
||||
Assert.notNull(metadata, "metadata cannot be null");
|
||||
Assert.noNullElements(metadata.keySet(), "metadata cannot have null keys");
|
||||
Assert.noNullElements(metadata.values(), "metadata cannot have null values");
|
||||
|
||||
this.id = id;
|
||||
this.content = content;
|
||||
this.media = media;
|
||||
this.metadata = metadata;
|
||||
this.media = media != null ? media : List.of();
|
||||
this.metadata = metadata != null ? metadata : new HashMap<>();
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
@@ -149,6 +185,11 @@ public class Document implements MediaContent {
|
||||
return this.metadata;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Double getScore() {
|
||||
return this.score;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the embedding that were calculated.
|
||||
* @deprecated We are considering getting rid of this, please comment on
|
||||
@@ -172,6 +213,7 @@ public class Document implements MediaContent {
|
||||
* @return the current ContentFormatter instance used for formatting the document
|
||||
* content.
|
||||
*/
|
||||
@Deprecated(since = "1.0.0-M4")
|
||||
public ContentFormatter getContentFormatter() {
|
||||
return this.contentFormatter;
|
||||
}
|
||||
@@ -184,59 +226,34 @@ public class Document implements MediaContent {
|
||||
this.contentFormatter = contentFormatter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + ((this.id == null) ? 0 : this.id.hashCode());
|
||||
result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode());
|
||||
result = prime * result + ((this.content == null) ? 0 : this.content.hashCode());
|
||||
return result;
|
||||
public Builder mutate() {
|
||||
return new Builder().id(this.id)
|
||||
.content(this.content)
|
||||
.media(new ArrayList<>(this.media))
|
||||
.metadata(this.metadata)
|
||||
.score(this.score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
public boolean equals(Object o) {
|
||||
if (o == null || this.getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
Document other = (Document) obj;
|
||||
if (this.id == null) {
|
||||
if (other.id != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.id.equals(other.id)) {
|
||||
return false;
|
||||
}
|
||||
if (this.metadata == null) {
|
||||
if (other.metadata != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.metadata.equals(other.metadata)) {
|
||||
return false;
|
||||
}
|
||||
if (this.content == null) {
|
||||
if (other.content != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!this.content.equals(other.content)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
Document document = (Document) o;
|
||||
return Objects.equals(this.id, document.id) && Objects.equals(this.content, document.content)
|
||||
&& Objects.equals(this.media, document.media) && Objects.equals(this.metadata, document.metadata)
|
||||
&& Objects.equals(this.score, document.score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(this.id, this.content, this.media, this.metadata, this.score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content
|
||||
+ '\'' + ", media=" + this.media + '}';
|
||||
return "Document{" + "id='" + this.id + '\'' + ", content='" + this.content + '\'' + ", media=" + this.media
|
||||
+ ", metadata=" + this.metadata + ", score=" + this.score + '}';
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
@@ -249,56 +266,103 @@ public class Document implements MediaContent {
|
||||
|
||||
private Map<String, Object> metadata = new HashMap<>();
|
||||
|
||||
private float[] embedding = new float[0];
|
||||
|
||||
@Nullable
|
||||
private Double score;
|
||||
|
||||
private IdGenerator idGenerator = new RandomIdGenerator();
|
||||
|
||||
public Builder withIdGenerator(IdGenerator idGenerator) {
|
||||
Assert.notNull(idGenerator, "idGenerator must not be null");
|
||||
public Builder idGenerator(IdGenerator idGenerator) {
|
||||
Assert.notNull(idGenerator, "idGenerator cannot be null");
|
||||
this.idGenerator = idGenerator;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withId(String id) {
|
||||
Assert.hasText(id, "id must not be null or empty");
|
||||
public Builder id(String id) {
|
||||
Assert.hasText(id, "id cannot be null or empty");
|
||||
this.id = id;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withContent(String content) {
|
||||
Assert.notNull(content, "content must not be null");
|
||||
public Builder content(String content) {
|
||||
this.content = content;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMedia(List<Media> media) {
|
||||
Assert.notNull(media, "media must not be null");
|
||||
this.media = media;
|
||||
public Builder media(List<Media> media) {
|
||||
this.media.addAll(media);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMedia(Media media) {
|
||||
Assert.notNull(media, "media must not be null");
|
||||
this.media.add(media);
|
||||
public Builder media(Media... media) {
|
||||
Assert.noNullElements(media, "media cannot contain null elements");
|
||||
this.media.addAll(List.of(media));
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMetadata(Map<String, Object> metadata) {
|
||||
Assert.notNull(metadata, "metadata must not be null");
|
||||
public Builder metadata(Map<String, Object> metadata) {
|
||||
this.metadata = metadata;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMetadata(String key, Object value) {
|
||||
Assert.notNull(key, "key must not be null");
|
||||
Assert.notNull(value, "value must not be null");
|
||||
public Builder metadata(String key, Object value) {
|
||||
this.metadata.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder embedding(float[] embedding) {
|
||||
this.embedding = embedding;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder score(@Nullable Double score) {
|
||||
this.score = score;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Builder withIdGenerator(IdGenerator idGenerator) {
|
||||
return idGenerator(idGenerator);
|
||||
}
|
||||
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Builder withId(String id) {
|
||||
return id(id);
|
||||
}
|
||||
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Builder withContent(String content) {
|
||||
return content(content);
|
||||
}
|
||||
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Builder withMedia(List<Media> media) {
|
||||
return media(media);
|
||||
}
|
||||
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Builder withMedia(Media media) {
|
||||
return media(media);
|
||||
}
|
||||
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Builder withMetadata(Map<String, Object> metadata) {
|
||||
return metadata(metadata);
|
||||
}
|
||||
|
||||
@Deprecated(since = "1.0.0-M5", forRemoval = true)
|
||||
public Builder withMetadata(String key, Object value) {
|
||||
return metadata(key, value);
|
||||
}
|
||||
|
||||
public Document build() {
|
||||
if (!StringUtils.hasText(this.id)) {
|
||||
this.id = this.idGenerator.generateId(this.content, this.metadata);
|
||||
}
|
||||
return new Document(this.id, this.content, this.media, this.metadata);
|
||||
var document = new Document(this.id, this.content, this.media, this.metadata, this.score);
|
||||
document.setEmbedding(this.embedding);
|
||||
return document;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.document;
|
||||
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
|
||||
/**
|
||||
* Common set of metadata keys used in {@link Document}s by {@link DocumentReader}s and
|
||||
* {@link VectorStore}s.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public enum DocumentMetadata {
|
||||
|
||||
// @formatter:off
|
||||
|
||||
/**
|
||||
* Measure of distance between the document embedding and the query vector.
|
||||
* The lower the distance, the more they are similar.
|
||||
* It's the opposite of the similarity score.
|
||||
*/
|
||||
DISTANCE("distance");
|
||||
|
||||
private final String value;
|
||||
|
||||
DocumentMetadata(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
public String value() {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
// @formatter:on
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return value;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
@NonNullApi
|
||||
@NonNullFields
|
||||
package org.springframework.ai.document;
|
||||
|
||||
import org.springframework.lang.NonNullApi;
|
||||
import org.springframework.lang.NonNullFields;
|
||||
@@ -22,6 +22,7 @@ import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.filter.Filter;
|
||||
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
|
||||
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
@@ -30,6 +31,7 @@ import org.springframework.util.Assert;
|
||||
* instance and then apply the 'with' methods to alter the default values.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public final class SearchRequest {
|
||||
|
||||
@@ -45,12 +47,13 @@ public final class SearchRequest {
|
||||
*/
|
||||
public static final int DEFAULT_TOP_K = 4;
|
||||
|
||||
public String query;
|
||||
private String query;
|
||||
|
||||
private int topK = DEFAULT_TOP_K;
|
||||
|
||||
private double similarityThreshold = SIMILARITY_THRESHOLD_ACCEPT_ALL;
|
||||
|
||||
@Nullable
|
||||
private Filter.Expression filterExpression;
|
||||
|
||||
private SearchRequest(String query) {
|
||||
@@ -186,7 +189,7 @@ public final class SearchRequest {
|
||||
* filter criteria. The 'null' value stands for no expression filters.
|
||||
* @return this builder.
|
||||
*/
|
||||
public SearchRequest withFilterExpression(Filter.Expression expression) {
|
||||
public SearchRequest withFilterExpression(@Nullable Filter.Expression expression) {
|
||||
this.filterExpression = expression;
|
||||
return this;
|
||||
}
|
||||
@@ -225,7 +228,7 @@ public final class SearchRequest {
|
||||
* 'null' value stands for no expression filters.
|
||||
* @return this.builder
|
||||
*/
|
||||
public SearchRequest withFilterExpression(String textExpression) {
|
||||
public SearchRequest withFilterExpression(@Nullable String textExpression) {
|
||||
this.filterExpression = (textExpression != null) ? new FilterExpressionTextParser().parse(textExpression)
|
||||
: null;
|
||||
return this;
|
||||
@@ -243,6 +246,7 @@ public final class SearchRequest {
|
||||
return this.similarityThreshold;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Filter.Expression getFilterExpression() {
|
||||
return this.filterExpression;
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.observation.conventions.VectorStoreProvider;
|
||||
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
|
||||
@@ -68,6 +69,7 @@ import org.springframework.core.io.Resource;
|
||||
* @author Christian Tzolov
|
||||
* @author Sebastien Deleuze
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class SimpleVectorStore extends AbstractObservationVectorStore {
|
||||
|
||||
@@ -127,12 +129,11 @@ public class SimpleVectorStore extends AbstractObservationVectorStore {
|
||||
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
|
||||
return this.store.values()
|
||||
.stream()
|
||||
.map(entry -> new Similarity(entry,
|
||||
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
|
||||
.filter(s -> s.score >= request.getSimilarityThreshold())
|
||||
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
|
||||
.map(content -> content
|
||||
.toDocument(EmbeddingMath.cosineSimilarity(userQueryEmbedding, content.getEmbedding())))
|
||||
.filter(document -> document.getScore() >= request.getSimilarityThreshold())
|
||||
.sorted(Comparator.comparing(Document::getScore).reversed())
|
||||
.limit(request.getTopK())
|
||||
.map(s -> s.getDocument())
|
||||
.toList();
|
||||
}
|
||||
|
||||
@@ -235,28 +236,7 @@ public class SimpleVectorStore extends AbstractObservationVectorStore {
|
||||
.withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value());
|
||||
}
|
||||
|
||||
public static class Similarity {
|
||||
|
||||
private SimpleVectorStoreContent content;
|
||||
|
||||
private double score;
|
||||
|
||||
public Similarity(SimpleVectorStoreContent content, double score) {
|
||||
this.content = content;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
Document getDocument() {
|
||||
return Document.builder()
|
||||
.withId(this.content.getId())
|
||||
.withContent(this.content.getContent())
|
||||
.withMetadata(this.content.getMetadata())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public final class EmbeddingMath {
|
||||
public static final class EmbeddingMath {
|
||||
|
||||
private EmbeddingMath() {
|
||||
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
|
||||
|
||||
@@ -25,6 +25,8 @@ import java.util.Objects;
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.document.id.IdGenerator;
|
||||
import org.springframework.ai.document.id.RandomIdGenerator;
|
||||
import org.springframework.ai.model.Content;
|
||||
@@ -135,6 +137,12 @@ public final class SimpleVectorStoreContent implements Content {
|
||||
return Arrays.copyOf(this.embedding, this.embedding.length);
|
||||
}
|
||||
|
||||
public Document toDocument(Double score) {
|
||||
var metadata = new HashMap<>(this.metadata);
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - score);
|
||||
return Document.builder().id(this.id).content(this.content).metadata(metadata).score(score).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
@NonNullApi
|
||||
@NonNullFields
|
||||
package org.springframework.ai.vectorstore;
|
||||
|
||||
import org.springframework.lang.NonNullApi;
|
||||
import org.springframework.lang.NonNullFields;
|
||||
@@ -70,8 +70,8 @@ class RetrievalAugmentationAdvisorTests {
|
||||
.build());
|
||||
|
||||
// Document Retriever
|
||||
var documentContext = List.of(Document.builder().withId("1").withContent("doc1").build(),
|
||||
Document.builder().withId("2").withContent("doc2").build());
|
||||
var documentContext = List.of(Document.builder().id("1").content("doc1").build(),
|
||||
Document.builder().id("2").content("doc2").build());
|
||||
var documentRetriever = mock(DocumentRetriever.class);
|
||||
var queryCaptor = ArgumentCaptor.forClass(Query.class);
|
||||
given(documentRetriever.retrieve(queryCaptor.capture())).willReturn(documentContext);
|
||||
|
||||
@@ -42,13 +42,11 @@ public class DocumentBuilderTests {
|
||||
URL mediaUrl2 = new URL("http://type2");
|
||||
Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1);
|
||||
Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2);
|
||||
List<Media> mediaList = List.of(media1, media2);
|
||||
return mediaList;
|
||||
return List.of(media1, media2);
|
||||
}
|
||||
catch (MalformedURLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
@@ -58,32 +56,26 @@ public class DocumentBuilderTests {
|
||||
|
||||
@Test
|
||||
void testWithIdGenerator() {
|
||||
IdGenerator mockGenerator = new IdGenerator() {
|
||||
IdGenerator mockGenerator = contents -> "mockedId";
|
||||
|
||||
@Override
|
||||
public String generateId(Object... contents) {
|
||||
return "mockedId";
|
||||
}
|
||||
};
|
||||
|
||||
Document.Builder result = this.builder.withIdGenerator(mockGenerator);
|
||||
Document.Builder result = this.builder.idGenerator(mockGenerator);
|
||||
|
||||
assertThat(result).isSameAs(this.builder);
|
||||
|
||||
Document document = result.withContent("Test content").withMetadata("key", "value").build();
|
||||
Document document = result.content("Test content").metadata("key", "value").build();
|
||||
|
||||
assertThat(document.getId()).isEqualTo("mockedId");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testWithIdGeneratorNull() {
|
||||
assertThatThrownBy(() -> this.builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("idGenerator must not be null");
|
||||
assertThatThrownBy(() -> this.builder.idGenerator(null).build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("idGenerator cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testWithId() {
|
||||
Document.Builder result = this.builder.withId("testId");
|
||||
Document.Builder result = this.builder.id("testId");
|
||||
|
||||
assertThat(result).isSameAs(this.builder);
|
||||
assertThat(result.build().getId()).isEqualTo("testId");
|
||||
@@ -91,16 +83,16 @@ public class DocumentBuilderTests {
|
||||
|
||||
@Test
|
||||
void testWithIdNullOrEmpty() {
|
||||
assertThatThrownBy(() -> this.builder.withId(null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("id must not be null or empty");
|
||||
assertThatThrownBy(() -> this.builder.id(null).build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("id cannot be null or empty");
|
||||
|
||||
assertThatThrownBy(() -> this.builder.withId("")).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("id must not be null or empty");
|
||||
assertThatThrownBy(() -> this.builder.id("").build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("id cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testWithContent() {
|
||||
Document.Builder result = this.builder.withContent("Test content");
|
||||
Document.Builder result = this.builder.content("Test content");
|
||||
|
||||
assertThat(result).isSameAs(this.builder);
|
||||
assertThat(result.build().getContent()).isEqualTo("Test content");
|
||||
@@ -108,14 +100,14 @@ public class DocumentBuilderTests {
|
||||
|
||||
@Test
|
||||
void testWithContentNull() {
|
||||
assertThatThrownBy(() -> this.builder.withContent(null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("content must not be null");
|
||||
assertThatThrownBy(() -> this.builder.content(null).build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("content cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testWithMediaList() {
|
||||
List<Media> mediaList = getMediaList();
|
||||
Document.Builder result = this.builder.withMedia(mediaList);
|
||||
Document.Builder result = this.builder.media(mediaList);
|
||||
|
||||
assertThat(result).isSameAs(this.builder);
|
||||
assertThat(result.build().getMedia()).isEqualTo(mediaList);
|
||||
@@ -123,9 +115,9 @@ public class DocumentBuilderTests {
|
||||
|
||||
@Test
|
||||
void testWithMediaListNull() {
|
||||
assertThatThrownBy(() -> this.builder.withMedia((List<Media>) null))
|
||||
assertThatThrownBy(() -> this.builder.media((List<Media>) null).build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("media must not be null");
|
||||
.hasMessageContaining("media cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -133,7 +125,7 @@ public class DocumentBuilderTests {
|
||||
URL mediaUrl = new URL("http://test");
|
||||
Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl);
|
||||
|
||||
Document.Builder result = this.builder.withMedia(media);
|
||||
Document.Builder result = this.builder.media(media);
|
||||
|
||||
assertThat(result).isSameAs(this.builder);
|
||||
assertThat(result.build().getMedia()).contains(media);
|
||||
@@ -141,8 +133,8 @@ public class DocumentBuilderTests {
|
||||
|
||||
@Test
|
||||
void testWithMediaSingleNull() {
|
||||
assertThatThrownBy(() -> this.builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("media must not be null");
|
||||
assertThatThrownBy(() -> this.builder.media((Media) null).build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("media cannot contain null elements");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -150,7 +142,7 @@ public class DocumentBuilderTests {
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
metadata.put("key1", "value1");
|
||||
metadata.put("key2", 2);
|
||||
Document.Builder result = this.builder.withMetadata(metadata);
|
||||
Document.Builder result = this.builder.metadata(metadata);
|
||||
|
||||
assertThat(result).isSameAs(this.builder);
|
||||
assertThat(result.build().getMetadata()).isEqualTo(metadata);
|
||||
@@ -158,47 +150,51 @@ public class DocumentBuilderTests {
|
||||
|
||||
@Test
|
||||
void testWithMetadataMapNull() {
|
||||
assertThatThrownBy(() -> this.builder.withMetadata((Map<String, Object>) null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("metadata must not be null");
|
||||
assertThatThrownBy(() -> this.builder.metadata(null).build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("metadata cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testWithMetadataKeyValue() {
|
||||
Document.Builder result = this.builder.withMetadata("key", "value");
|
||||
Document.Builder result = this.builder.metadata("key", "value");
|
||||
|
||||
assertThat(result).isSameAs(this.builder);
|
||||
assertThat(result.build().getMetadata()).containsEntry("key", "value");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testWithMetadataKeyValueNull() {
|
||||
assertThatThrownBy(() -> this.builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("key must not be null");
|
||||
void testWithMetadataKeyNull() {
|
||||
assertThatThrownBy(() -> this.builder.metadata(null, "value").build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("metadata cannot have null keys");
|
||||
}
|
||||
|
||||
assertThatThrownBy(() -> this.builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("value must not be null");
|
||||
@Test
|
||||
void testWithMetadataValueNull() {
|
||||
assertThatThrownBy(() -> this.builder.metadata("key", null).build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("metadata cannot have null values");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBuildWithoutId() {
|
||||
Document document = this.builder.withContent("Test content").build();
|
||||
Document document = this.builder.content("Test content").build();
|
||||
|
||||
assertThat(document.getId()).isNotNull().isNotEmpty();
|
||||
assertThat(document.getContent()).isEqualTo("Test content");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBuildWithAllProperties() throws MalformedURLException {
|
||||
void testBuildWithAllProperties() {
|
||||
|
||||
List<Media> mediaList = getMediaList();
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
metadata.put("key", "value");
|
||||
|
||||
Document document = this.builder.withId("customId")
|
||||
.withContent("Test content")
|
||||
.withMedia(mediaList)
|
||||
.withMetadata(metadata)
|
||||
Document document = this.builder.id("customId")
|
||||
.content("Test content")
|
||||
.media(mediaList)
|
||||
.metadata(metadata)
|
||||
.build();
|
||||
|
||||
assertThat(document.getId()).isEqualTo("customId");
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.document;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class DocumentTests {
|
||||
|
||||
@Test
|
||||
void testScore() {
|
||||
Double score = 0.95;
|
||||
Document document = Document.builder().content("Test content").score(score).build();
|
||||
|
||||
assertThat(document.getScore()).isEqualTo(score);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNullScore() {
|
||||
Document document = Document.builder().content("Test content").score(null).build();
|
||||
|
||||
assertThat(document.getScore()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMediaBuilderIsAdditive() {
|
||||
try {
|
||||
URL mediaUrl1 = new URL("http://type1");
|
||||
URL mediaUrl2 = new URL("http://type2");
|
||||
URL mediaUrl3 = new URL("http://type3");
|
||||
|
||||
Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1);
|
||||
Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2);
|
||||
Media media3 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl3);
|
||||
|
||||
Document document = Document.builder().media(media1).media(media2).media(List.of(media3)).build();
|
||||
|
||||
assertThat(document.getMedia()).hasSize(3).containsExactly(media1, media2, media3);
|
||||
|
||||
}
|
||||
catch (MalformedURLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMutate() {
|
||||
List<Media> mediaList = getMediaList();
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
metadata.put("key", "value");
|
||||
Double score = 0.95;
|
||||
|
||||
Document original = Document.builder()
|
||||
.id("customId")
|
||||
.content("Test content")
|
||||
.media(mediaList)
|
||||
.metadata(metadata)
|
||||
.score(score)
|
||||
.build();
|
||||
|
||||
Document mutated = original.mutate().build();
|
||||
|
||||
assertThat(mutated).isNotSameAs(original).usingRecursiveComparison().isEqualTo(original);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEquals() {
|
||||
List<Media> mediaList = getMediaList();
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
metadata.put("key", "value");
|
||||
Double score = 0.95;
|
||||
|
||||
Document doc1 = Document.builder()
|
||||
.id("customId")
|
||||
.content("Test content")
|
||||
.media(mediaList)
|
||||
.metadata(metadata)
|
||||
.score(score)
|
||||
.build();
|
||||
|
||||
Document doc2 = Document.builder()
|
||||
.id("customId")
|
||||
.content("Test content")
|
||||
.media(mediaList)
|
||||
.metadata(metadata)
|
||||
.score(score)
|
||||
.build();
|
||||
|
||||
Document differentDoc = Document.builder()
|
||||
.id("differentId")
|
||||
.content("Different content")
|
||||
.media(mediaList)
|
||||
.metadata(metadata)
|
||||
.score(score)
|
||||
.build();
|
||||
|
||||
assertThat(doc1).isEqualTo(doc2).isNotEqualTo(differentDoc).isNotEqualTo(null).isNotEqualTo(new Object());
|
||||
|
||||
assertThat(doc1.hashCode()).isEqualTo(doc2.hashCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEmptyDocument() {
|
||||
Document emptyDoc = Document.builder().build();
|
||||
|
||||
assertThat(emptyDoc.getContent()).isEqualTo(Document.EMPTY_TEXT).isEmpty();
|
||||
assertThat(emptyDoc.getMedia()).isEmpty();
|
||||
assertThat(emptyDoc.getMetadata()).isEmpty();
|
||||
assertThat(emptyDoc.getScore()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testToString() {
|
||||
List<Media> mediaList = getMediaList();
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
metadata.put("key", "value");
|
||||
Double score = 0.95;
|
||||
|
||||
Document document = Document.builder()
|
||||
.id("customId")
|
||||
.content("Test content")
|
||||
.media(mediaList)
|
||||
.metadata(metadata)
|
||||
.score(score)
|
||||
.build();
|
||||
|
||||
String toString = document.toString();
|
||||
|
||||
assertThat(toString).contains("id='customId'")
|
||||
.contains("content='Test content'")
|
||||
.contains("media=" + mediaList)
|
||||
.contains("metadata=" + metadata)
|
||||
.contains("score=" + score);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testToStringWithEmptyDocument() {
|
||||
Document emptyDoc = Document.builder().build();
|
||||
|
||||
String toString = emptyDoc.toString();
|
||||
|
||||
assertThat(toString).contains("content=''").contains("media=[]").contains("metadata={}").contains("score=null");
|
||||
}
|
||||
|
||||
private static List<Media> getMediaList() {
|
||||
try {
|
||||
URL mediaUrl1 = new URL("http://type1");
|
||||
URL mediaUrl2 = new URL("http://type2");
|
||||
Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1);
|
||||
Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2);
|
||||
return List.of(media1, media2);
|
||||
}
|
||||
catch (MalformedURLException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -27,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class SimpleVectorStoreSimilarityTests {
|
||||
|
||||
@@ -38,8 +39,7 @@ public class SimpleVectorStoreSimilarityTests {
|
||||
|
||||
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata,
|
||||
testEmbedding);
|
||||
SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d);
|
||||
Document document = similarity.getDocument();
|
||||
Document document = storeContent.toDocument(0.6);
|
||||
assertThat(document).isNotNull();
|
||||
assertThat(document.getId()).isEqualTo("1");
|
||||
assertThat(document.getContent()).isEqualTo("hello, how are you?");
|
||||
|
||||
@@ -62,11 +62,7 @@ class SimpleVectorStoreTests {
|
||||
|
||||
@Test
|
||||
void shouldAddAndRetrieveDocument() {
|
||||
Document doc = Document.builder()
|
||||
.withId("1")
|
||||
.withContent("test content")
|
||||
.withMetadata(Map.of("key", "value"))
|
||||
.build();
|
||||
Document doc = Document.builder().id("1").content("test content").metadata(Map.of("key", "value")).build();
|
||||
|
||||
this.vectorStore.add(List.of(doc));
|
||||
|
||||
@@ -80,8 +76,8 @@ class SimpleVectorStoreTests {
|
||||
|
||||
@Test
|
||||
void shouldAddMultipleDocuments() {
|
||||
List<Document> docs = Arrays.asList(Document.builder().withId("1").withContent("first").build(),
|
||||
Document.builder().withId("2").withContent("second").build());
|
||||
List<Document> docs = Arrays.asList(Document.builder().id("1").content("first").build(),
|
||||
Document.builder().id("2").content("second").build());
|
||||
|
||||
this.vectorStore.add(docs);
|
||||
|
||||
@@ -104,7 +100,7 @@ class SimpleVectorStoreTests {
|
||||
|
||||
@Test
|
||||
void shouldDeleteDocuments() {
|
||||
Document doc = Document.builder().withId("1").withContent("test content").build();
|
||||
Document doc = Document.builder().id("1").content("test content").build();
|
||||
|
||||
this.vectorStore.add(List.of(doc));
|
||||
assertThat(this.vectorStore.similaritySearch("test")).hasSize(1);
|
||||
@@ -125,7 +121,7 @@ class SimpleVectorStoreTests {
|
||||
// Configure mock to return different embeddings for different queries
|
||||
when(this.mockEmbeddingModel.embed("query")).thenReturn(new float[] { 0.9f, 0.9f, 0.9f });
|
||||
|
||||
Document doc = Document.builder().withId("1").withContent("test content").build();
|
||||
Document doc = Document.builder().id("1").content("test content").build();
|
||||
|
||||
this.vectorStore.add(List.of(doc));
|
||||
|
||||
@@ -138,9 +134,9 @@ class SimpleVectorStoreTests {
|
||||
@Test
|
||||
void shouldSaveAndLoadVectorStore() throws IOException {
|
||||
Document doc = Document.builder()
|
||||
.withId("1")
|
||||
.withContent("test content")
|
||||
.withMetadata(new HashMap<>(Map.of("key", "value")))
|
||||
.id("1")
|
||||
.content("test content")
|
||||
.metadata(new HashMap<>(Map.of("key", "value")))
|
||||
.build();
|
||||
|
||||
this.vectorStore.add(List.of(doc));
|
||||
@@ -185,7 +181,7 @@ class SimpleVectorStoreTests {
|
||||
for (int i = 0; i < numThreads; i++) {
|
||||
final String id = String.valueOf(i);
|
||||
threads[i] = new Thread(() -> {
|
||||
Document doc = Document.builder().withId(id).withContent("content " + id).build();
|
||||
Document doc = Document.builder().id(id).content("content " + id).build();
|
||||
this.vectorStore.add(List.of(doc));
|
||||
});
|
||||
threads[i].start();
|
||||
|
||||
@@ -54,7 +54,6 @@
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
|
||||
@@ -76,6 +75,13 @@
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-test</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-testcontainers</artifactId>
|
||||
|
||||
@@ -16,7 +16,10 @@
|
||||
|
||||
package org.springframework.ai.integration.tests;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Import;
|
||||
|
||||
/**
|
||||
@@ -28,4 +31,9 @@ import org.springframework.context.annotation.Import;
|
||||
@Import(TestcontainersConfiguration.class)
|
||||
public class TestApplication {
|
||||
|
||||
@Bean
|
||||
SimpleVectorStore simpleVectorStore(EmbeddingModel embeddingModel) {
|
||||
return new SimpleVectorStore(embeddingModel);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.integration.tests.vectorstore;
|
||||
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.integration.tests.TestApplication;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Integration tests for {@link SimpleVectorStore}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@SpringBootTest(classes = TestApplication.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
|
||||
public class SimpleVectorStoreIT {
|
||||
|
||||
@Autowired
|
||||
private SimpleVectorStore vectorStore;
|
||||
|
||||
List<Document> documents = List.of(
|
||||
Document.builder()
|
||||
.id("471a8c78-549a-4b2c-bce5-ef3ae6579be3")
|
||||
.content(getText("classpath:/test/data/spring.ai.txt"))
|
||||
.metadata(Map.of("meta1", "meta1"))
|
||||
.build(),
|
||||
Document.builder()
|
||||
.id("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f")
|
||||
.content(getText("classpath:/test/data/time.shelter.txt"))
|
||||
.metadata(Map.of())
|
||||
.build(),
|
||||
Document.builder()
|
||||
.id("d0237682-1150-44ff-b4d2-1be9b1731ee5")
|
||||
.content(getText("classpath:/test/data/great.depression.txt"))
|
||||
.metadata(Map.of("meta2", "meta2"))
|
||||
.build());
|
||||
|
||||
public static String getText(String uri) {
|
||||
var resource = new DefaultResourceLoader().getResource(uri);
|
||||
try {
|
||||
return resource.getContentAsString(StandardCharsets.UTF_8);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void setUp() {
|
||||
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void searchWithThreshold() {
|
||||
Document document = Document.builder()
|
||||
.id(UUID.randomUUID().toString())
|
||||
.content("Spring AI rocks!!")
|
||||
.metadata("meta1", "meta1")
|
||||
.build();
|
||||
|
||||
vectorStore.add(List.of(document));
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = Document.builder()
|
||||
.id(document.getId())
|
||||
.content("The World is Big and Salvation Lurks Around the Corner")
|
||||
.metadata("meta2", "meta2")
|
||||
.build();
|
||||
|
||||
vectorStore.add(List.of(sameIdDocument));
|
||||
|
||||
results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,6 +18,7 @@ package org.springframework.ai.autoconfigure.vectorstore.pinecone;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.vectorstore.PineconeVectorStore;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
@@ -25,6 +26,7 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
* Configuration properties for Pinecone Vector Store.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@ConfigurationProperties(PineconeVectorStoreProperties.CONFIG_PREFIX)
|
||||
public class PineconeVectorStoreProperties {
|
||||
@@ -43,7 +45,7 @@ public class PineconeVectorStoreProperties {
|
||||
|
||||
private String contentFieldName = PineconeVectorStore.CONTENT_FIELD_NAME;
|
||||
|
||||
private String distanceMetadataFieldName = PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME;
|
||||
private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value();
|
||||
|
||||
private Duration serverSideTimeout = Duration.ofSeconds(20);
|
||||
|
||||
|
||||
@@ -20,12 +20,14 @@ import java.time.Duration;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.vectorstore.PineconeVectorStore;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class PineconeVectorStorePropertiesTests {
|
||||
|
||||
@@ -39,7 +41,7 @@ public class PineconeVectorStorePropertiesTests {
|
||||
assertThat(props.getIndexName()).isNull();
|
||||
assertThat(props.getServerSideTimeout()).isEqualTo(Duration.ofSeconds(20));
|
||||
assertThat(props.getContentFieldName()).isEqualTo(PineconeVectorStore.CONTENT_FIELD_NAME);
|
||||
assertThat(props.getDistanceMetadataFieldName()).isEqualTo(PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME);
|
||||
assertThat(props.getDistanceMetadataFieldName()).isEqualTo(DocumentMetadata.DISTANCE.value());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -74,6 +74,7 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv
|
||||
*
|
||||
* @author Theo van Kraay
|
||||
* @author Soby Chacko
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class CosmosDBVectorStore extends AbstractObservationVectorStore implements AutoCloseable {
|
||||
@@ -338,7 +339,7 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen
|
||||
.block();
|
||||
// Convert JsonNode to Document
|
||||
List<Document> docs = documents.stream()
|
||||
.map(doc -> new Document(doc.get("id").asText(), doc.get("content").asText(), new HashMap<>()))
|
||||
.map(doc -> Document.builder().id(doc.get("id").asText()).content(doc.get("content").asText()).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return docs != null ? docs : List.of();
|
||||
|
||||
@@ -42,6 +42,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* @author Theo van Kraay
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.vectorstore;
|
||||
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
/**
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public final class CosmosDbImage {
|
||||
|
||||
// It must always be "latest" or else Azure locks the image after a while. See:
|
||||
// https://github.com/Azure/azure-cosmos-db-emulator-docker/issues/60
|
||||
public static final DockerImageName DEFAULT_IMAGE = DockerImageName
|
||||
.parse("mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:latest");
|
||||
|
||||
private CosmosDbImage() {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -47,6 +47,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -96,8 +97,6 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements
|
||||
|
||||
private static final String METADATA_FIELD_NAME = "metadata";
|
||||
|
||||
private static final String DISTANCE_METADATA_FIELD_NAME = "distance";
|
||||
|
||||
private static final int DEFAULT_TOP_K = 4;
|
||||
|
||||
private static final Double DEFAULT_SIMILARITY_THRESHOLD = 0.0;
|
||||
@@ -321,13 +320,15 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements
|
||||
|
||||
}) : Map.of();
|
||||
|
||||
metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore());
|
||||
|
||||
final Document doc = new Document(entry.id(), entry.content(), metadata);
|
||||
doc.setEmbedding(EmbeddingUtils.toPrimitive(entry.embedding()));
|
||||
|
||||
return doc;
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - result.getScore());
|
||||
|
||||
return Document.builder()
|
||||
.id(entry.id())
|
||||
.content(entry.content)
|
||||
.metadata(metadata)
|
||||
.score(result.getScore())
|
||||
.embedding(EmbeddingUtils.toPrimitive(entry.embedding))
|
||||
.build();
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.transformers.TransformersEmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
@@ -52,6 +53,7 @@ import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_API_KEY", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+")
|
||||
@@ -103,7 +105,7 @@ public class AzureVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -224,7 +226,7 @@ public class AzureVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -245,7 +247,7 @@ public class AzureVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
@@ -271,21 +273,22 @@ public class AzureVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
|
||||
@@ -45,6 +45,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -106,8 +107,6 @@ import org.springframework.ai.vectorstore.observation.VectorStoreObservationConv
|
||||
*/
|
||||
public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable {
|
||||
|
||||
public static final String SIMILARITY_FIELD_NAME = "similarity_score";
|
||||
|
||||
public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates";
|
||||
|
||||
public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search";
|
||||
@@ -252,14 +251,19 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
|
||||
break;
|
||||
}
|
||||
Map<String, Object> docFields = new HashMap<>();
|
||||
docFields.put(SIMILARITY_FIELD_NAME, score);
|
||||
docFields.put(DocumentMetadata.DISTANCE.value(), 1 - score);
|
||||
for (var metadata : this.conf.schema.metadataColumns()) {
|
||||
var value = row.get(metadata.name(), metadata.javaType());
|
||||
if (null != value) {
|
||||
docFields.put(metadata.name(), value);
|
||||
}
|
||||
}
|
||||
Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields);
|
||||
Document doc = Document.builder()
|
||||
.id(getDocumentId(row))
|
||||
.content(row.getString(this.conf.schema.content()))
|
||||
.metadata(docFields)
|
||||
.score((double) score)
|
||||
.build();
|
||||
|
||||
if (this.conf.returnEmbeddings) {
|
||||
doc.setEmbedding(EmbeddingUtils
|
||||
|
||||
@@ -37,6 +37,7 @@ import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.containers.CassandraContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -226,8 +227,7 @@ class CassandraRichSchemaVectorStoreIT {
|
||||
|
||||
assertThat(resultDoc.getMetadata()).hasSize(3);
|
||||
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision",
|
||||
CassandraVectorStore.SIMILARITY_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the createStore
|
||||
store.delete(documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -494,8 +494,7 @@ class CassandraRichSchemaVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isNotEqualTo(sameIdDocument.getId());
|
||||
assertThat(resultDoc.getContent()).doesNotContain(newContent);
|
||||
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision",
|
||||
CassandraVectorStore.SIMILARITY_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value());
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -509,16 +508,15 @@ class CassandraRichSchemaVectorStoreIT {
|
||||
List<Document> fullResult = store
|
||||
.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream()
|
||||
.map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = store.similaritySearch(
|
||||
SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThreshold(threshold));
|
||||
List<Document> results = store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY)
|
||||
.withTopK(5)
|
||||
.withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
@@ -526,8 +524,8 @@ class CassandraRichSchemaVectorStoreIT {
|
||||
|
||||
assertThat(resultDoc.getContent()).contains(URANUS_ORBIT_QUERY);
|
||||
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision",
|
||||
CassandraVectorStore.SIMILARITY_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import com.datastax.oss.driver.api.core.servererrors.SyntaxError;
|
||||
import com.datastax.oss.driver.api.core.type.DataTypes;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.containers.CassandraContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -138,7 +139,7 @@ class CassandraVectorStoreIT {
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
store.delete(documents().stream().map(doc -> doc.getId()).toList());
|
||||
@@ -174,7 +175,7 @@ class CassandraVectorStoreIT {
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
|
||||
assertThat(resultDoc.getMetadata()).hasSize(1);
|
||||
assertThat(resultDoc.getMetadata()).containsKey(CassandraVectorStore.SIMILARITY_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
store.delete(documents().stream().map(doc -> doc.getId()).toList());
|
||||
@@ -359,7 +360,7 @@ class CassandraVectorStoreIT {
|
||||
resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", CassandraVectorStore.SIMILARITY_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
store.delete(List.of(document.getId()));
|
||||
}
|
||||
@@ -375,16 +376,14 @@ class CassandraVectorStoreIT {
|
||||
List<Document> fullResult = store
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream()
|
||||
.map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = store
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(threshold));
|
||||
List<Document> results = store.similaritySearch(
|
||||
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
@@ -393,7 +392,8 @@ class CassandraVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest;
|
||||
import org.springframework.ai.chroma.ChromaApi.DeleteEmbeddingsRequest;
|
||||
import org.springframework.ai.chroma.ChromaApi.Embedding;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -58,11 +59,10 @@ import org.springframework.util.CollectionUtils;
|
||||
* @author Fu Cheng
|
||||
* @author Sebastien Deleuze
|
||||
* @author Soby Chacko
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean {
|
||||
|
||||
public static final String DISTANCE_FIELD_NAME = "distance";
|
||||
|
||||
public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection";
|
||||
|
||||
private final EmbeddingModel embeddingModel;
|
||||
@@ -192,9 +192,14 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements
|
||||
if (metadata == null) {
|
||||
metadata = new HashMap<>();
|
||||
}
|
||||
metadata.put(DISTANCE_FIELD_NAME, distance);
|
||||
Document document = new Document(id, content, metadata);
|
||||
document.setEmbedding(chromaEmbedding.embedding());
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), distance);
|
||||
Document document = Document.builder()
|
||||
.id(id)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.embedding(chromaEmbedding.embedding())
|
||||
.score(1.0 - distance)
|
||||
.build();
|
||||
responseDocuments.add(document);
|
||||
}
|
||||
}
|
||||
@@ -244,8 +249,7 @@ public class ChromaVectorStore extends AbstractObservationVectorStore implements
|
||||
@NonNull String operationName) {
|
||||
return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName)
|
||||
.withDimensions(this.embeddingModel.dimensions())
|
||||
.withCollectionName(this.collectionName + ":" + this.collectionId)
|
||||
.withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null);
|
||||
.withCollectionName(this.collectionName + ":" + this.collectionId);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
@@ -24,6 +24,7 @@ import java.util.UUID;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.chromadb.ChromaDBContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -81,7 +82,7 @@ public class ChromaVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo(
|
||||
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
assertThat(vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()))
|
||||
@@ -99,8 +100,8 @@ public class ChromaVectorStoreIT {
|
||||
VectorStore vectorStore = context.getBean(VectorStore.class);
|
||||
|
||||
var document = Document.builder()
|
||||
.withId("simpleDoc")
|
||||
.withContent("The sky is blue because of Rayleigh scattering.")
|
||||
.id("simpleDoc")
|
||||
.content("The sky is blue because of Rayleigh scattering.")
|
||||
.build();
|
||||
|
||||
vectorStore.add(List.of(document));
|
||||
@@ -179,7 +180,7 @@ public class ChromaVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -194,7 +195,7 @@ public class ChromaVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
@@ -213,13 +214,13 @@ public class ChromaVectorStoreIT {
|
||||
var request = SearchRequest.query("Great").withTopK(5);
|
||||
List<Document> fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
@@ -227,7 +228,8 @@ public class ChromaVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).isEqualTo(
|
||||
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
|
||||
@@ -107,7 +107,6 @@ public class ChromaVectorStoreObservationIT {
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(),
|
||||
"TestCollection:" + vectorStore.getCollectionId())
|
||||
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString())
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance")
|
||||
.doesNotHaveHighCardinalityKeyValueWithKey(
|
||||
HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString())
|
||||
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString())
|
||||
@@ -141,7 +140,6 @@ public class ChromaVectorStoreObservationIT {
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(),
|
||||
"TestCollection:" + vectorStore.getCollectionId())
|
||||
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString())
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance")
|
||||
.doesNotHaveHighCardinalityKeyValueWithKey(
|
||||
HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString())
|
||||
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString(), "1")
|
||||
|
||||
@@ -37,6 +37,7 @@ import com.tangosol.net.Session;
|
||||
import com.tangosol.util.Filter;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.filter.Filter.Expression;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
@@ -62,6 +63,7 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
* </ul>
|
||||
*
|
||||
* @author Aleks Seovic
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class CoherenceVectorStore implements VectorStore, InitializingBean {
|
||||
@@ -211,8 +213,13 @@ public class CoherenceVectorStore implements VectorStore, InitializingBean {
|
||||
if (this.distanceType != DistanceType.COSINE || (1 - r.getDistance()) >= request.getSimilarityThreshold()) {
|
||||
DocumentChunk.Id id = r.getKey();
|
||||
DocumentChunk chunk = r.getValue();
|
||||
chunk.metadata().put("distance", r.getDistance());
|
||||
documents.add(new Document(id.docId(), chunk.text(), chunk.metadata()));
|
||||
chunk.metadata().put(DocumentMetadata.DISTANCE.value(), r.getDistance());
|
||||
documents.add(Document.builder()
|
||||
.id(id.docId())
|
||||
.content(chunk.text())
|
||||
.metadata(chunk.metadata())
|
||||
.score(1 - r.getDistance())
|
||||
.build());
|
||||
}
|
||||
}
|
||||
return documents;
|
||||
|
||||
@@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.transformers.TransformersEmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
|
||||
@@ -125,7 +126,7 @@ public class CoherenceVectorStoreIT {
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -223,7 +224,7 @@ public class CoherenceVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -236,7 +237,7 @@ public class CoherenceVectorStoreIT {
|
||||
resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName());
|
||||
});
|
||||
@@ -257,18 +258,18 @@ public class CoherenceVectorStoreIT {
|
||||
|
||||
assertThat(isSortedByDistance(fullResult)).isTrue();
|
||||
|
||||
List<Double> distances = fullResult.stream()
|
||||
.map(doc -> (Double) doc.getMetadata().get("distance"))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
double threshold = 1d - (distances.get(0) + distances.get(1)) / 2f;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId());
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName());
|
||||
});
|
||||
@@ -276,7 +277,7 @@ public class CoherenceVectorStoreIT {
|
||||
|
||||
private static boolean isSortedByDistance(final List<Document> documents) {
|
||||
final List<Double> distances = documents.stream()
|
||||
.map(doc -> (Double) doc.getMetadata().get("distance"))
|
||||
.map(doc -> (Double) doc.getMetadata().get(DocumentMetadata.DISTANCE.value()))
|
||||
.toList();
|
||||
|
||||
if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {
|
||||
|
||||
@@ -40,6 +40,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -210,20 +211,24 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp
|
||||
|
||||
private Document toDocument(Hit<Document> hit) {
|
||||
Document document = hit.source();
|
||||
document.getMetadata().put("distance", calculateDistance(hit.score().floatValue()));
|
||||
return document;
|
||||
Document.Builder documentBuilder = document.mutate();
|
||||
if (hit.score() != null) {
|
||||
documentBuilder.metadata(DocumentMetadata.DISTANCE.value(), 1 - normalizeSimilarityScore(hit.score()));
|
||||
documentBuilder.score(normalizeSimilarityScore(hit.score()));
|
||||
}
|
||||
return documentBuilder.build();
|
||||
}
|
||||
|
||||
// more info on score/distance calculation
|
||||
// https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search
|
||||
private float calculateDistance(Float score) {
|
||||
private double normalizeSimilarityScore(double score) {
|
||||
switch (this.options.getSimilarity()) {
|
||||
case l2_norm:
|
||||
// the returned value of l2_norm is the opposite of the other functions
|
||||
// (closest to zero means more accurate), so to make it consistent
|
||||
// with the other functions the reverse is returned applying a "1-"
|
||||
// to the standard transformation
|
||||
return (float) (1 - (java.lang.Math.sqrt((1 / score) - 1)));
|
||||
return (1 - (java.lang.Math.sqrt((1 / score) - 1)));
|
||||
// cosine and dot_product
|
||||
default:
|
||||
return (2 * score) - 1;
|
||||
|
||||
@@ -42,6 +42,7 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.elasticsearch.ElasticsearchContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -165,7 +166,7 @@ class ElasticsearchVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
|
||||
@@ -299,7 +300,7 @@ class ElasticsearchVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2"));
|
||||
@@ -318,7 +319,7 @@ class ElasticsearchVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
@@ -343,21 +344,22 @@ class ElasticsearchVectorStoreIT {
|
||||
|
||||
List<Document> fullResult = vectorStore.similaritySearch(query);
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float thresholdResult = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(thresholdResult));
|
||||
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
|
||||
|
||||
@@ -30,6 +30,7 @@ import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import reactor.util.annotation.NonNull;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
@@ -172,8 +173,6 @@ public class GemFireVectorStore extends AbstractObservationVectorStore implement
|
||||
// Query Defaults
|
||||
private static final String QUERY = "/query";
|
||||
|
||||
private static final String DISTANCE_METADATA_FIELD_NAME = "distance";
|
||||
|
||||
/**
|
||||
* Initializes the GemFireVectorStore after properties are set. This method is called
|
||||
* after all bean properties have been set and allows the bean to perform any
|
||||
@@ -271,9 +270,9 @@ public class GemFireVectorStore extends AbstractObservationVectorStore implement
|
||||
metadata = new HashMap<>();
|
||||
metadata.put(DOCUMENT_FIELD, "--Deleted--");
|
||||
}
|
||||
metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score);
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - r.score);
|
||||
String content = (String) metadata.remove(DOCUMENT_FIELD);
|
||||
return new Document(r.key, content, metadata);
|
||||
return Document.builder().id(r.key).content(content).metadata(metadata).score((double) r.score).build();
|
||||
})
|
||||
.collectList()
|
||||
.onErrorMap(WebClientException.class, this::handleHttpClientException)
|
||||
|
||||
@@ -34,6 +34,7 @@ import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.transformers.TransformersEmbeddingModel;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
@@ -134,7 +135,7 @@ public class GemFireVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939)" + " was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -156,7 +157,7 @@ public class GemFireVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks " + "Around the Corner",
|
||||
@@ -171,7 +172,7 @@ public class GemFireVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation" + " Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -191,12 +192,12 @@ public class GemFireVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
assertThat(distances).hasSize(3);
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
|
||||
@@ -204,7 +205,8 @@ public class GemFireVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression " + "(1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -97,7 +98,7 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements
|
||||
public static final String EMBEDDING_FIELD_NAME = "embedding";
|
||||
|
||||
// Metadata, automatically assigned by Milvus.
|
||||
public static final String DISTANCE_FIELD_NAME = "distance";
|
||||
private static final String DISTANCE_FIELD_NAME = "distance";
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class);
|
||||
|
||||
@@ -258,13 +259,18 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements
|
||||
try {
|
||||
metadata = (JSONObject) rowRecord.get(this.config.metadataFieldName);
|
||||
// inject the distance into the metadata.
|
||||
metadata.put(DISTANCE_FIELD_NAME, 1 - getResultSimilarity(rowRecord));
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - getResultSimilarity(rowRecord));
|
||||
}
|
||||
catch (ParamException e) {
|
||||
// skip the ParamException if metadata doesn't exist for the custom
|
||||
// collection
|
||||
}
|
||||
return new Document(docId, content, (metadata != null) ? metadata.getInnerMap() : Map.of());
|
||||
return Document.builder()
|
||||
.id(docId)
|
||||
.content(content)
|
||||
.metadata((metadata != null) ? metadata.getInnerMap() : Map.of())
|
||||
.score((double) getResultSimilarity(rowRecord))
|
||||
.build();
|
||||
})
|
||||
.toList();
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import io.milvus.param.MetricType;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.milvus.MilvusContainer;
|
||||
@@ -106,7 +107,7 @@ public class MilvusVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -200,7 +201,7 @@ public class MilvusVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -215,7 +216,7 @@ public class MilvusVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
|
||||
@@ -238,24 +239,22 @@ public class MilvusVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream()
|
||||
.map(doc -> (Float) doc.getMetadata().get("distance"))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import com.mongodb.MongoCommandException;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -172,12 +173,17 @@ public class MongoDBAtlasVectorStore extends AbstractObservationVectorStore impl
|
||||
private Document mapMongoDocument(org.bson.Document mongoDocument, float[] queryEmbedding) {
|
||||
String id = mongoDocument.getString(ID_FIELD_NAME);
|
||||
String content = mongoDocument.getString(CONTENT_FIELD_NAME);
|
||||
double score = mongoDocument.getDouble(SCORE_FIELD_NAME);
|
||||
Map<String, Object> metadata = mongoDocument.get(METADATA_FIELD_NAME, org.bson.Document.class);
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - score);
|
||||
|
||||
Document document = new Document(id, content, metadata);
|
||||
document.setEmbedding(queryEmbedding);
|
||||
|
||||
return document;
|
||||
return Document.builder()
|
||||
.id(id)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.score(score)
|
||||
.embedding(queryEmbedding)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package org.springframework.ai.vectorstore;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@@ -27,6 +29,8 @@ import com.mongodb.client.MongoClient;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.mongodb.MongoDBAtlasLocalContainer;
|
||||
@@ -198,6 +202,55 @@ class MongoDBAtlasVectorStoreIT {
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void searchWithThreshold() {
|
||||
this.contextRunner.run(context -> {
|
||||
VectorStore vectorStore = context.getBean(VectorStore.class);
|
||||
|
||||
var documents = List.of(
|
||||
new Document("471a8c78-549a-4b2c-bce5-ef3ae6579be3", getText("classpath:/test/data/spring.ai.txt"),
|
||||
Map.of("meta1", "meta1")),
|
||||
new Document("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f",
|
||||
getText("classpath:/test/data/time.shelter.txt"), Map.of()),
|
||||
new Document("d0237682-1150-44ff-b4d2-1be9b1731ee5",
|
||||
getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
|
||||
vectorStore.add(documents);
|
||||
Thread.sleep(5000); // Await a second for the document to be indexed
|
||||
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
|
||||
assertThat(fullResult).hasSize(3);
|
||||
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId());
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
public static String getText(String uri) {
|
||||
var resource = new DefaultResourceLoader().getResource(uri);
|
||||
try {
|
||||
return resource.getContentAsString(StandardCharsets.UTF_8);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
@EnableAutoConfiguration
|
||||
public static class TestApplication {
|
||||
|
||||
@@ -29,6 +29,7 @@ import org.neo4j.driver.SessionConfig;
|
||||
import org.neo4j.driver.Values;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -224,15 +225,19 @@ public class Neo4jVectorStore extends AbstractObservationVectorStore implements
|
||||
var node = neoRecord.get("node").asNode();
|
||||
var score = neoRecord.get("score").asFloat();
|
||||
var metaData = new HashMap<String, Object>();
|
||||
metaData.put("distance", 1 - score);
|
||||
metaData.put(DocumentMetadata.DISTANCE.value(), 1 - score);
|
||||
node.keys().forEach(key -> {
|
||||
if (key.startsWith("metadata.")) {
|
||||
metaData.put(key.substring(key.indexOf(".") + 1), node.get(key).asObject());
|
||||
}
|
||||
});
|
||||
|
||||
return new Document(node.get(this.config.idProperty).asString(), node.get("text").asString(),
|
||||
Map.copyOf(metaData));
|
||||
return Document.builder()
|
||||
.id(node.get(this.config.idProperty).asString())
|
||||
.content(node.get("text").asString())
|
||||
.metadata(Map.copyOf(metaData))
|
||||
.score((double) score)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -28,6 +28,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.neo4j.driver.AuthTokens;
|
||||
import org.neo4j.driver.Driver;
|
||||
import org.neo4j.driver.GraphDatabase;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -91,7 +92,7 @@ class Neo4jVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).isEqualTo(
|
||||
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
|
||||
@@ -203,7 +204,7 @@ class Neo4jVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -218,7 +219,7 @@ class Neo4jVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
});
|
||||
}
|
||||
@@ -235,14 +236,14 @@ class Neo4jVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
@@ -250,8 +251,8 @@ class Neo4jVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).isEqualTo(
|
||||
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -231,8 +232,12 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem
|
||||
|
||||
private Document toDocument(Hit<Document> hit) {
|
||||
Document document = hit.source();
|
||||
document.getMetadata().put("distance", 1 - hit.score().floatValue());
|
||||
return document;
|
||||
Document.Builder documentBuilder = document.mutate();
|
||||
if (hit.score() != null) {
|
||||
documentBuilder.metadata(DocumentMetadata.DISTANCE.value(), 1 - hit.score().floatValue());
|
||||
documentBuilder.score(hit.score());
|
||||
}
|
||||
return documentBuilder.build();
|
||||
}
|
||||
|
||||
public boolean exists(String targetIndex) {
|
||||
|
||||
@@ -39,6 +39,7 @@ import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.opensearch.client.opensearch.OpenSearchClient;
|
||||
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;
|
||||
import org.opensearch.testcontainers.OpensearchContainer;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
|
||||
@@ -144,7 +145,7 @@ class OpenSearchVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
|
||||
@@ -281,7 +282,7 @@ class OpenSearchVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2"));
|
||||
@@ -300,7 +301,7 @@ class OpenSearchVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
@@ -330,21 +331,22 @@ class OpenSearchVectorStoreIT {
|
||||
|
||||
List<Document> fullResult = vectorStore.similaritySearch(query);
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold));
|
||||
SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
|
||||
|
||||
@@ -39,6 +39,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -80,6 +81,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Loïc Lefèvre
|
||||
* @author Christian Tzolov
|
||||
* @author Soby Chacko
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class OracleVectorStore extends AbstractObservationVectorStore implements InitializingBean {
|
||||
|
||||
@@ -649,12 +651,16 @@ public class OracleVectorStore extends AbstractObservationVectorStore implements
|
||||
@Override
|
||||
public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
|
||||
final Map<String, Object> metadata = getMap(rs.getObject(3, OracleJsonValue.class));
|
||||
metadata.put("distance", rs.getDouble(5));
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), rs.getDouble(5));
|
||||
|
||||
final Document document = new Document(rs.getString(1), rs.getString(2), metadata);
|
||||
final float[] embedding = rs.getObject(4, float[].class);
|
||||
document.setEmbedding(embedding);
|
||||
return document;
|
||||
return Document.builder()
|
||||
.id(rs.getString(1))
|
||||
.content(rs.getString(2))
|
||||
.metadata(metadata)
|
||||
.score(1 - rs.getDouble(5))
|
||||
.embedding(embedding)
|
||||
.build();
|
||||
}
|
||||
|
||||
private Map<String, Object> getMap(OracleJsonValue value) {
|
||||
|
||||
@@ -32,6 +32,7 @@ import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.oracle.OracleContainer;
|
||||
@@ -94,21 +95,19 @@ public class OracleVectorStoreIT {
|
||||
jdbcTemplate.execute("DROP TABLE IF EXISTS " + tableName + " PURGE");
|
||||
}
|
||||
|
||||
private static boolean isSortedByDistance(final List<Document> documents) {
|
||||
final List<Double> distances = documents.stream()
|
||||
.map(doc -> (Double) doc.getMetadata().get("distance"))
|
||||
.toList();
|
||||
private static boolean isSortedBySimilarity(final List<Document> documents) {
|
||||
final List<Double> scores = documents.stream().map(Document::getScore).toList();
|
||||
|
||||
if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {
|
||||
if (CollectionUtils.isEmpty(scores) || scores.size() == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Iterator<Double> iter = distances.iterator();
|
||||
Iterator<Double> iter = scores.iterator();
|
||||
Double current;
|
||||
Double previous = iter.next();
|
||||
while (iter.hasNext()) {
|
||||
current = iter.next();
|
||||
if (previous > current) {
|
||||
if (previous < current) {
|
||||
return false;
|
||||
}
|
||||
previous = current;
|
||||
@@ -134,7 +133,7 @@ public class OracleVectorStoreIT {
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -243,7 +242,7 @@ public class OracleVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -256,7 +255,7 @@ public class OracleVectorStoreIT {
|
||||
resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
dropTable(context, ((OracleVectorStore) vectorStore).getTableName());
|
||||
});
|
||||
@@ -279,20 +278,19 @@ public class OracleVectorStoreIT {
|
||||
|
||||
assertThat(fullResult).hasSize(3);
|
||||
|
||||
assertThat(isSortedByDistance(fullResult)).isTrue();
|
||||
assertThat(isSortedBySimilarity(fullResult)).isTrue();
|
||||
|
||||
List<Double> distances = fullResult.stream()
|
||||
.map(doc -> (Double) doc.getMetadata().get("distance"))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
double threshold = (distances.get(0) + distances.get(1)) / 2d;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2d;
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1d - threshold));
|
||||
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
dropTable(context, ((OracleVectorStore) vectorStore).getTableName());
|
||||
});
|
||||
|
||||
@@ -35,6 +35,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -502,12 +503,15 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
|
||||
Float distance = rs.getFloat(COLUMN_DISTANCE);
|
||||
|
||||
Map<String, Object> metadata = toMap(pgMetadata);
|
||||
metadata.put(COLUMN_DISTANCE, distance);
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), distance);
|
||||
|
||||
Document document = new Document(id, content, metadata);
|
||||
document.setEmbedding(toFloatArray(embedding));
|
||||
|
||||
return document;
|
||||
return Document.builder()
|
||||
.id(id)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.score(1.0 - distance)
|
||||
.embedding(toFloatArray(embedding))
|
||||
.build();
|
||||
}
|
||||
|
||||
private float[] toFloatArray(PGobject embedding) throws SQLException {
|
||||
|
||||
@@ -34,6 +34,7 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.containers.PostgreSQLContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -113,20 +114,20 @@ public class PgVectorStoreIT {
|
||||
);
|
||||
}
|
||||
|
||||
private static boolean isSortedByDistance(List<Document> docs) {
|
||||
private static boolean isSortedBySimilarity(List<Document> docs) {
|
||||
|
||||
List<Float> distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = docs.stream().map(Document::getScore).toList();
|
||||
|
||||
if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {
|
||||
if (CollectionUtils.isEmpty(scores) || scores.size() == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Iterator<Float> iter = distances.iterator();
|
||||
Float current;
|
||||
Float previous = iter.next();
|
||||
Iterator<Double> iter = scores.iterator();
|
||||
Double current;
|
||||
Double previous = iter.next();
|
||||
while (iter.hasNext()) {
|
||||
current = iter.next();
|
||||
if (previous > current) {
|
||||
if (previous < current) {
|
||||
return false;
|
||||
}
|
||||
previous = current;
|
||||
@@ -150,7 +151,7 @@ public class PgVectorStoreIT {
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -289,7 +290,7 @@ public class PgVectorStoreIT {
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -303,7 +304,7 @@ public class PgVectorStoreIT {
|
||||
resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
dropTable(context);
|
||||
});
|
||||
@@ -326,20 +327,19 @@ public class PgVectorStoreIT {
|
||||
|
||||
assertThat(fullResult).hasSize(3);
|
||||
|
||||
assertThat(isSortedByDistance(fullResult)).isTrue();
|
||||
assertThat(isSortedBySimilarity(fullResult)).isTrue();
|
||||
|
||||
List<Float> distances = fullResult.stream()
|
||||
.map(doc -> (Float) doc.getMetadata().get("distance"))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
dropTable(context);
|
||||
});
|
||||
|
||||
@@ -38,6 +38,7 @@ import io.pinecone.proto.UpsertRequest;
|
||||
import io.pinecone.proto.Vector;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -60,13 +61,12 @@ import org.springframework.util.StringUtils;
|
||||
* @author Christian Tzolov
|
||||
* @author Adam Bchouti
|
||||
* @author Soby Chacko
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class PineconeVectorStore extends AbstractObservationVectorStore {
|
||||
|
||||
public static final String CONTENT_FIELD_NAME = "document_content";
|
||||
|
||||
public static final String DISTANCE_METADATA_FIELD_NAME = "distance";
|
||||
|
||||
public final FilterExpressionConverter filterExpressionConverter = new PineconeFilterExpressionConverter();
|
||||
|
||||
private final EmbeddingModel embeddingModel;
|
||||
@@ -236,7 +236,12 @@ public class PineconeVectorStore extends AbstractObservationVectorStore {
|
||||
var content = metadataStruct.getFieldsOrThrow(this.pineconeContentFieldName).getStringValue();
|
||||
Map<String, Object> metadata = extractMetadata(metadataStruct);
|
||||
metadata.put(this.pineconeDistanceMetadataFieldName, 1 - scoredVector.getScore());
|
||||
return new Document(id, content, metadata);
|
||||
return Document.builder()
|
||||
.id(id)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.score((double) scoredVector.getScore())
|
||||
.build();
|
||||
})
|
||||
.toList();
|
||||
}
|
||||
@@ -298,6 +303,8 @@ public class PineconeVectorStore extends AbstractObservationVectorStore {
|
||||
|
||||
private final String contentFieldName;
|
||||
|
||||
// TODO: Why is this field configurable? Can we remove this after standardizing
|
||||
// the key?
|
||||
private final String distanceMetadataFieldName;
|
||||
|
||||
private final PineconeConnectionConfig connectionConfig;
|
||||
@@ -357,7 +364,7 @@ public class PineconeVectorStore extends AbstractObservationVectorStore {
|
||||
|
||||
private String contentFieldName = CONTENT_FIELD_NAME;
|
||||
|
||||
private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME;
|
||||
private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value();
|
||||
|
||||
/**
|
||||
* Optional server-side timeout in seconds for all operations. Default: 20
|
||||
|
||||
@@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.transformers.TransformersEmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.PineconeVectorStore.PineconeVectorStoreConfig;
|
||||
@@ -46,6 +47,7 @@ import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+")
|
||||
public class PineconeVectorStoreIT {
|
||||
@@ -109,7 +111,7 @@ public class PineconeVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -193,7 +195,7 @@ public class PineconeVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -214,7 +216,7 @@ public class PineconeVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
@@ -240,21 +242,22 @@ public class PineconeVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
|
||||
@@ -35,6 +35,7 @@ import io.qdrant.client.grpc.Points.SearchPoints;
|
||||
import io.qdrant.client.grpc.Points.UpdateStatus;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -57,6 +58,7 @@ import org.springframework.util.Assert;
|
||||
* @author Eddú Meléndez
|
||||
* @author Josh Long
|
||||
* @author Soby Chacko
|
||||
* @author Thomas Vitale
|
||||
* @since 0.8.1
|
||||
*/
|
||||
public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean {
|
||||
@@ -65,8 +67,6 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements
|
||||
|
||||
private static final String CONTENT_FIELD_NAME = "doc_content";
|
||||
|
||||
private static final String DISTANCE_FIELD_NAME = "distance";
|
||||
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
private final QdrantClient qdrantClient;
|
||||
@@ -208,12 +208,17 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements
|
||||
try {
|
||||
var id = point.getId().getUuid();
|
||||
|
||||
var payload = QdrantObjectFactory.toObjectMap(point.getPayloadMap());
|
||||
payload.put(DISTANCE_FIELD_NAME, 1 - point.getScore());
|
||||
var metadata = QdrantObjectFactory.toObjectMap(point.getPayloadMap());
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - point.getScore());
|
||||
|
||||
var content = (String) payload.remove(CONTENT_FIELD_NAME);
|
||||
var content = (String) metadata.remove(CONTENT_FIELD_NAME);
|
||||
|
||||
return new Document(id, content, payload);
|
||||
return Document.builder()
|
||||
.id(id)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.score((double) point.getScore())
|
||||
.build();
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
|
||||
@@ -30,6 +30,7 @@ import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.qdrant.QdrantContainer;
|
||||
@@ -107,7 +108,7 @@ public class QdrantVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo(
|
||||
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -185,7 +186,7 @@ public class QdrantVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -200,7 +201,7 @@ public class QdrantVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
});
|
||||
@@ -218,13 +219,13 @@ public class QdrantVectorStoreIT {
|
||||
var request = SearchRequest.query("Great").withTopK(5);
|
||||
List<Document> fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
@@ -232,7 +233,8 @@ public class QdrantVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).isEqualTo(
|
||||
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
|
||||
@@ -30,6 +30,7 @@ import java.util.stream.Collectors;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
import redis.clients.jedis.Pipeline;
|
||||
import redis.clients.jedis.json.Path2;
|
||||
@@ -240,14 +241,21 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements
|
||||
|
||||
private Document toDocument(redis.clients.jedis.search.Document doc) {
|
||||
var id = doc.getId().substring(this.config.prefix.length());
|
||||
var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName)
|
||||
: null;
|
||||
var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) : "";
|
||||
Map<String, Object> metadata = this.config.metadataFields.stream()
|
||||
.map(MetadataField::name)
|
||||
.filter(doc::hasProperty)
|
||||
.collect(Collectors.toMap(Function.identity(), doc::getString));
|
||||
// TODO: this seems wrong. The key is named "vector_store", but the value is the
|
||||
// distance. Can we remove this after standardizing the metadata?
|
||||
metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc));
|
||||
return new Document(id, content, metadata);
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - similarityScore(doc));
|
||||
return Document.builder()
|
||||
.id(id)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.score((double) similarityScore(doc))
|
||||
.build();
|
||||
}
|
||||
|
||||
private float similarityScore(redis.clients.jedis.search.Document doc) {
|
||||
|
||||
@@ -26,6 +26,7 @@ import java.util.UUID;
|
||||
import com.redis.testcontainers.RedisStackContainer;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
@@ -50,6 +51,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
/**
|
||||
* @author Julien Ruaux
|
||||
* @author Eddú Meléndez
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@Testcontainers
|
||||
class RedisVectorStoreIT {
|
||||
@@ -105,8 +107,9 @@ class RedisVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).hasSize(3);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME,
|
||||
DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -190,6 +193,7 @@ class RedisVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -205,6 +209,7 @@ class RedisVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME);
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
|
||||
@@ -223,24 +228,23 @@ class RedisVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream()
|
||||
.map(doc -> (Float) doc.getMetadata().get(RedisVectorStore.DISTANCE_FIELD_NAME))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME);
|
||||
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME,
|
||||
DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import java.util.stream.Stream;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.typesense.api.Client;
|
||||
import org.typesense.api.FieldTypes;
|
||||
import org.typesense.model.CollectionResponse;
|
||||
@@ -57,6 +58,7 @@ import org.springframework.util.Assert;
|
||||
* @author Pablo Sanchidrian Herrera
|
||||
* @author Soby Chacko
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class TypesenseVectorStore extends AbstractObservationVectorStore implements InitializingBean {
|
||||
|
||||
@@ -212,8 +214,13 @@ public class TypesenseVectorStore extends AbstractObservationVectorStore impleme
|
||||
String content = rawDocument.get(CONTENT_FIELD_NAME).toString();
|
||||
Map<String, Object> metadata = rawDocument.get(METADATA_FIELD_NAME) instanceof Map
|
||||
? (Map<String, Object>) rawDocument.get(METADATA_FIELD_NAME) : Map.of();
|
||||
metadata.put("distance", hit.getVectorDistance());
|
||||
return new Document(docId, content, metadata);
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), hit.getVectorDistance());
|
||||
return Document.builder()
|
||||
.id(docId)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.score(1.0 - hit.getVectorDistance())
|
||||
.build();
|
||||
}))
|
||||
.toList();
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.containers.GenericContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -96,7 +97,7 @@ public class TypesenseVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -114,7 +115,7 @@ public class TypesenseVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
|
||||
@@ -211,21 +212,22 @@ public class TypesenseVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
float threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
((TypesenseVectorStore) vectorStore).dropCollection();
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ import io.weaviate.client.v1.graphql.query.fields.Field;
|
||||
import io.weaviate.client.v1.graphql.query.fields.Fields;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.springframework.ai.embedding.BatchingStrategy;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
|
||||
@@ -73,11 +74,10 @@ import org.springframework.util.StringUtils;
|
||||
* @author Eddú Meléndez
|
||||
* @author Josh Long
|
||||
* @author Soby Chacko
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class WeaviateVectorStore extends AbstractObservationVectorStore {
|
||||
|
||||
public static final String DOCUMENT_METADATA_DISTANCE_KEY_NAME = "distance";
|
||||
|
||||
private static final String METADATA_FIELD_PREFIX = "meta_";
|
||||
|
||||
private static final String CONTENT_FIELD_NAME = "content";
|
||||
@@ -367,7 +367,7 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore {
|
||||
|
||||
// Metadata
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
metadata.put(DOCUMENT_METADATA_DISTANCE_KEY_NAME, 1 - certainty);
|
||||
metadata.put(DocumentMetadata.DISTANCE.value(), 1 - certainty);
|
||||
|
||||
try {
|
||||
String metadataJson = (String) item.get(METADATA_FIELD_NAME);
|
||||
@@ -382,10 +382,13 @@ public class WeaviateVectorStore extends AbstractObservationVectorStore {
|
||||
// Content
|
||||
String content = (String) item.get(CONTENT_FIELD_NAME);
|
||||
|
||||
var document = new Document(id, content, metadata);
|
||||
document.setEmbedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding)));
|
||||
|
||||
return document;
|
||||
return Document.builder()
|
||||
.id(id)
|
||||
.content(content)
|
||||
.metadata(metadata)
|
||||
.embedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding)))
|
||||
.score(certainty)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -26,6 +26,7 @@ import java.util.UUID;
|
||||
import io.weaviate.client.Config;
|
||||
import io.weaviate.client.WeaviateClient;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.document.DocumentMetadata;
|
||||
import org.testcontainers.containers.wait.strategy.Wait;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -101,7 +102,7 @@ public class WeaviateVectorStoreIT {
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).hasSize(2);
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
|
||||
// Remove all documents from the store
|
||||
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());
|
||||
@@ -186,7 +187,7 @@ public class WeaviateVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta1");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
Document sameIdDocument = new Document(document.getId(),
|
||||
"The World is Big and Salvation Lurks Around the Corner",
|
||||
@@ -201,7 +202,7 @@ public class WeaviateVectorStoreIT {
|
||||
assertThat(resultDoc.getId()).isEqualTo(document.getId());
|
||||
assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("meta2");
|
||||
assertThat(resultDoc.getMetadata()).containsKey("distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value());
|
||||
|
||||
vectorStore.delete(List.of(document.getId()));
|
||||
|
||||
@@ -222,23 +223,22 @@ public class WeaviateVectorStoreIT {
|
||||
List<Document> fullResult = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
|
||||
|
||||
List<Double> distances = fullResult.stream()
|
||||
.map(doc -> (Double) doc.getMetadata().get("distance"))
|
||||
.toList();
|
||||
List<Double> scores = fullResult.stream().map(Document::getScore).toList();
|
||||
|
||||
assertThat(distances).hasSize(3);
|
||||
assertThat(scores).hasSize(3);
|
||||
|
||||
double threshold = (distances.get(0) + distances.get(1)) / 2;
|
||||
double similarityThreshold = (scores.get(0) + scores.get(1)) / 2;
|
||||
|
||||
List<Document> results = vectorStore
|
||||
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold));
|
||||
List<Document> results = vectorStore.similaritySearch(
|
||||
SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold));
|
||||
|
||||
assertThat(results).hasSize(1);
|
||||
Document resultDoc = results.get(0);
|
||||
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId());
|
||||
assertThat(resultDoc.getContent()).contains(
|
||||
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance");
|
||||
assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value());
|
||||
assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold);
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user