Add LocalStack OpenSearch Service Connection support for Docker Compose and Testcontainers

* Add property `spring.ai.vectorstore.opensearch.aws.domain-name`
* Require `AwsCredentialsProvider` to enable `AwsOpenSearchConfiguration`
* Add Testcontainers Service Connection support
* Add Docker Compose Service Connection support
This commit is contained in:
Eddú Meléndez
2024-11-12 22:23:05 -05:00
committed by Ilayaperumal Gopinathan
parent c8dc342b30
commit 6261ce02ff
16 changed files with 579 additions and 58 deletions

View File

@@ -31,6 +31,9 @@ The following service connection factories are provided in the `spring-ai-spring
[cols="|,|"]
|====
| Connection Details | Matched on
| `AwsOpenSearchConnectionDetails`
| Containers named `localstack/localstack`
| `ChromaConnectionDetails`
| Containers named `chromadb/chroma`, `ghcr.io/chroma-core/chroma`

View File

@@ -31,6 +31,10 @@ The following service connection factories are provided in the `spring-ai-spring
[cols="|,|"]
|====
| Connection Details | Matched on
| `AwsOpenSearchConnectionDetails`
| Containers of type `LocalStackContainer`
| `ChromaConnectionDetails`
| Containers of type `ChromaDBContainer`

View File

@@ -0,0 +1,31 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.autoconfigure.vectorstore.opensearch;
import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails;
public interface AwsOpenSearchConnectionDetails extends ConnectionDetails {
String getRegion();
String getAccessKey();
String getSecretKey();
String getHost(String domainName);
}

View File

@@ -30,7 +30,9 @@ import org.opensearch.client.transport.OpenSearchTransport;
import org.opensearch.client.transport.aws.AwsSdk2Transport;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder;
import org.springframework.util.StringUtils;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
@@ -123,28 +125,35 @@ public class OpenSearchVectorStoreAutoConfiguration {
}
@Configuration(proxyBeanMethods = false)
@ConditionalOnClass({ Region.class, ApacheHttpClient.class })
@ConditionalOnClass({ AwsCredentialsProvider.class, Region.class, ApacheHttpClient.class })
static class AwsOpenSearchConfiguration {
@Bean
@ConditionalOnMissingBean(AwsOpenSearchConnectionDetails.class)
PropertiesAwsOpenSearchConnectionDetails awsOpenSearchConnectionDetails(
OpenSearchVectorStoreProperties properties) {
return new PropertiesAwsOpenSearchConnectionDetails(properties);
}
@Bean
@ConditionalOnMissingBean
OpenSearchClient openSearchClient(OpenSearchVectorStoreProperties properties, AwsSdk2TransportOptions options) {
OpenSearchVectorStoreProperties.Aws aws = properties.getAws();
Region region = Region.of(aws.getRegion());
OpenSearchClient openSearchClient(OpenSearchVectorStoreProperties properties,
AwsOpenSearchConnectionDetails connectionDetails, AwsSdk2TransportOptions options) {
Region region = Region.of(connectionDetails.getRegion());
SdkHttpClient httpClient = ApacheHttpClient.builder().build();
OpenSearchTransport transport = new AwsSdk2Transport(httpClient, aws.getHost(), aws.getServiceName(),
region, options);
OpenSearchTransport transport = new AwsSdk2Transport(httpClient,
connectionDetails.getHost(properties.getAws().getDomainName()),
properties.getAws().getServiceName(), region, options);
return new OpenSearchClient(transport);
}
@Bean
@ConditionalOnMissingBean
AwsSdk2TransportOptions options(OpenSearchVectorStoreProperties properties) {
OpenSearchVectorStoreProperties.Aws aws = properties.getAws();
AwsSdk2TransportOptions options(AwsOpenSearchConnectionDetails connectionDetails) {
return AwsSdk2TransportOptions.builder()
.setCredentials(StaticCredentialsProvider
.create(AwsBasicCredentials.create(aws.getAccessKey(), aws.getSecretKey())))
.setCredentials(StaticCredentialsProvider.create(
AwsBasicCredentials.create(connectionDetails.getAccessKey(), connectionDetails.getSecretKey())))
.build();
}
@@ -175,4 +184,37 @@ public class OpenSearchVectorStoreAutoConfiguration {
}
static class PropertiesAwsOpenSearchConnectionDetails implements AwsOpenSearchConnectionDetails {
private final OpenSearchVectorStoreProperties.Aws aws;
public PropertiesAwsOpenSearchConnectionDetails(OpenSearchVectorStoreProperties properties) {
this.aws = properties.getAws();
}
@Override
public String getRegion() {
return this.aws.getRegion();
}
@Override
public String getAccessKey() {
return this.aws.getAccessKey();
}
@Override
public String getSecretKey() {
return this.aws.getSecretKey();
}
@Override
public String getHost(String domainName) {
if (StringUtils.hasText(domainName)) {
return "%s.%s".formatted(this.aws.getDomainName(), this.aws.getHost());
}
return this.aws.getHost();
}
}
}

View File

@@ -91,6 +91,8 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
static class Aws {
private String domainName;
private String host;
private String serviceName;
@@ -101,6 +103,14 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
private String region;
public String getDomainName() {
return this.domainName;
}
public void setDomainName(String domainName) {
this.domainName = domainName;
}
public String getHost() {
return this.host;
}

View File

@@ -62,12 +62,11 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT {
.withConfiguration(AutoConfigurations.of(OpenSearchVectorStoreAutoConfiguration.class,
SpringAiRetryAutoConfiguration.class))
.withUserConfiguration(Config.class)
.withPropertyValues("spring.ai.vectorstore.opensearch.initialize-schema=true")
.withPropertyValues(
.withPropertyValues("spring.ai.vectorstore.opensearch.initialize-schema=true",
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".aws.host="
+ String.format("testcontainers-domain.%s.opensearch.localhost.localstack.cloud:%s",
localstack.getRegion(), localstack.getMappedPort(4566)),
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".aws.service-name=opensearch",
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".aws.service-name=es",
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".aws.region=" + localstack.getRegion(),
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".aws.access-key=" + localstack.getAccessKey(),
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".aws.secret-key=" + localstack.getSecretKey(),

View File

@@ -0,0 +1,82 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.docker.compose.service.connection.opensearch;
import org.springframework.ai.autoconfigure.vectorstore.opensearch.AwsOpenSearchConnectionDetails;
import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchConnectionDetails;
import org.springframework.boot.docker.compose.core.RunningService;
import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionDetailsFactory;
import org.springframework.boot.docker.compose.service.connection.DockerComposeConnectionSource;
/**
* @author Eddú Meléndez
*/
class AwsOpenSearchDockerComposeConnectionDetailsFactory
extends DockerComposeConnectionDetailsFactory<AwsOpenSearchConnectionDetails> {
private static final int LOCALSTACK_PORT = 4566;
protected AwsOpenSearchDockerComposeConnectionDetailsFactory() {
super("localstack/localstack");
}
@Override
protected AwsOpenSearchConnectionDetails getDockerComposeConnectionDetails(DockerComposeConnectionSource source) {
return new AwsOpenSearchDockerComposeConnectionDetails(source.getRunningService());
}
/**
* {@link OpenSearchConnectionDetails} backed by a {@code OpenSearch}
* {@link RunningService}.
*/
static class AwsOpenSearchDockerComposeConnectionDetails extends DockerComposeConnectionDetails
implements AwsOpenSearchConnectionDetails {
private final AwsOpenSearchEnvironment environment;
private final int port;
AwsOpenSearchDockerComposeConnectionDetails(RunningService service) {
super(service);
this.environment = new AwsOpenSearchEnvironment(service.env());
this.port = service.ports().get(LOCALSTACK_PORT);
}
@Override
public String getRegion() {
return this.environment.getRegion();
}
@Override
public String getAccessKey() {
return this.environment.getAccessKey();
}
@Override
public String getSecretKey() {
return this.environment.getSecretKey();
}
@Override
public String getHost(String domainName) {
return "%s.%s.opensearch.localhost.localstack.cloud:%s".formatted(domainName, this.environment.getRegion(),
this.port);
}
}
}

View File

@@ -0,0 +1,47 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.docker.compose.service.connection.opensearch;
import java.util.Map;
class AwsOpenSearchEnvironment {
private final String region;
private final String accessKey;
private final String secretKey;
AwsOpenSearchEnvironment(Map<String, String> env) {
this.region = env.getOrDefault("DEFAULT_REGION", "us-east-1");
this.accessKey = env.getOrDefault("AWS_ACCESS_KEY_ID", "test");
this.secretKey = env.getOrDefault("AWS_SECRET_ACCESS_KEY", "test");
}
public String getRegion() {
return this.region;
}
public String getAccessKey() {
return this.accessKey;
}
public String getSecretKey() {
return this.secretKey;
}
}

View File

@@ -18,6 +18,7 @@ org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFacto
org.springframework.ai.docker.compose.service.connection.chroma.ChromaDockerComposeConnectionDetailsFactory,\
org.springframework.ai.docker.compose.service.connection.mongo.MongoDbAtlasLocalDockerComposeConnectionDetailsFactory,\
org.springframework.ai.docker.compose.service.connection.ollama.OllamaDockerComposeConnectionDetailsFactory,\
org.springframework.ai.docker.compose.service.connection.opensearch.AwsOpenSearchDockerComposeConnectionDetailsFactory,\
org.springframework.ai.docker.compose.service.connection.opensearch.OpenSearchDockerComposeConnectionDetailsFactory,\
org.springframework.ai.docker.compose.service.connection.qdrant.QdrantDockerComposeConnectionDetailsFactory,\
org.springframework.ai.docker.compose.service.connection.typesense.TypesenseDockerComposeConnectionDetailsFactory,\

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.docker.compose.service.connection.opensearch;
import org.junit.jupiter.api.Test;
import org.springframework.ai.autoconfigure.vectorstore.opensearch.AwsOpenSearchConnectionDetails;
import org.springframework.boot.docker.compose.service.connection.test.AbstractDockerComposeIntegrationTests;
import org.testcontainers.utility.DockerImageName;
import static org.assertj.core.api.Assertions.assertThat;
class AwsOpenSearchDockerComposeConnectionDetailsFactoryTests extends AbstractDockerComposeIntegrationTests {
AwsOpenSearchDockerComposeConnectionDetailsFactoryTests() {
super("localstack-compose.yaml", DockerImageName.parse("localstack/localstack:3.5.0"));
}
@Test
void runCreatesConnectionDetails() {
AwsOpenSearchConnectionDetails connectionDetails = run(AwsOpenSearchConnectionDetails.class);
assertThat(connectionDetails.getAccessKey()).isEqualTo("test");
assertThat(connectionDetails.getSecretKey()).isEqualTo("test");
assertThat(connectionDetails.getRegion()).isEqualTo("us-east-1");
}
}

View File

@@ -0,0 +1,5 @@
services:
localstack:
image: '{imageName}'
ports:
- '4566'

View File

@@ -222,6 +222,27 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>apache-client</artifactId>
<version>${awssdk.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>auth</artifactId>
<version>${awssdk.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>regions</artifactId>
<version>${awssdk.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>qdrant</artifactId>
@@ -240,6 +261,12 @@
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>localstack</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>milvus</artifactId>

View File

@@ -0,0 +1,70 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.testcontainers.service.connection.opensearch;
import org.springframework.ai.autoconfigure.vectorstore.opensearch.AwsOpenSearchConnectionDetails;
import org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory;
import org.springframework.boot.testcontainers.service.connection.ContainerConnectionSource;
import org.testcontainers.containers.localstack.LocalStackContainer;
/**
* @author Eddú Meléndez
*/
class AwsOpenSearchContainerConnectionDetailsFactory
extends ContainerConnectionDetailsFactory<LocalStackContainer, AwsOpenSearchConnectionDetails> {
@Override
public AwsOpenSearchConnectionDetails getContainerConnectionDetails(
ContainerConnectionSource<LocalStackContainer> source) {
return new AwsOpenSearchContainerConnectionDetails(source);
}
/**
* {@link AwsOpenSearchConnectionDetails} backed by a
* {@link ContainerConnectionSource}.
*/
private static final class AwsOpenSearchContainerConnectionDetails
extends ContainerConnectionDetails<LocalStackContainer> implements AwsOpenSearchConnectionDetails {
private AwsOpenSearchContainerConnectionDetails(ContainerConnectionSource<LocalStackContainer> source) {
super(source);
}
@Override
public String getRegion() {
return getContainer().getRegion();
}
@Override
public String getAccessKey() {
return getContainer().getAccessKey();
}
@Override
public String getSecretKey() {
return getContainer().getSecretKey();
}
@Override
public String getHost(String domainName) {
return "%s.%s.opensearch.localhost.localstack.cloud:%s".formatted(domainName, getContainer().getRegion(),
getContainer().getMappedPort(4566));
}
}
}

View File

@@ -18,6 +18,7 @@ org.springframework.ai.testcontainers.service.connection.chroma.ChromaContainerC
org.springframework.ai.testcontainers.service.connection.milvus.MilvusContainerConnectionDetailsFactory,\
org.springframework.ai.testcontainers.service.connection.mongo.MongoDbAtlasLocalContainerConnectionDetailsFactory,\
org.springframework.ai.testcontainers.service.connection.ollama.OllamaContainerConnectionDetailsFactory,\
org.springframework.ai.testcontainers.service.connection.opensearch.AwsOpenSearchContainerConnectionDetailsFactory,\
org.springframework.ai.testcontainers.service.connection.opensearch.OpenSearchContainerConnectionDetailsFactory,\
org.springframework.ai.testcontainers.service.connection.qdrant.QdrantContainerConnectionDetailsFactory,\
org.springframework.ai.testcontainers.service.connection.typesense.TypesenseContainerConnectionDetailsFactory,\

View File

@@ -0,0 +1,147 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.testcontainers.service.connection.opensearch;
import com.jayway.jsonpath.JsonPath;
import net.minidev.json.JSONArray;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreAutoConfiguration;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import static org.hamcrest.Matchers.hasSize;
@SpringJUnitConfig
@TestPropertySource(properties = { "spring.ai.vectorstore.opensearch.index-name=auto-spring-ai-document-index",
"spring.ai.vectorstore.opensearch.initialize-schema=true",
"spring.ai.vectorstore.opensearch.mapping-json="
+ AwsOpenSearchContainerConnectionDetailsFactoryTest.MAPPING_JSON,
"spring.ai.vectorstore.opensearch.aws.domain-name=testcontainers-domain",
"spring.ai.vectorstore.opensearch.aws.service-name=es" })
@Testcontainers
class AwsOpenSearchContainerConnectionDetailsFactoryTest {
static final String MAPPING_JSON = "{\"properties\":{\"embedding\":{\"type\":\"knn_vector\",\"dimension\":384}}}";
@Container
@ServiceConnection
private static final LocalStackContainer localstack = new LocalStackContainer(
DockerImageName.parse("localstack/localstack:3.5.0"))
.withEnv("LOCALSTACK_HOST", "localhost.localstack.cloud");
private final List<Document> documents = List.of(
new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
@Autowired
private VectorStore vectorStore;
@BeforeAll
static void beforeAll() throws IOException, InterruptedException {
String[] createDomainCmd = { "awslocal", "opensearch", "create-domain", "--domain-name",
"testcontainers-domain", "--region", localstack.getRegion() };
localstack.execInContainer(createDomainCmd);
String[] describeDomainCmd = { "awslocal", "opensearch", "describe-domain", "--domain-name",
"testcontainers-domain", "--region", localstack.getRegion() };
await().pollInterval(Duration.ofSeconds(30)).atMost(Duration.ofSeconds(300)).untilAsserted(() -> {
org.testcontainers.containers.Container.ExecResult execResult = localstack
.execInContainer(describeDomainCmd);
String response = execResult.getStdout();
JSONArray processed = JsonPath.read(response, "$.DomainStatus[?(@.Processing == false)]");
assertThat(processed).isNotEmpty();
});
}
@Test
public void addAndSearchTest() {
this.vectorStore.add(this.documents);
Awaitility.await()
.until(() -> this.vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
hasSize(1));
List<Document> results = this.vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
// Remove all documents from the store
this.vectorStore.delete(this.documents.stream().map(Document::getId).toList());
Awaitility.await()
.until(() -> this.vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
hasSize(0));
}
private String getText(String uri) {
var resource = new DefaultResourceLoader().getResource(uri);
try {
return resource.getContentAsString(StandardCharsets.UTF_8);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@Configuration(proxyBeanMethods = false)
@ImportAutoConfiguration(OpenSearchVectorStoreAutoConfiguration.class)
static class Config {
@Bean
public EmbeddingModel embeddingModel() {
return new TransformersEmbeddingModel();
}
}
}

View File

@@ -24,78 +24,85 @@ import java.util.Map;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.opensearch.testcontainers.OpensearchContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreAutoConfiguration;
import org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStoreProperties;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.FilteredClassLoader;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.OpenSearchVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.testcontainers.lifecycle.TestcontainersLifecycleApplicationContextInitializer;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.boot.testcontainers.service.connection.ServiceConnectionAutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.DefaultResourceLoader;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.Matchers.hasSize;
@SpringBootTest(properties = {
"spring.ai.vectorstore.opensearch.index-name=" + OpenSearchContainerConnectionDetailsFactoryTest.DOCUMENT_INDEX,
"spring.ai.vectorstore.opensearch.initialize-schema=true",
"spring.ai.vectorstore.opensearch.mapping-json="
+ OpenSearchContainerConnectionDetailsFactoryTest.MAPPING_JSON })
@Testcontainers
class OpenSearchContainerConnectionDetailsFactoryTest {
static final String DOCUMENT_INDEX = "auto-spring-ai-document-index";
static final String MAPPING_JSON = "{\"properties\":{\"embedding\":{\"type\":\"knn_vector\",\"dimension\":384}}}";
@Container
@ServiceConnection
private static final OpensearchContainer<?> opensearch = new OpensearchContainer<>(OpenSearchImage.DEFAULT_IMAGE);
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withInitializer(new TestcontainersLifecycleApplicationContextInitializer())
.withConfiguration(AutoConfigurations.of(ServiceConnectionAutoConfiguration.class,
OpenSearchVectorStoreAutoConfiguration.class))
.withClassLoader(new FilteredClassLoader(Region.class, ApacheHttpClient.class))
.withUserConfiguration(Config.class)
.withPropertyValues("spring.ai.vectorstore.opensearch.initialize-schema=true",
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".indexName=auto-spring-ai-document-index",
OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".mappingJson=" + """
{
"properties":{
"embedding":{
"type":"knn_vector",
"dimension":384
}
}
}
""");
private final List<Document> documents = List.of(
new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
@Autowired
private OpenSearchVectorStore vectorStore;
@Test
public void addAndSearchTest() {
contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
vectorStore.add(this.documents);
this.vectorStore.add(this.documents);
Awaitility.await()
.until(() -> vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
hasSize(1));
Awaitility.await()
.until(() -> this.vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
hasSize(1));
List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0));
List<Document> results = this.vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0));
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getContent()).contains("The Great Depression (19291939) was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(Document::getId).toList());
// Remove all documents from the store
this.vectorStore.delete(this.documents.stream().map(Document::getId).toList());
Awaitility.await()
.until(() -> this.vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
hasSize(0));
Awaitility.await()
.until(() -> vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)),
hasSize(0));
});
}
private String getText(String uri) {
@@ -109,7 +116,6 @@ class OpenSearchContainerConnectionDetailsFactoryTest {
}
@Configuration(proxyBeanMethods = false)
@ImportAutoConfiguration(OpenSearchVectorStoreAutoConfiguration.class)
static class Config {
@Bean
@@ -117,6 +123,12 @@ class OpenSearchContainerConnectionDetailsFactoryTest {
return new TransformersEmbeddingModel();
}
@Bean
@ServiceConnection
OpensearchContainer<?> opensearch() {
return new OpensearchContainer<>(OpenSearchImage.DEFAULT_IMAGE);
}
}
}